1use std::collections::VecDeque;
7use std::fmt;
8use std::sync::Arc;
9
10use arrow_schema::SchemaRef;
11use rustc_hash::{FxHashMap, FxHashSet};
12use smallvec::SmallVec;
13
14use super::error::DagError;
15
16pub const MAX_FAN_OUT: usize = 8;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
23pub struct NodeId(pub u32);
24
25impl fmt::Display for NodeId {
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 write!(f, "NodeId({})", self.0)
28 }
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33pub struct EdgeId(pub u32);
34
35impl fmt::Display for EdgeId {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 write!(f, "EdgeId({})", self.0)
38 }
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
43pub struct StatePartitionId(pub u32);
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum DagNodeType {
48 Source,
50 StatelessOperator,
52 StatefulOperator,
54 MaterializedView,
56 Sink,
58}
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
67pub enum DagChannelType {
68 Spsc,
70 Spmc,
72 Mpsc,
74}
75
76pub type PartitionFn = Arc<dyn Fn(&[u8]) -> usize + Send + Sync>;
78
79#[derive(Clone, Default)]
81pub enum PartitioningStrategy {
82 #[default]
84 Single,
85 RoundRobin,
87 HashBy(String),
89 Custom(PartitionFn),
91}
92
93impl fmt::Debug for PartitioningStrategy {
94 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95 match self {
96 Self::Single => write!(f, "Single"),
97 Self::RoundRobin => write!(f, "RoundRobin"),
98 Self::HashBy(key) => write!(f, "HashBy({key})"),
99 Self::Custom(_) => write!(f, "Custom(...)"),
100 }
101 }
102}
103
104pub struct DagNode {
109 pub id: NodeId,
111 pub name: String,
113 pub inputs: SmallVec<[EdgeId; 4]>,
115 pub outputs: SmallVec<[EdgeId; 4]>,
117 pub output_schema: SchemaRef,
119 pub state_partition: StatePartitionId,
121 pub node_type: DagNodeType,
123}
124
125impl fmt::Debug for DagNode {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.debug_struct("DagNode")
128 .field("id", &self.id)
129 .field("name", &self.name)
130 .field("inputs", &self.inputs)
131 .field("outputs", &self.outputs)
132 .field("node_type", &self.node_type)
133 .field("state_partition", &self.state_partition)
134 .finish_non_exhaustive()
135 }
136}
137
138#[derive(Debug)]
143pub struct DagEdge {
144 pub id: EdgeId,
146 pub source: NodeId,
148 pub target: NodeId,
150 pub channel_type: DagChannelType,
152 pub partitioning: PartitioningStrategy,
154 pub source_port: u8,
156 pub target_port: u8,
158}
159
160#[derive(Debug)]
164pub struct SharedStageMetadata {
165 pub producer_node: NodeId,
167 pub consumer_count: usize,
169 pub consumer_nodes: Vec<NodeId>,
171}
172
173pub struct StreamingDag {
178 nodes: FxHashMap<NodeId, DagNode>,
180 edges: FxHashMap<EdgeId, DagEdge>,
182 execution_order: Vec<NodeId>,
185 shared_stages: FxHashMap<NodeId, SharedStageMetadata>,
187 source_nodes: Vec<NodeId>,
189 sink_nodes: Vec<NodeId>,
191 name_index: FxHashMap<String, NodeId>,
193 next_node_id: u32,
195 next_edge_id: u32,
197 finalized: bool,
199}
200
201impl fmt::Debug for StreamingDag {
202 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203 f.debug_struct("StreamingDag")
204 .field("node_count", &self.nodes.len())
205 .field("edge_count", &self.edges.len())
206 .field("source_nodes", &self.source_nodes)
207 .field("sink_nodes", &self.sink_nodes)
208 .field("execution_order", &self.execution_order)
209 .field("finalized", &self.finalized)
210 .finish_non_exhaustive()
211 }
212}
213
214impl StreamingDag {
215 #[must_use]
217 pub fn new() -> Self {
218 Self {
219 nodes: FxHashMap::default(),
220 edges: FxHashMap::default(),
221 execution_order: Vec::new(),
222 shared_stages: FxHashMap::default(),
223 source_nodes: Vec::new(),
224 sink_nodes: Vec::new(),
225 name_index: FxHashMap::default(),
226 next_node_id: 0,
227 next_edge_id: 0,
228 finalized: false,
229 }
230 }
231
232 pub fn add_node(
238 &mut self,
239 name: impl Into<String>,
240 node_type: DagNodeType,
241 output_schema: SchemaRef,
242 ) -> Result<NodeId, DagError> {
243 let name = name.into();
244 if self.name_index.contains_key(&name) {
245 return Err(DagError::DuplicateNode(name));
246 }
247
248 let id = NodeId(self.next_node_id);
249 self.next_node_id += 1;
250
251 let node = DagNode {
252 id,
253 name: name.clone(),
254 inputs: SmallVec::new(),
255 outputs: SmallVec::new(),
256 output_schema,
257 state_partition: StatePartitionId(id.0),
258 node_type,
259 };
260
261 self.nodes.insert(id, node);
262 self.name_index.insert(name, id);
263 self.finalized = false;
264
265 Ok(id)
266 }
267
268 pub fn add_edge(&mut self, source: NodeId, target: NodeId) -> Result<EdgeId, DagError> {
275 if source == target {
277 let name = self.node_name(source).unwrap_or_default();
278 return Err(DagError::CycleDetected(name));
279 }
280
281 if !self.nodes.contains_key(&source) {
282 return Err(DagError::NodeNotFound(format!("{source}")));
283 }
284 if !self.nodes.contains_key(&target) {
285 return Err(DagError::NodeNotFound(format!("{target}")));
286 }
287
288 let id = EdgeId(self.next_edge_id);
289 self.next_edge_id += 1;
290
291 #[allow(clippy::cast_possible_truncation)]
293 let source_port = self.nodes.get(&source).map_or(0, |n| n.outputs.len() as u8);
294 #[allow(clippy::cast_possible_truncation)]
295 let target_port = self.nodes.get(&target).map_or(0, |n| n.inputs.len() as u8);
296
297 let edge = DagEdge {
298 id,
299 source,
300 target,
301 channel_type: DagChannelType::Spsc, partitioning: PartitioningStrategy::default(),
303 source_port,
304 target_port,
305 };
306
307 self.edges.insert(id, edge);
308
309 if let Some(node) = self.nodes.get_mut(&source) {
311 node.outputs.push(id);
312 }
313 if let Some(node) = self.nodes.get_mut(&target) {
314 node.inputs.push(id);
315 }
316
317 self.finalized = false;
318
319 Ok(id)
320 }
321
322 pub fn finalize(&mut self) -> Result<(), DagError> {
334 if self.nodes.is_empty() {
335 return Err(DagError::EmptyDag);
336 }
337
338 self.check_fan_out_limits()?;
339 self.compute_execution_order()?;
340 self.check_connected()?;
341 self.derive_channel_types();
342 self.identify_shared_stages();
343 self.classify_source_sink_nodes();
344 self.finalized = true;
345
346 Ok(())
347 }
348
349 pub fn validate(&self) -> Result<(), DagError> {
355 if self.nodes.is_empty() {
356 return Err(DagError::EmptyDag);
357 }
358 self.check_fan_out_limits()?;
359 self.check_acyclic()?;
360 self.check_connected()?;
361 self.check_schemas()?;
362 Ok(())
363 }
364
365 #[must_use]
369 pub fn node_count(&self) -> usize {
370 self.nodes.len()
371 }
372
373 #[must_use]
375 pub fn edge_count(&self) -> usize {
376 self.edges.len()
377 }
378
379 #[must_use]
381 pub fn node(&self, id: NodeId) -> Option<&DagNode> {
382 self.nodes.get(&id)
383 }
384
385 #[must_use]
387 pub fn edge(&self, id: EdgeId) -> Option<&DagEdge> {
388 self.edges.get(&id)
389 }
390
391 #[must_use]
393 pub fn nodes(&self) -> &FxHashMap<NodeId, DagNode> {
394 &self.nodes
395 }
396
397 #[must_use]
399 pub fn edges(&self) -> &FxHashMap<EdgeId, DagEdge> {
400 &self.edges
401 }
402
403 #[must_use]
405 pub fn node_id_by_name(&self, name: &str) -> Option<NodeId> {
406 self.name_index.get(name).copied()
407 }
408
409 #[must_use]
411 pub fn node_name(&self, id: NodeId) -> Option<String> {
412 self.nodes.get(&id).map(|n| n.name.clone())
413 }
414
415 #[inline]
417 #[must_use]
418 pub fn outgoing_edge_count(&self, node: NodeId) -> usize {
419 self.nodes.get(&node).map_or(0, |n| n.outputs.len())
420 }
421
422 #[inline]
424 #[must_use]
425 pub fn incoming_edge_count(&self, node: NodeId) -> usize {
426 self.nodes.get(&node).map_or(0, |n| n.inputs.len())
427 }
428
429 #[must_use]
431 pub fn sources(&self) -> &[NodeId] {
432 &self.source_nodes
433 }
434
435 #[must_use]
437 pub fn sinks(&self) -> &[NodeId] {
438 &self.sink_nodes
439 }
440
441 #[must_use]
443 pub fn execution_order(&self) -> &[NodeId] {
444 &self.execution_order
445 }
446
447 #[must_use]
449 pub fn shared_stages(&self) -> &FxHashMap<NodeId, SharedStageMetadata> {
450 &self.shared_stages
451 }
452
453 #[must_use]
455 pub fn is_finalized(&self) -> bool {
456 self.finalized
457 }
458
459 fn check_fan_out_limits(&self) -> Result<(), DagError> {
463 for node in self.nodes.values() {
464 if node.outputs.len() > MAX_FAN_OUT {
465 return Err(DagError::FanOutLimitExceeded {
466 node: node.name.clone(),
467 count: node.outputs.len(),
468 max: MAX_FAN_OUT,
469 });
470 }
471 }
472 Ok(())
473 }
474
475 fn check_acyclic(&self) -> Result<(), DagError> {
480 let (order, _) = self.kahn_topo_sort();
481 if order.len() < self.nodes.len() {
482 let ordered_set: FxHashSet<NodeId> = order.into_iter().collect();
484 for node in self.nodes.values() {
485 if !ordered_set.contains(&node.id) {
486 return Err(DagError::CycleDetected(node.name.clone()));
487 }
488 }
489 return Err(DagError::CycleDetected("unknown".to_string()));
490 }
491 Ok(())
492 }
493
494 fn check_connected(&self) -> Result<(), DagError> {
496 for node in self.nodes.values() {
497 match node.node_type {
498 DagNodeType::Source => {
499 if node.outputs.is_empty() {
500 return Err(DagError::DisconnectedNode(node.name.clone()));
501 }
502 }
503 DagNodeType::Sink => {
504 if node.inputs.is_empty() {
505 return Err(DagError::DisconnectedNode(node.name.clone()));
506 }
507 }
508 _ => {
509 if node.inputs.is_empty() && node.outputs.is_empty() {
510 return Err(DagError::DisconnectedNode(node.name.clone()));
511 }
512 }
513 }
514 }
515 Ok(())
516 }
517
518 fn check_schemas(&self) -> Result<(), DagError> {
523 for edge in self.edges.values() {
524 let source_node = self.nodes.get(&edge.source);
525 let target_node = self.nodes.get(&edge.target);
526
527 if let (Some(source), Some(target)) = (source_node, target_node) {
528 let source_schema = &source.output_schema;
529 let target_schema = &target.output_schema;
530
531 if source_schema.fields().is_empty() || target_schema.fields().is_empty() {
533 continue;
534 }
535
536 if source_schema.fields().len() != target_schema.fields().len() {
538 return Err(DagError::SchemaMismatch {
539 source_node: source.name.clone(),
540 target_node: target.name.clone(),
541 reason: format!(
542 "field count mismatch: {} vs {}",
543 source_schema.fields().len(),
544 target_schema.fields().len()
545 ),
546 });
547 }
548
549 for (sf, tf) in source_schema
550 .fields()
551 .iter()
552 .zip(target_schema.fields().iter())
553 {
554 if sf.data_type() != tf.data_type() {
555 return Err(DagError::SchemaMismatch {
556 source_node: source.name.clone(),
557 target_node: target.name.clone(),
558 reason: format!(
559 "type mismatch for field '{}': {:?} vs '{}':{:?}",
560 sf.name(),
561 sf.data_type(),
562 tf.name(),
563 tf.data_type()
564 ),
565 });
566 }
567 }
568 }
569 }
570 Ok(())
571 }
572
573 fn compute_execution_order(&mut self) -> Result<(), DagError> {
578 let (order, processed) = self.kahn_topo_sort();
579 if processed < self.nodes.len() {
580 let ordered_set: FxHashSet<NodeId> = order.iter().copied().collect();
581 for node in self.nodes.values() {
582 if !ordered_set.contains(&node.id) {
583 return Err(DagError::CycleDetected(node.name.clone()));
584 }
585 }
586 return Err(DagError::CycleDetected("unknown".to_string()));
587 }
588 self.execution_order = order;
589 Ok(())
590 }
591
592 fn kahn_topo_sort(&self) -> (Vec<NodeId>, usize) {
596 let mut in_degree: FxHashMap<NodeId, usize> = FxHashMap::default();
598 for node in self.nodes.values() {
599 in_degree.entry(node.id).or_insert(0);
600 }
601 for edge in self.edges.values() {
602 *in_degree.entry(edge.target).or_insert(0) += 1;
603 }
604
605 let mut queue: VecDeque<NodeId> = VecDeque::new();
607 for (&node_id, °) in &in_degree {
608 if deg == 0 {
609 queue.push_back(node_id);
610 }
611 }
612
613 let mut initial: Vec<NodeId> = queue.drain(..).collect();
615 initial.sort_by_key(|n| n.0);
616 for id in initial {
617 queue.push_back(id);
618 }
619
620 let mut order = Vec::with_capacity(self.nodes.len());
621 let mut processed = 0;
622
623 while let Some(node_id) = queue.pop_front() {
624 order.push(node_id);
625 processed += 1;
626
627 if let Some(node) = self.nodes.get(&node_id) {
629 let mut successors: Vec<NodeId> = Vec::new();
630 for &edge_id in &node.outputs {
631 if let Some(edge) = self.edges.get(&edge_id) {
632 let target = edge.target;
633 if let Some(deg) = in_degree.get_mut(&target) {
634 *deg = deg.saturating_sub(1);
635 if *deg == 0 {
636 successors.push(target);
637 }
638 }
639 }
640 }
641 successors.sort_by_key(|n| n.0);
642 queue.extend(successors);
643 }
644 }
645
646 (order, processed)
647 }
648
649 fn derive_channel_types(&mut self) {
655 let edge_ids: Vec<EdgeId> = self.edges.keys().copied().collect();
656
657 for edge_id in edge_ids {
658 let (source_fan_out, target_fan_in) = {
659 let edge = &self.edges[&edge_id];
660 (
661 self.outgoing_edge_count(edge.source),
662 self.incoming_edge_count(edge.target),
663 )
664 };
665
666 let channel_type = match (target_fan_in > 1, source_fan_out > 1) {
667 (false, false) => DagChannelType::Spsc,
668 (false, true) => DagChannelType::Spmc,
669 (true, _) => DagChannelType::Mpsc,
673 };
674
675 if let Some(edge) = self.edges.get_mut(&edge_id) {
676 edge.channel_type = channel_type;
677 }
678 }
679 }
680
681 fn identify_shared_stages(&mut self) {
683 self.shared_stages.clear();
684
685 for node in self.nodes.values() {
686 if node.outputs.len() > 1 {
687 let consumer_nodes: Vec<NodeId> = node
688 .outputs
689 .iter()
690 .filter_map(|&edge_id| self.edges.get(&edge_id).map(|e| e.target))
691 .collect();
692
693 self.shared_stages.insert(
694 node.id,
695 SharedStageMetadata {
696 producer_node: node.id,
697 consumer_count: consumer_nodes.len(),
698 consumer_nodes,
699 },
700 );
701 }
702 }
703 }
704
705 fn classify_source_sink_nodes(&mut self) {
707 self.source_nodes.clear();
708 self.sink_nodes.clear();
709
710 for node in self.nodes.values() {
711 if node.inputs.is_empty() {
712 self.source_nodes.push(node.id);
713 }
714 if node.outputs.is_empty() {
715 self.sink_nodes.push(node.id);
716 }
717 }
718
719 self.source_nodes.sort_by_key(|n| n.0);
721 self.sink_nodes.sort_by_key(|n| n.0);
722 }
723}
724
725impl StreamingDag {
726 pub fn from_mv_registry(
742 registry: &crate::mv::MvRegistry,
743 base_table_schemas: &FxHashMap<String, SchemaRef>,
744 ) -> Result<Self, DagError> {
745 if registry.is_empty() && registry.base_tables().is_empty() {
746 return Err(DagError::EmptyDag);
747 }
748
749 let mut dag = Self::new();
750
751 for base_table in registry.base_tables() {
753 let schema = base_table_schemas
754 .get(base_table)
755 .ok_or_else(|| DagError::BaseTableSchemaNotFound(base_table.clone()))?;
756 dag.add_node(base_table, DagNodeType::Source, schema.clone())?;
757 }
758
759 for mv_name in registry.topo_order() {
761 let mv = registry
762 .get(mv_name)
763 .ok_or_else(|| DagError::NodeNotFound(mv_name.clone()))?;
764 dag.add_node(mv_name, DagNodeType::MaterializedView, mv.schema.clone())?;
765 }
766
767 for mv_name in registry.topo_order() {
769 let mv = registry
770 .get(mv_name)
771 .ok_or_else(|| DagError::NodeNotFound(mv_name.clone()))?;
772 let target_id = dag
773 .node_id_by_name(mv_name)
774 .ok_or_else(|| DagError::NodeNotFound(mv_name.clone()))?;
775 for source_name in &mv.sources {
776 let source_id = dag
777 .node_id_by_name(source_name)
778 .ok_or_else(|| DagError::NodeNotFound(source_name.clone()))?;
779 dag.add_edge(source_id, target_id)?;
780 }
781 }
782
783 dag.finalize()?;
784 Ok(dag)
785 }
786}
787
788impl Default for StreamingDag {
789 fn default() -> Self {
790 Self::new()
791 }
792}