Skip to main content

laminar_core/compiler/
extractor.rs

1//! Pipeline extraction from `DataFusion` logical plans.
2//!
3//! [`PipelineExtractor`] walks a [`LogicalPlan`] top-down, decomposing it into
4//! compilable [`Pipeline`] segments separated by [`PipelineBreaker`]s (stateful
5//! operators like aggregations, sorts, and joins).
6
7use std::sync::Arc;
8
9use arrow_schema::SchemaRef;
10use datafusion_expr::LogicalPlan;
11
12use super::error::ExtractError;
13use super::pipeline::{Pipeline, PipelineBreaker, PipelineId, PipelineStage};
14
15/// The result of extracting pipelines from a logical plan.
16#[derive(Debug)]
17pub struct ExtractedPlan {
18    /// Compilable pipeline segments.
19    pub pipelines: Vec<Pipeline>,
20    /// Breakers connecting pipelines: `(upstream_id, breaker, downstream_id)`.
21    pub breakers: Vec<(PipelineId, PipelineBreaker, PipelineId)>,
22    /// Pipeline IDs that read from sources (table scans).
23    pub sources: Vec<PipelineId>,
24    /// Pipeline IDs that write to sinks (terminal outputs).
25    pub sinks: Vec<PipelineId>,
26}
27
28/// Mutable state accumulated during plan extraction.
29struct ExtractionContext {
30    pipelines: Vec<PipelineBuilder>,
31    breakers: Vec<(PipelineId, PipelineBreaker, PipelineId)>,
32    sources: Vec<PipelineId>,
33    sinks: Vec<PipelineId>,
34    next_id: u32,
35}
36
37/// In-progress pipeline being built during extraction.
38struct PipelineBuilder {
39    id: PipelineId,
40    stages: Vec<PipelineStage>,
41    input_schema: SchemaRef,
42    output_schema: SchemaRef,
43}
44
45impl ExtractionContext {
46    fn new() -> Self {
47        Self {
48            pipelines: Vec::new(),
49            breakers: Vec::new(),
50            sources: Vec::new(),
51            sinks: Vec::new(),
52            next_id: 0,
53        }
54    }
55
56    /// Creates a new pipeline with the given schemas and returns its ID.
57    fn new_pipeline(&mut self, input_schema: SchemaRef, output_schema: SchemaRef) -> PipelineId {
58        let id = PipelineId(self.next_id);
59        self.next_id += 1;
60        self.pipelines.push(PipelineBuilder {
61            id,
62            stages: Vec::new(),
63            input_schema,
64            output_schema,
65        });
66        id
67    }
68
69    /// Adds a stage to the pipeline with the given ID.
70    fn add_stage(&mut self, pipeline_id: PipelineId, stage: PipelineStage) {
71        let builder = self
72            .pipelines
73            .iter_mut()
74            .find(|p| p.id == pipeline_id)
75            .expect("pipeline not found");
76        builder.stages.push(stage);
77    }
78
79    /// Updates the output schema of a pipeline (e.g., after a Projection stage).
80    fn update_output_schema(&mut self, pipeline_id: PipelineId, schema: SchemaRef) {
81        let builder = self
82            .pipelines
83            .iter_mut()
84            .find(|p| p.id == pipeline_id)
85            .expect("pipeline not found");
86        builder.output_schema = schema;
87    }
88
89    /// Finalizes all pipeline builders into an [`ExtractedPlan`].
90    fn finalize(self) -> ExtractedPlan {
91        let pipelines = self
92            .pipelines
93            .into_iter()
94            .map(|b| Pipeline {
95                id: b.id,
96                stages: b.stages,
97                input_schema: b.input_schema,
98                output_schema: b.output_schema,
99            })
100            .collect();
101
102        ExtractedPlan {
103            pipelines,
104            breakers: self.breakers,
105            sources: self.sources,
106            sinks: self.sinks,
107        }
108    }
109}
110
111/// Extracts compilable pipelines from `DataFusion` logical plans.
112pub struct PipelineExtractor;
113
114impl PipelineExtractor {
115    /// Extracts pipelines from a [`LogicalPlan`].
116    ///
117    /// Walks the plan top-down, collecting compilable stages (Filter, Projection)
118    /// into pipelines and breaking at stateful operators (Aggregate, Sort, Join).
119    ///
120    /// # Errors
121    ///
122    /// Returns [`ExtractError`] if the plan contains unsupported nodes.
123    pub fn extract(plan: &LogicalPlan) -> Result<ExtractedPlan, ExtractError> {
124        let mut ctx = ExtractionContext::new();
125        let terminal_id = extract_impl(plan, &mut ctx)?;
126
127        // The terminal pipeline is a sink.
128        ctx.sinks.push(terminal_id);
129
130        Ok(ctx.finalize())
131    }
132}
133
134/// Converts a `DFSchemaRef` (via the plan's `schema()` method) to an Arrow `SchemaRef`.
135fn arrow_schema(plan: &LogicalPlan) -> SchemaRef {
136    Arc::new(plan.schema().as_arrow().clone())
137}
138
139/// Recursively extracts pipelines from a logical plan node.
140#[allow(clippy::too_many_lines)]
141fn extract_impl(
142    plan: &LogicalPlan,
143    ctx: &mut ExtractionContext,
144) -> Result<PipelineId, ExtractError> {
145    match plan {
146        LogicalPlan::TableScan(scan) => {
147            let schema = Arc::new(scan.projected_schema.as_arrow().clone());
148            let pipeline_id = ctx.new_pipeline(Arc::clone(&schema), schema);
149            ctx.sources.push(pipeline_id);
150            Ok(pipeline_id)
151        }
152
153        LogicalPlan::EmptyRelation(empty) => {
154            let schema = Arc::new(empty.schema.as_arrow().clone());
155            let pipeline_id = ctx.new_pipeline(Arc::clone(&schema), schema);
156            ctx.sources.push(pipeline_id);
157            Ok(pipeline_id)
158        }
159
160        LogicalPlan::Filter(filter) => {
161            let child_id = extract_impl(&filter.input, ctx)?;
162            ctx.add_stage(
163                child_id,
164                PipelineStage::Filter {
165                    predicate: filter.predicate.clone(),
166                },
167            );
168            // Filter does not change the schema.
169            Ok(child_id)
170        }
171
172        LogicalPlan::Projection(proj) => {
173            let child_id = extract_impl(&proj.input, ctx)?;
174
175            let arrow_out = proj.schema.as_arrow();
176            let expressions: Vec<(datafusion_expr::Expr, String)> = proj
177                .expr
178                .iter()
179                .enumerate()
180                .map(|(i, expr)| (expr.clone(), arrow_out.field(i).name().clone()))
181                .collect();
182
183            ctx.add_stage(child_id, PipelineStage::Project { expressions });
184            ctx.update_output_schema(child_id, Arc::new(arrow_out.clone()));
185            Ok(child_id)
186        }
187
188        LogicalPlan::Aggregate(agg) => {
189            let upstream_id = extract_impl(&agg.input, ctx)?;
190
191            // Add KeyExtract stage for group expressions to the upstream pipeline.
192            if !agg.group_expr.is_empty() {
193                ctx.add_stage(
194                    upstream_id,
195                    PipelineStage::KeyExtract {
196                        key_exprs: agg.group_expr.clone(),
197                    },
198                );
199            }
200
201            // Create downstream pipeline with the aggregate's output schema.
202            let output_schema = Arc::new(agg.schema.as_arrow().clone());
203            let downstream_id = ctx.new_pipeline(Arc::clone(&output_schema), output_schema);
204
205            ctx.breakers.push((
206                upstream_id,
207                PipelineBreaker::Aggregate {
208                    group_exprs: agg.group_expr.clone(),
209                    aggr_exprs: agg.aggr_expr.clone(),
210                },
211                downstream_id,
212            ));
213
214            Ok(downstream_id)
215        }
216
217        LogicalPlan::Sort(sort) => {
218            let upstream_id = extract_impl(&sort.input, ctx)?;
219
220            let output_schema = arrow_schema(plan);
221            let downstream_id = ctx.new_pipeline(Arc::clone(&output_schema), output_schema);
222
223            let order_exprs = sort.expr.iter().map(|se| se.expr.clone()).collect();
224            ctx.breakers.push((
225                upstream_id,
226                PipelineBreaker::Sort { order_exprs },
227                downstream_id,
228            ));
229
230            Ok(downstream_id)
231        }
232
233        LogicalPlan::Join(join) => {
234            let left_id = extract_impl(&join.left, ctx)?;
235            let _right_id = extract_impl(&join.right, ctx)?;
236
237            let output_schema = Arc::new(join.schema.as_arrow().clone());
238            let downstream_id = ctx.new_pipeline(Arc::clone(&output_schema), output_schema);
239
240            let left_keys = join.on.iter().map(|(l, _)| l.clone()).collect();
241            let right_keys = join.on.iter().map(|(_, r)| r.clone()).collect();
242
243            ctx.breakers.push((
244                left_id,
245                PipelineBreaker::Join {
246                    join_type: format!("{:?}", join.join_type),
247                    left_keys,
248                    right_keys,
249                },
250                downstream_id,
251            ));
252
253            Ok(downstream_id)
254        }
255
256        LogicalPlan::SubqueryAlias(alias) => extract_impl(&alias.input, ctx),
257
258        LogicalPlan::Limit(limit) => {
259            // Limit requires materialization — treat as a breaker.
260            let upstream_id = extract_impl(&limit.input, ctx)?;
261            let output_schema = arrow_schema(plan);
262            let downstream_id = ctx.new_pipeline(Arc::clone(&output_schema), output_schema);
263
264            ctx.breakers.push((
265                upstream_id,
266                PipelineBreaker::Sort {
267                    order_exprs: vec![],
268                },
269                downstream_id,
270            ));
271
272            Ok(downstream_id)
273        }
274
275        LogicalPlan::Distinct(distinct) => {
276            let input = match distinct {
277                datafusion_expr::Distinct::All(input) => input.as_ref(),
278                datafusion_expr::Distinct::On(d) => d.input.as_ref(),
279            };
280            let upstream_id = extract_impl(input, ctx)?;
281            let output_schema = arrow_schema(plan);
282            let downstream_id = ctx.new_pipeline(Arc::clone(&output_schema), output_schema);
283
284            ctx.breakers.push((
285                upstream_id,
286                PipelineBreaker::Sort {
287                    order_exprs: vec![],
288                },
289                downstream_id,
290            ));
291
292            Ok(downstream_id)
293        }
294
295        // Unknown nodes: try single-input passthrough.
296        other => {
297            let inputs = other.inputs();
298            if inputs.len() == 1 {
299                extract_impl(inputs[0], ctx)
300            } else if inputs.is_empty() {
301                // Unknown leaf — create a source pipeline with the plan's schema.
302                let schema = arrow_schema(other);
303                let pipeline_id = ctx.new_pipeline(Arc::clone(&schema), schema);
304                ctx.sources.push(pipeline_id);
305                Ok(pipeline_id)
306            } else {
307                Err(ExtractError::UnsupportedPlan(format!(
308                    "multi-input plan node: {}",
309                    other.display()
310                )))
311            }
312        }
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use arrow_schema::{DataType, Field, Schema};
320    use datafusion_common::DFSchema;
321    use datafusion_expr::{col, lit, LogicalPlanBuilder};
322
323    /// Helper: creates a simple table scan plan.
324    fn table_scan_plan(fields: Vec<(&str, DataType)>) -> LogicalPlan {
325        let arrow_schema = Arc::new(Schema::new(
326            fields
327                .into_iter()
328                .map(|(name, dt)| Field::new(name, dt, true))
329                .collect::<Vec<_>>(),
330        ));
331        let df_schema = DFSchema::try_from(arrow_schema.as_ref().clone()).unwrap();
332        LogicalPlan::EmptyRelation(datafusion_expr::EmptyRelation {
333            produce_one_row: false,
334            schema: Arc::new(df_schema),
335        })
336    }
337
338    #[test]
339    fn extract_simple_table_scan() {
340        let plan = table_scan_plan(vec![("x", DataType::Int64)]);
341        let extracted = PipelineExtractor::extract(&plan).unwrap();
342
343        assert_eq!(extracted.pipelines.len(), 1);
344        assert_eq!(extracted.sources.len(), 1);
345        assert_eq!(extracted.sinks.len(), 1);
346        assert_eq!(extracted.breakers.len(), 0);
347        assert!(extracted.pipelines[0].stages.is_empty());
348    }
349
350    #[test]
351    fn extract_filter_only() {
352        let scan = table_scan_plan(vec![("x", DataType::Int64)]);
353        let plan = LogicalPlan::Filter(
354            datafusion_expr::Filter::try_new(col("x").gt(lit(10_i64)), Arc::new(scan)).unwrap(),
355        );
356
357        let extracted = PipelineExtractor::extract(&plan).unwrap();
358
359        assert_eq!(extracted.pipelines.len(), 1);
360        assert_eq!(extracted.pipelines[0].stages.len(), 1);
361        assert!(matches!(
362            &extracted.pipelines[0].stages[0],
363            PipelineStage::Filter { .. }
364        ));
365        assert_eq!(extracted.breakers.len(), 0);
366    }
367
368    #[test]
369    fn extract_projection_only() {
370        let scan = table_scan_plan(vec![("x", DataType::Int64), ("y", DataType::Int64)]);
371
372        let plan = LogicalPlanBuilder::from(scan)
373            .project(vec![col("x"), col("y") + lit(1_i64)])
374            .unwrap()
375            .build()
376            .unwrap();
377
378        let extracted = PipelineExtractor::extract(&plan).unwrap();
379
380        assert_eq!(extracted.pipelines.len(), 1);
381        assert_eq!(extracted.pipelines[0].stages.len(), 1);
382        assert!(matches!(
383            &extracted.pipelines[0].stages[0],
384            PipelineStage::Project { expressions } if expressions.len() == 2
385        ));
386    }
387
388    #[test]
389    fn extract_filter_then_project() {
390        let scan = table_scan_plan(vec![("x", DataType::Int64), ("y", DataType::Int64)]);
391
392        let plan = LogicalPlanBuilder::from(scan)
393            .filter(col("x").gt(lit(0_i64)))
394            .unwrap()
395            .project(vec![col("x"), col("y") * lit(2_i64)])
396            .unwrap()
397            .build()
398            .unwrap();
399
400        let extracted = PipelineExtractor::extract(&plan).unwrap();
401
402        assert_eq!(extracted.pipelines.len(), 1);
403        assert_eq!(extracted.pipelines[0].stages.len(), 2);
404        assert!(matches!(
405            &extracted.pipelines[0].stages[0],
406            PipelineStage::Filter { .. }
407        ));
408        assert!(matches!(
409            &extracted.pipelines[0].stages[1],
410            PipelineStage::Project { .. }
411        ));
412    }
413
414    #[test]
415    fn extract_aggregate_manual() {
416        let scan = table_scan_plan(vec![("key", DataType::Int64), ("val", DataType::Int64)]);
417
418        // Build aggregate manually. Schema must match group_expr.len() + aggr_expr.len().
419        let agg_schema = Arc::new(Schema::new(vec![Field::new("key", DataType::Int64, true)]));
420        let df_schema = DFSchema::try_from(agg_schema.as_ref().clone()).unwrap();
421        let agg = datafusion_expr::Aggregate::try_new_with_schema(
422            Arc::new(scan),
423            vec![col("key")],
424            vec![],
425            Arc::new(df_schema),
426        )
427        .unwrap();
428
429        let plan = LogicalPlan::Aggregate(agg);
430        let extracted = PipelineExtractor::extract(&plan).unwrap();
431
432        // Should have 2 pipelines (upstream + downstream) and 1 breaker.
433        assert_eq!(extracted.pipelines.len(), 2);
434        assert_eq!(extracted.breakers.len(), 1);
435        assert!(matches!(
436            &extracted.breakers[0].1,
437            PipelineBreaker::Aggregate { .. }
438        ));
439        // Upstream pipeline should have a KeyExtract stage.
440        let upstream = &extracted.pipelines[0];
441        assert!(upstream
442            .stages
443            .iter()
444            .any(|s| matches!(s, PipelineStage::KeyExtract { .. })));
445    }
446
447    #[test]
448    fn extract_sort_creates_breaker() {
449        let scan = table_scan_plan(vec![("x", DataType::Int64)]);
450
451        let plan = LogicalPlanBuilder::from(scan)
452            .sort(vec![col("x").sort(true, true)])
453            .unwrap()
454            .build()
455            .unwrap();
456
457        let extracted = PipelineExtractor::extract(&plan).unwrap();
458
459        assert_eq!(extracted.pipelines.len(), 2);
460        assert_eq!(extracted.breakers.len(), 1);
461        assert!(matches!(
462            &extracted.breakers[0].1,
463            PipelineBreaker::Sort { order_exprs } if order_exprs.len() == 1
464        ));
465    }
466
467    #[test]
468    fn extract_subquery_alias_passthrough() {
469        let scan = table_scan_plan(vec![("x", DataType::Int64)]);
470
471        let plan = LogicalPlanBuilder::from(scan)
472            .alias("t")
473            .unwrap()
474            .filter(col("x").gt(lit(5_i64)))
475            .unwrap()
476            .build()
477            .unwrap();
478
479        let extracted = PipelineExtractor::extract(&plan).unwrap();
480
481        // SubqueryAlias is transparent — should still be 1 pipeline.
482        assert_eq!(extracted.pipelines.len(), 1);
483        assert_eq!(extracted.pipelines[0].stages.len(), 1);
484        assert!(matches!(
485            &extracted.pipelines[0].stages[0],
486            PipelineStage::Filter { .. }
487        ));
488    }
489
490    #[test]
491    fn extract_nested_filters() {
492        let scan = table_scan_plan(vec![("x", DataType::Int64)]);
493
494        let plan = LogicalPlanBuilder::from(scan)
495            .filter(col("x").gt(lit(0_i64)))
496            .unwrap()
497            .filter(col("x").lt(lit(100_i64)))
498            .unwrap()
499            .build()
500            .unwrap();
501
502        let extracted = PipelineExtractor::extract(&plan).unwrap();
503
504        assert_eq!(extracted.pipelines.len(), 1);
505        assert_eq!(extracted.pipelines[0].stages.len(), 2);
506        assert!(extracted.pipelines[0]
507            .stages
508            .iter()
509            .all(|s| matches!(s, PipelineStage::Filter { .. })));
510    }
511
512    #[test]
513    fn extract_multi_pipeline_filter_agg_project() {
514        let scan = table_scan_plan(vec![("key", DataType::Int64), ("val", DataType::Int64)]);
515
516        // Build: Filter → Aggregate → Project (manually for Aggregate)
517        let filtered = LogicalPlanBuilder::from(scan)
518            .filter(col("val").gt(lit(0_i64)))
519            .unwrap()
520            .build()
521            .unwrap();
522
523        let agg_schema = Arc::new(Schema::new(vec![Field::new("key", DataType::Int64, true)]));
524        let df_schema = DFSchema::try_from(agg_schema.as_ref().clone()).unwrap();
525        let agg = datafusion_expr::Aggregate::try_new_with_schema(
526            Arc::new(filtered),
527            vec![col("key")],
528            vec![],
529            Arc::new(df_schema),
530        )
531        .unwrap();
532
533        let plan = LogicalPlanBuilder::from(LogicalPlan::Aggregate(agg))
534            .project(vec![col("key")])
535            .unwrap()
536            .build()
537            .unwrap();
538
539        let extracted = PipelineExtractor::extract(&plan).unwrap();
540
541        // Pipeline 0: [Filter, KeyExtract] (upstream of aggregate)
542        // Pipeline 1: [Project] (downstream of aggregate)
543        assert_eq!(extracted.pipelines.len(), 2);
544        assert_eq!(extracted.breakers.len(), 1);
545
546        let upstream = &extracted.pipelines[0];
547        assert!(upstream
548            .stages
549            .iter()
550            .any(|s| matches!(s, PipelineStage::Filter { .. })));
551        assert!(upstream
552            .stages
553            .iter()
554            .any(|s| matches!(s, PipelineStage::KeyExtract { .. })));
555
556        let downstream = &extracted.pipelines[1];
557        assert!(downstream
558            .stages
559            .iter()
560            .any(|s| matches!(s, PipelineStage::Project { .. })));
561    }
562
563    #[test]
564    fn extract_empty_relation() {
565        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
566        let df_schema = DFSchema::try_from(schema.as_ref().clone()).unwrap();
567        let plan = LogicalPlan::EmptyRelation(datafusion_expr::EmptyRelation {
568            produce_one_row: false,
569            schema: Arc::new(df_schema),
570        });
571
572        let extracted = PipelineExtractor::extract(&plan).unwrap();
573        assert_eq!(extracted.pipelines.len(), 1);
574        assert!(extracted.pipelines[0].stages.is_empty());
575        assert_eq!(extracted.sources.len(), 1);
576    }
577
578    #[test]
579    fn extract_preserves_pipeline_ids() {
580        let scan = table_scan_plan(vec![("x", DataType::Int64)]);
581
582        let plan = LogicalPlanBuilder::from(scan)
583            .sort(vec![col("x").sort(true, true)])
584            .unwrap()
585            .build()
586            .unwrap();
587
588        let extracted = PipelineExtractor::extract(&plan).unwrap();
589
590        let (upstream_id, _, downstream_id) = &extracted.breakers[0];
591        assert_eq!(*upstream_id, extracted.pipelines[0].id);
592        assert_eq!(*downstream_id, extracted.pipelines[1].id);
593    }
594
595    #[test]
596    fn extract_schema_tracking_through_project() {
597        let scan = table_scan_plan(vec![
598            ("a", DataType::Int64),
599            ("b", DataType::Float64),
600            ("c", DataType::Boolean),
601        ]);
602
603        // Project down to just 2 columns.
604        let plan = LogicalPlanBuilder::from(scan)
605            .project(vec![col("a"), col("b")])
606            .unwrap()
607            .build()
608            .unwrap();
609
610        let extracted = PipelineExtractor::extract(&plan).unwrap();
611        let pipeline = &extracted.pipelines[0];
612
613        // Input schema has 3 fields, output has 2.
614        assert_eq!(pipeline.input_schema.fields().len(), 3);
615        assert_eq!(pipeline.output_schema.fields().len(), 2);
616    }
617
618    #[test]
619    fn extract_limit_creates_breaker() {
620        let scan = table_scan_plan(vec![("x", DataType::Int64)]);
621
622        let plan = LogicalPlanBuilder::from(scan)
623            .limit(0, Some(10))
624            .unwrap()
625            .build()
626            .unwrap();
627
628        let extracted = PipelineExtractor::extract(&plan).unwrap();
629
630        assert_eq!(extracted.pipelines.len(), 2);
631        assert_eq!(extracted.breakers.len(), 1);
632    }
633}