1use std::collections::HashMap;
4use std::fmt::{Debug, Display};
5use std::ops::{Bound, RangeBounds};
6use std::sync::OnceLock;
7
8use documented::DocumentedVariants;
9use proc_macro2::{Ident, Literal, Span, TokenStream};
10use quote::quote_spanned;
11use serde::{Deserialize, Serialize};
12use slotmap::Key;
13use syn::punctuated::Punctuated;
14use syn::{Expr, Token, parse_quote_spanned};
15
16use super::{
17 GraphLoopId, GraphNode, GraphNodeId, GraphSubgraphId, OpInstGenerics, OperatorInstance,
18 PortIndexValue,
19};
20use crate::diagnostic::{Diagnostic, Diagnostics, Level};
21use crate::parse::{Operator, PortIndex};
22
23#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug, Serialize, Deserialize)]
25pub enum DelayType {
26 Tick,
28 TickLazy,
30}
31
32pub enum PortListSpec {
34 Variadic,
36 Fixed(Punctuated<PortIndex, Token![,]>),
38}
39
40pub struct OperatorConstraints {
42 pub name: &'static str,
44 pub categories: &'static [OperatorCategory],
46
47 pub hard_range_inn: &'static dyn RangeTrait<usize>,
50 pub soft_range_inn: &'static dyn RangeTrait<usize>,
52 pub hard_range_out: &'static dyn RangeTrait<usize>,
54 pub soft_range_out: &'static dyn RangeTrait<usize>,
56 pub num_args: usize,
58 pub persistence_args: &'static dyn RangeTrait<usize>,
60 pub type_args: &'static dyn RangeTrait<usize>,
64 pub is_external_input: bool,
67 pub flo_type: Option<FloType>,
69
70 pub ports_inn: Option<fn() -> PortListSpec>,
72 pub ports_out: Option<fn() -> PortListSpec>,
74
75 pub input_delaytype_fn: fn(&PortIndexValue) -> Option<DelayType>,
77 pub write_fn: WriteFn,
79}
80
81pub type WriteFn = fn(&WriteContextArgs<'_>, &mut Diagnostics) -> Result<OperatorWriteOutput, ()>;
83
84impl Debug for OperatorConstraints {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 f.debug_struct("OperatorConstraints")
87 .field("name", &self.name)
88 .field("hard_range_inn", &self.hard_range_inn)
89 .field("soft_range_inn", &self.soft_range_inn)
90 .field("hard_range_out", &self.hard_range_out)
91 .field("soft_range_out", &self.soft_range_out)
92 .field("num_args", &self.num_args)
93 .field("persistence_args", &self.persistence_args)
94 .field("type_args", &self.type_args)
95 .field("is_external_input", &self.is_external_input)
96 .field("ports_inn", &self.ports_inn)
97 .field("ports_out", &self.ports_out)
98 .finish()
102 }
103}
104
105#[derive(Default)]
109pub struct OperatorWriteOutput {
110 pub write_prologue: TokenStream,
113 pub write_iterator: TokenStream,
120 pub write_iterator_after: TokenStream,
122 pub write_tick_end: TokenStream,
125}
126
127pub const RANGE_ANY: &'static dyn RangeTrait<usize> = &(0..);
129pub const RANGE_0: &'static dyn RangeTrait<usize> = &(0..=0);
131pub const RANGE_1: &'static dyn RangeTrait<usize> = &(1..=1);
133
134pub fn identity_write_iterator_fn(
137 &WriteContextArgs {
138 root,
139 op_span,
140 ident,
141 inputs,
142 outputs,
143 is_pull,
144 op_inst:
145 OperatorInstance {
146 generics: OpInstGenerics { type_args, .. },
147 ..
148 },
149 ..
150 }: &WriteContextArgs,
151) -> TokenStream {
152 let generic_type = type_args
153 .first()
154 .map(quote::ToTokens::to_token_stream)
155 .unwrap_or(quote_spanned!(op_span=> _));
156
157 if is_pull {
158 let input = &inputs[0];
159 quote_spanned! {op_span=>
160 let #ident = {
161 fn check_input<Pull, Item>(pull: Pull) -> impl #root::dfir_pipes::pull::Pull<Item = Item, Meta = Pull::Meta, CanPend = Pull::CanPend, CanEnd = Pull::CanEnd>
162 where
163 Pull: #root::dfir_pipes::pull::Pull<Item = Item>,
164 {
165 pull
166 }
167 check_input::<_, #generic_type>(#input)
168 };
169 }
170 } else {
171 let output = &outputs[0];
172 quote_spanned! {op_span=>
173 let #ident = {
174 fn check_output<Psh, Item>(push: Psh) -> impl #root::dfir_pipes::push::Push<Item, (), CanPend = Psh::CanPend>
175 where
176 Psh: #root::dfir_pipes::push::Push<Item, ()>,
177 {
178 push
179 }
180 check_output::<_, #generic_type>(#output)
181 };
182 }
183 }
184}
185
186pub const IDENTITY_WRITE_FN: WriteFn = |write_context_args, _| {
188 let write_iterator = identity_write_iterator_fn(write_context_args);
189 Ok(OperatorWriteOutput {
190 write_iterator,
191 ..Default::default()
192 })
193};
194
195pub fn null_write_iterator_fn(
198 &WriteContextArgs {
199 root,
200 op_span,
201 ident,
202 inputs,
203 outputs,
204 is_pull,
205 op_inst:
206 OperatorInstance {
207 generics: OpInstGenerics { type_args, .. },
208 ..
209 },
210 ..
211 }: &WriteContextArgs,
212) -> TokenStream {
213 let default_type = parse_quote_spanned! {op_span=> _};
214 let iter_type = type_args.first().unwrap_or(&default_type);
215
216 if is_pull {
217 quote_spanned! {op_span=>
218 let #ident = #root::dfir_pipes::pull::poll_fn({
219 #(
220 let mut #inputs = ::std::boxed::Box::pin(#inputs);
221 )*
222 move |_cx| {
223 #(
227 let #inputs = #root::dfir_pipes::pull::Pull::pull(
228 ::std::pin::Pin::as_mut(&mut #inputs),
229 <_ as #root::dfir_pipes::Context>::from_task(_cx),
230 );
231 )*
232 #(
233 if let #root::dfir_pipes::pull::PullStep::Pending(_) = #inputs {
234 return #root::dfir_pipes::pull::PullStep::Pending(#root::dfir_pipes::Yes);
235 }
236 )*
237 #root::dfir_pipes::pull::PullStep::<_, _, #root::dfir_pipes::Yes, _>::Ended(#root::dfir_pipes::Yes)
238 }
239 });
240 }
241 } else {
242 quote_spanned! {op_span=>
243 #[allow(clippy::let_unit_value)]
244 let _ = (#(#outputs),*);
245 let #ident = #root::dfir_pipes::push::for_each::<_, #iter_type>(::std::mem::drop::<#iter_type>);
246 }
247 }
248}
249
250pub const NULL_WRITE_FN: WriteFn = |write_context_args, _| {
253 let write_iterator = null_write_iterator_fn(write_context_args);
254 Ok(OperatorWriteOutput {
255 write_iterator,
256 ..Default::default()
257 })
258};
259
260macro_rules! declare_ops {
261 ( $( $mod:ident :: $op:ident, )* ) => {
262 $( pub(crate) mod $mod; )*
263 pub const OPERATORS: &[OperatorConstraints] = &[
265 $( $mod :: $op, )*
266 ];
267 };
268}
269declare_ops![
270 all_iterations::ALL_ITERATIONS,
271 all_once::ALL_ONCE,
272 anti_join::ANTI_JOIN,
273 assert::ASSERT,
274 assert_eq::ASSERT_EQ,
275 batch::BATCH,
276 chain::CHAIN,
277 chain_first_n::CHAIN_FIRST_N,
278 _counter::_COUNTER,
279 cross_join::CROSS_JOIN,
280 cross_join_multiset::CROSS_JOIN_MULTISET,
281 cross_singleton::CROSS_SINGLETON,
282 demux_enum::DEMUX_ENUM,
283 dest_file::DEST_FILE,
284 dest_sink::DEST_SINK,
285 dest_sink_serde::DEST_SINK_SERDE,
286 difference::DIFFERENCE,
287 enumerate::ENUMERATE,
288 filter::FILTER,
289 filter_map::FILTER_MAP,
290 flat_map::FLAT_MAP,
291 flat_map_stream_blocking::FLAT_MAP_STREAM_BLOCKING,
292 flatten::FLATTEN,
293 flatten_stream_blocking::FLATTEN_STREAM_BLOCKING,
294 fold::FOLD,
295 fold_no_replay::FOLD_NO_REPLAY,
296 for_each::FOR_EACH,
297 identity::IDENTITY,
298 initialize::INITIALIZE,
299 inspect::INSPECT,
300 iter_ref::ITER_REF,
301 join::JOIN,
302 join_fused::JOIN_FUSED,
303 join_fused_lhs::JOIN_FUSED_LHS,
304 join_fused_rhs::JOIN_FUSED_RHS,
305 join_multiset::JOIN_MULTISET,
306 join_multiset_half::JOIN_MULTISET_HALF,
307 fold_keyed::FOLD_KEYED,
308 reduce_keyed::REDUCE_KEYED,
309 repeat_n::REPEAT_N,
310 lattice_bimorphism::LATTICE_BIMORPHISM,
312 _lattice_fold_batch::_LATTICE_FOLD_BATCH,
313 lattice_fold::LATTICE_FOLD,
314 _lattice_join_fused_join::_LATTICE_JOIN_FUSED_JOIN,
315 lattice_reduce::LATTICE_REDUCE,
316 map::MAP,
317 union::UNION,
318 multiset_delta::MULTISET_DELTA,
319 next_iteration::NEXT_ITERATION,
320 defer_signal::DEFER_SIGNAL,
321 defer_tick::DEFER_TICK,
322 defer_tick_lazy::DEFER_TICK_LAZY,
323 null::NULL,
324 partition::PARTITION,
325 persist::PERSIST,
326 persist_mut::PERSIST_MUT,
327 persist_mut_keyed::PERSIST_MUT_KEYED,
328 prefix::PREFIX,
329 resolve_futures::RESOLVE_FUTURES,
330 resolve_futures_blocking::RESOLVE_FUTURES_BLOCKING,
331 resolve_futures_blocking_ordered::RESOLVE_FUTURES_BLOCKING_ORDERED,
332 resolve_futures_ordered::RESOLVE_FUTURES_ORDERED,
333 reduce::REDUCE,
334 reduce_no_replay::REDUCE_NO_REPLAY,
335 scan::SCAN,
336 scan_async_blocking::SCAN_ASYNC_BLOCKING,
337 spin::SPIN,
338 sort::SORT,
339 sort_by_key::SORT_BY_KEY,
340 source_file::SOURCE_FILE,
341 source_interval::SOURCE_INTERVAL,
342 source_iter::SOURCE_ITER,
343 source_json::SOURCE_JSON,
344 source_stdin::SOURCE_STDIN,
345 source_stream::SOURCE_STREAM,
346 source_stream_serde::SOURCE_STREAM_SERDE,
347 state::STATE,
348 state_by::STATE_BY,
349 tee::TEE,
350 unique::UNIQUE,
351 unzip::UNZIP,
352 zip::ZIP,
353 zip_longest::ZIP_LONGEST,
354];
355
356pub fn operator_lookup() -> &'static HashMap<&'static str, &'static OperatorConstraints> {
358 pub static OPERATOR_LOOKUP: OnceLock<HashMap<&'static str, &'static OperatorConstraints>> =
359 OnceLock::new();
360 OPERATOR_LOOKUP.get_or_init(|| OPERATORS.iter().map(|op| (op.name, op)).collect())
361}
362pub fn find_node_op_constraints(node: &GraphNode) -> Option<&'static OperatorConstraints> {
364 if let GraphNode::Operator(operator) = node {
365 find_op_op_constraints(operator)
366 } else {
367 None
368 }
369}
370pub fn find_op_op_constraints(operator: &Operator) -> Option<&'static OperatorConstraints> {
372 let name = &*operator.name_string();
373 operator_lookup().get(name).copied()
374}
375
376#[derive(Clone)]
378pub struct WriteContextArgs<'a> {
379 pub root: &'a TokenStream,
381 pub context: &'a Ident,
384 pub df_ident: &'a Ident,
388 pub subgraph_id: GraphSubgraphId,
390 pub node_id: GraphNodeId,
392 pub loop_id: Option<GraphLoopId>,
394 pub op_span: Span,
396 pub op_tag: Option<String>,
398 pub work_fn: &'a Ident,
400 pub work_fn_async: &'a Ident,
402
403 pub ident: &'a Ident,
405 pub is_pull: bool,
407 pub inputs: &'a [Ident],
409 pub outputs: &'a [Ident],
411
412 pub op_name: &'static str,
414 pub op_inst: &'a OperatorInstance,
416 pub arguments: &'a Punctuated<Expr, Token![,]>,
422}
423impl WriteContextArgs<'_> {
424 pub fn make_ident(&self, suffix: impl AsRef<str>) -> Ident {
430 Ident::new(
431 &format!(
432 "sg_{:?}_node_{:?}_{}",
433 self.subgraph_id.data(),
434 self.node_id.data(),
435 suffix.as_ref(),
436 ),
437 self.op_span,
438 )
439 }
440
441 pub fn persistence_args_disallow_mutable<const N: usize>(
443 &self,
444 diagnostics: &mut Diagnostics,
445 ) -> [Persistence; N] {
446 let len = self.op_inst.generics.persistence_args.len();
447 if 0 != len && 1 != len && N != len {
448 diagnostics.push(Diagnostic::spanned(
449 self.op_span,
450 Level::Error,
451 format!(
452 "The operator `{}` only accepts 0, 1, or {} persistence arguments",
453 self.op_name, N
454 ),
455 ));
456 }
457
458 let default_persistence = if self.loop_id.is_some() {
459 Persistence::None
460 } else {
461 Persistence::Tick
462 };
463 let mut out = [default_persistence; N];
464 self.op_inst
465 .generics
466 .persistence_args
467 .iter()
468 .copied()
469 .cycle() .take(N)
471 .enumerate()
472 .filter(|&(_i, p)| {
473 if p == Persistence::Mutable {
474 diagnostics.push(Diagnostic::spanned(
475 self.op_span,
476 Level::Error,
477 format!(
478 "An implementation of `'{}` does not exist",
479 p.to_str_lowercase()
480 ),
481 ));
482 false
483 } else {
484 true
485 }
486 })
487 .for_each(|(i, p)| {
488 out[i] = p;
489 });
490 out
491 }
492}
493
494pub trait RangeTrait<T>: Send + Sync + Debug
496where
497 T: ?Sized,
498{
499 fn start_bound(&self) -> Bound<&T>;
501 fn end_bound(&self) -> Bound<&T>;
503 fn contains(&self, item: &T) -> bool
505 where
506 T: PartialOrd<T>;
507
508 fn human_string(&self) -> String
510 where
511 T: Display + PartialEq,
512 {
513 match (self.start_bound(), self.end_bound()) {
514 (Bound::Unbounded, Bound::Unbounded) => "any number of".to_owned(),
515
516 (Bound::Included(n), Bound::Included(x)) if n == x => {
517 format!("exactly {}", n)
518 }
519 (Bound::Included(n), Bound::Included(x)) => {
520 format!("at least {} and at most {}", n, x)
521 }
522 (Bound::Included(n), Bound::Excluded(x)) => {
523 format!("at least {} and less than {}", n, x)
524 }
525 (Bound::Included(n), Bound::Unbounded) => format!("at least {}", n),
526 (Bound::Excluded(n), Bound::Included(x)) => {
527 format!("more than {} and at most {}", n, x)
528 }
529 (Bound::Excluded(n), Bound::Excluded(x)) => {
530 format!("more than {} and less than {}", n, x)
531 }
532 (Bound::Excluded(n), Bound::Unbounded) => format!("more than {}", n),
533 (Bound::Unbounded, Bound::Included(x)) => format!("at most {}", x),
534 (Bound::Unbounded, Bound::Excluded(x)) => format!("less than {}", x),
535 }
536 }
537}
538
539impl<R, T> RangeTrait<T> for R
540where
541 R: RangeBounds<T> + Send + Sync + Debug,
542{
543 fn start_bound(&self) -> Bound<&T> {
544 self.start_bound()
545 }
546
547 fn end_bound(&self) -> Bound<&T> {
548 self.end_bound()
549 }
550
551 fn contains(&self, item: &T) -> bool
552 where
553 T: PartialOrd<T>,
554 {
555 self.contains(item)
556 }
557}
558
559#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug, Serialize, Deserialize)]
561pub enum Persistence {
562 None,
564 Loop,
566 Tick,
568 Static,
570 Mutable,
572}
573impl Persistence {
574 pub fn to_str_lowercase(self) -> &'static str {
576 match self {
577 Persistence::None => "none",
578 Persistence::Tick => "tick",
579 Persistence::Loop => "loop",
580 Persistence::Static => "static",
581 Persistence::Mutable => "mutable",
582 }
583 }
584}
585
586fn make_missing_runtime_msg(op_name: &str) -> Literal {
588 Literal::string(&format!(
589 "`{}()` must be used within a Tokio runtime. For example, use `#[dfir_rs::main]` on your main method.",
590 op_name
591 ))
592}
593
594#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, DocumentedVariants)]
596pub enum OperatorCategory {
597 Map,
599 Filter,
601 Flatten,
603 Fold,
605 KeyedFold,
607 LatticeFold,
609 Persistence,
611 MultiIn,
613 MultiOut,
615 Source,
617 Sink,
619 Control,
621 CompilerFusionOperator,
623 Windowing,
625 Unwindowing,
627}
628impl OperatorCategory {
629 pub fn name(self) -> &'static str {
631 self.get_variant_docs().split_once(":").unwrap().0
632 }
633 pub fn description(self) -> &'static str {
635 self.get_variant_docs().split_once(":").unwrap().1
636 }
637}
638
639#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug)]
641pub enum FloType {
642 Source,
644 Windowing,
646 Unwindowing,
648 NextIteration,
650}