1use std::sync::Arc;
8
9use arrow_schema::SchemaRef;
10use datafusion_expr::LogicalPlan;
11
12use super::error::ExtractError;
13use super::pipeline::{Pipeline, PipelineBreaker, PipelineId, PipelineStage};
14
15#[derive(Debug)]
17pub struct ExtractedPlan {
18 pub pipelines: Vec<Pipeline>,
20 pub breakers: Vec<(PipelineId, PipelineBreaker, PipelineId)>,
22 pub sources: Vec<PipelineId>,
24 pub sinks: Vec<PipelineId>,
26}
27
28struct ExtractionContext {
30 pipelines: Vec<PipelineBuilder>,
31 breakers: Vec<(PipelineId, PipelineBreaker, PipelineId)>,
32 sources: Vec<PipelineId>,
33 sinks: Vec<PipelineId>,
34 next_id: u32,
35}
36
37struct PipelineBuilder {
39 id: PipelineId,
40 stages: Vec<PipelineStage>,
41 input_schema: SchemaRef,
42 output_schema: SchemaRef,
43}
44
45impl ExtractionContext {
46 fn new() -> Self {
47 Self {
48 pipelines: Vec::new(),
49 breakers: Vec::new(),
50 sources: Vec::new(),
51 sinks: Vec::new(),
52 next_id: 0,
53 }
54 }
55
56 fn new_pipeline(&mut self, input_schema: SchemaRef, output_schema: SchemaRef) -> PipelineId {
58 let id = PipelineId(self.next_id);
59 self.next_id += 1;
60 self.pipelines.push(PipelineBuilder {
61 id,
62 stages: Vec::new(),
63 input_schema,
64 output_schema,
65 });
66 id
67 }
68
69 fn add_stage(&mut self, pipeline_id: PipelineId, stage: PipelineStage) {
71 let builder = self
72 .pipelines
73 .iter_mut()
74 .find(|p| p.id == pipeline_id)
75 .expect("pipeline not found");
76 builder.stages.push(stage);
77 }
78
79 fn update_output_schema(&mut self, pipeline_id: PipelineId, schema: SchemaRef) {
81 let builder = self
82 .pipelines
83 .iter_mut()
84 .find(|p| p.id == pipeline_id)
85 .expect("pipeline not found");
86 builder.output_schema = schema;
87 }
88
89 fn finalize(self) -> ExtractedPlan {
91 let pipelines = self
92 .pipelines
93 .into_iter()
94 .map(|b| Pipeline {
95 id: b.id,
96 stages: b.stages,
97 input_schema: b.input_schema,
98 output_schema: b.output_schema,
99 })
100 .collect();
101
102 ExtractedPlan {
103 pipelines,
104 breakers: self.breakers,
105 sources: self.sources,
106 sinks: self.sinks,
107 }
108 }
109}
110
111pub struct PipelineExtractor;
113
114impl PipelineExtractor {
115 pub fn extract(plan: &LogicalPlan) -> Result<ExtractedPlan, ExtractError> {
124 let mut ctx = ExtractionContext::new();
125 let terminal_id = extract_impl(plan, &mut ctx)?;
126
127 ctx.sinks.push(terminal_id);
129
130 Ok(ctx.finalize())
131 }
132}
133
134fn arrow_schema(plan: &LogicalPlan) -> SchemaRef {
136 Arc::new(plan.schema().as_arrow().clone())
137}
138
139#[allow(clippy::too_many_lines)]
141fn extract_impl(
142 plan: &LogicalPlan,
143 ctx: &mut ExtractionContext,
144) -> Result<PipelineId, ExtractError> {
145 match plan {
146 LogicalPlan::TableScan(scan) => {
147 let schema = Arc::new(scan.projected_schema.as_arrow().clone());
148 let pipeline_id = ctx.new_pipeline(Arc::clone(&schema), schema);
149 ctx.sources.push(pipeline_id);
150 Ok(pipeline_id)
151 }
152
153 LogicalPlan::EmptyRelation(empty) => {
154 let schema = Arc::new(empty.schema.as_arrow().clone());
155 let pipeline_id = ctx.new_pipeline(Arc::clone(&schema), schema);
156 ctx.sources.push(pipeline_id);
157 Ok(pipeline_id)
158 }
159
160 LogicalPlan::Filter(filter) => {
161 let child_id = extract_impl(&filter.input, ctx)?;
162 ctx.add_stage(
163 child_id,
164 PipelineStage::Filter {
165 predicate: filter.predicate.clone(),
166 },
167 );
168 Ok(child_id)
170 }
171
172 LogicalPlan::Projection(proj) => {
173 let child_id = extract_impl(&proj.input, ctx)?;
174
175 let arrow_out = proj.schema.as_arrow();
176 let expressions: Vec<(datafusion_expr::Expr, String)> = proj
177 .expr
178 .iter()
179 .enumerate()
180 .map(|(i, expr)| (expr.clone(), arrow_out.field(i).name().clone()))
181 .collect();
182
183 ctx.add_stage(child_id, PipelineStage::Project { expressions });
184 ctx.update_output_schema(child_id, Arc::new(arrow_out.clone()));
185 Ok(child_id)
186 }
187
188 LogicalPlan::Aggregate(agg) => {
189 let upstream_id = extract_impl(&agg.input, ctx)?;
190
191 if !agg.group_expr.is_empty() {
193 ctx.add_stage(
194 upstream_id,
195 PipelineStage::KeyExtract {
196 key_exprs: agg.group_expr.clone(),
197 },
198 );
199 }
200
201 let output_schema = Arc::new(agg.schema.as_arrow().clone());
203 let downstream_id = ctx.new_pipeline(Arc::clone(&output_schema), output_schema);
204
205 ctx.breakers.push((
206 upstream_id,
207 PipelineBreaker::Aggregate {
208 group_exprs: agg.group_expr.clone(),
209 aggr_exprs: agg.aggr_expr.clone(),
210 },
211 downstream_id,
212 ));
213
214 Ok(downstream_id)
215 }
216
217 LogicalPlan::Sort(sort) => {
218 let upstream_id = extract_impl(&sort.input, ctx)?;
219
220 let output_schema = arrow_schema(plan);
221 let downstream_id = ctx.new_pipeline(Arc::clone(&output_schema), output_schema);
222
223 let order_exprs = sort.expr.iter().map(|se| se.expr.clone()).collect();
224 ctx.breakers.push((
225 upstream_id,
226 PipelineBreaker::Sort { order_exprs },
227 downstream_id,
228 ));
229
230 Ok(downstream_id)
231 }
232
233 LogicalPlan::Join(join) => {
234 let left_id = extract_impl(&join.left, ctx)?;
235 let _right_id = extract_impl(&join.right, ctx)?;
236
237 let output_schema = Arc::new(join.schema.as_arrow().clone());
238 let downstream_id = ctx.new_pipeline(Arc::clone(&output_schema), output_schema);
239
240 let left_keys = join.on.iter().map(|(l, _)| l.clone()).collect();
241 let right_keys = join.on.iter().map(|(_, r)| r.clone()).collect();
242
243 ctx.breakers.push((
244 left_id,
245 PipelineBreaker::Join {
246 join_type: format!("{:?}", join.join_type),
247 left_keys,
248 right_keys,
249 },
250 downstream_id,
251 ));
252
253 Ok(downstream_id)
254 }
255
256 LogicalPlan::SubqueryAlias(alias) => extract_impl(&alias.input, ctx),
257
258 LogicalPlan::Limit(limit) => {
259 let upstream_id = extract_impl(&limit.input, ctx)?;
261 let output_schema = arrow_schema(plan);
262 let downstream_id = ctx.new_pipeline(Arc::clone(&output_schema), output_schema);
263
264 ctx.breakers.push((
265 upstream_id,
266 PipelineBreaker::Sort {
267 order_exprs: vec![],
268 },
269 downstream_id,
270 ));
271
272 Ok(downstream_id)
273 }
274
275 LogicalPlan::Distinct(distinct) => {
276 let input = match distinct {
277 datafusion_expr::Distinct::All(input) => input.as_ref(),
278 datafusion_expr::Distinct::On(d) => d.input.as_ref(),
279 };
280 let upstream_id = extract_impl(input, ctx)?;
281 let output_schema = arrow_schema(plan);
282 let downstream_id = ctx.new_pipeline(Arc::clone(&output_schema), output_schema);
283
284 ctx.breakers.push((
285 upstream_id,
286 PipelineBreaker::Sort {
287 order_exprs: vec![],
288 },
289 downstream_id,
290 ));
291
292 Ok(downstream_id)
293 }
294
295 other => {
297 let inputs = other.inputs();
298 if inputs.len() == 1 {
299 extract_impl(inputs[0], ctx)
300 } else if inputs.is_empty() {
301 let schema = arrow_schema(other);
303 let pipeline_id = ctx.new_pipeline(Arc::clone(&schema), schema);
304 ctx.sources.push(pipeline_id);
305 Ok(pipeline_id)
306 } else {
307 Err(ExtractError::UnsupportedPlan(format!(
308 "multi-input plan node: {}",
309 other.display()
310 )))
311 }
312 }
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319 use arrow_schema::{DataType, Field, Schema};
320 use datafusion_common::DFSchema;
321 use datafusion_expr::{col, lit, LogicalPlanBuilder};
322
323 fn table_scan_plan(fields: Vec<(&str, DataType)>) -> LogicalPlan {
325 let arrow_schema = Arc::new(Schema::new(
326 fields
327 .into_iter()
328 .map(|(name, dt)| Field::new(name, dt, true))
329 .collect::<Vec<_>>(),
330 ));
331 let df_schema = DFSchema::try_from(arrow_schema.as_ref().clone()).unwrap();
332 LogicalPlan::EmptyRelation(datafusion_expr::EmptyRelation {
333 produce_one_row: false,
334 schema: Arc::new(df_schema),
335 })
336 }
337
338 #[test]
339 fn extract_simple_table_scan() {
340 let plan = table_scan_plan(vec![("x", DataType::Int64)]);
341 let extracted = PipelineExtractor::extract(&plan).unwrap();
342
343 assert_eq!(extracted.pipelines.len(), 1);
344 assert_eq!(extracted.sources.len(), 1);
345 assert_eq!(extracted.sinks.len(), 1);
346 assert_eq!(extracted.breakers.len(), 0);
347 assert!(extracted.pipelines[0].stages.is_empty());
348 }
349
350 #[test]
351 fn extract_filter_only() {
352 let scan = table_scan_plan(vec![("x", DataType::Int64)]);
353 let plan = LogicalPlan::Filter(
354 datafusion_expr::Filter::try_new(col("x").gt(lit(10_i64)), Arc::new(scan)).unwrap(),
355 );
356
357 let extracted = PipelineExtractor::extract(&plan).unwrap();
358
359 assert_eq!(extracted.pipelines.len(), 1);
360 assert_eq!(extracted.pipelines[0].stages.len(), 1);
361 assert!(matches!(
362 &extracted.pipelines[0].stages[0],
363 PipelineStage::Filter { .. }
364 ));
365 assert_eq!(extracted.breakers.len(), 0);
366 }
367
368 #[test]
369 fn extract_projection_only() {
370 let scan = table_scan_plan(vec![("x", DataType::Int64), ("y", DataType::Int64)]);
371
372 let plan = LogicalPlanBuilder::from(scan)
373 .project(vec![col("x"), col("y") + lit(1_i64)])
374 .unwrap()
375 .build()
376 .unwrap();
377
378 let extracted = PipelineExtractor::extract(&plan).unwrap();
379
380 assert_eq!(extracted.pipelines.len(), 1);
381 assert_eq!(extracted.pipelines[0].stages.len(), 1);
382 assert!(matches!(
383 &extracted.pipelines[0].stages[0],
384 PipelineStage::Project { expressions } if expressions.len() == 2
385 ));
386 }
387
388 #[test]
389 fn extract_filter_then_project() {
390 let scan = table_scan_plan(vec![("x", DataType::Int64), ("y", DataType::Int64)]);
391
392 let plan = LogicalPlanBuilder::from(scan)
393 .filter(col("x").gt(lit(0_i64)))
394 .unwrap()
395 .project(vec![col("x"), col("y") * lit(2_i64)])
396 .unwrap()
397 .build()
398 .unwrap();
399
400 let extracted = PipelineExtractor::extract(&plan).unwrap();
401
402 assert_eq!(extracted.pipelines.len(), 1);
403 assert_eq!(extracted.pipelines[0].stages.len(), 2);
404 assert!(matches!(
405 &extracted.pipelines[0].stages[0],
406 PipelineStage::Filter { .. }
407 ));
408 assert!(matches!(
409 &extracted.pipelines[0].stages[1],
410 PipelineStage::Project { .. }
411 ));
412 }
413
414 #[test]
415 fn extract_aggregate_manual() {
416 let scan = table_scan_plan(vec![("key", DataType::Int64), ("val", DataType::Int64)]);
417
418 let agg_schema = Arc::new(Schema::new(vec![Field::new("key", DataType::Int64, true)]));
420 let df_schema = DFSchema::try_from(agg_schema.as_ref().clone()).unwrap();
421 let agg = datafusion_expr::Aggregate::try_new_with_schema(
422 Arc::new(scan),
423 vec![col("key")],
424 vec![],
425 Arc::new(df_schema),
426 )
427 .unwrap();
428
429 let plan = LogicalPlan::Aggregate(agg);
430 let extracted = PipelineExtractor::extract(&plan).unwrap();
431
432 assert_eq!(extracted.pipelines.len(), 2);
434 assert_eq!(extracted.breakers.len(), 1);
435 assert!(matches!(
436 &extracted.breakers[0].1,
437 PipelineBreaker::Aggregate { .. }
438 ));
439 let upstream = &extracted.pipelines[0];
441 assert!(upstream
442 .stages
443 .iter()
444 .any(|s| matches!(s, PipelineStage::KeyExtract { .. })));
445 }
446
447 #[test]
448 fn extract_sort_creates_breaker() {
449 let scan = table_scan_plan(vec![("x", DataType::Int64)]);
450
451 let plan = LogicalPlanBuilder::from(scan)
452 .sort(vec![col("x").sort(true, true)])
453 .unwrap()
454 .build()
455 .unwrap();
456
457 let extracted = PipelineExtractor::extract(&plan).unwrap();
458
459 assert_eq!(extracted.pipelines.len(), 2);
460 assert_eq!(extracted.breakers.len(), 1);
461 assert!(matches!(
462 &extracted.breakers[0].1,
463 PipelineBreaker::Sort { order_exprs } if order_exprs.len() == 1
464 ));
465 }
466
467 #[test]
468 fn extract_subquery_alias_passthrough() {
469 let scan = table_scan_plan(vec![("x", DataType::Int64)]);
470
471 let plan = LogicalPlanBuilder::from(scan)
472 .alias("t")
473 .unwrap()
474 .filter(col("x").gt(lit(5_i64)))
475 .unwrap()
476 .build()
477 .unwrap();
478
479 let extracted = PipelineExtractor::extract(&plan).unwrap();
480
481 assert_eq!(extracted.pipelines.len(), 1);
483 assert_eq!(extracted.pipelines[0].stages.len(), 1);
484 assert!(matches!(
485 &extracted.pipelines[0].stages[0],
486 PipelineStage::Filter { .. }
487 ));
488 }
489
490 #[test]
491 fn extract_nested_filters() {
492 let scan = table_scan_plan(vec![("x", DataType::Int64)]);
493
494 let plan = LogicalPlanBuilder::from(scan)
495 .filter(col("x").gt(lit(0_i64)))
496 .unwrap()
497 .filter(col("x").lt(lit(100_i64)))
498 .unwrap()
499 .build()
500 .unwrap();
501
502 let extracted = PipelineExtractor::extract(&plan).unwrap();
503
504 assert_eq!(extracted.pipelines.len(), 1);
505 assert_eq!(extracted.pipelines[0].stages.len(), 2);
506 assert!(extracted.pipelines[0]
507 .stages
508 .iter()
509 .all(|s| matches!(s, PipelineStage::Filter { .. })));
510 }
511
512 #[test]
513 fn extract_multi_pipeline_filter_agg_project() {
514 let scan = table_scan_plan(vec![("key", DataType::Int64), ("val", DataType::Int64)]);
515
516 let filtered = LogicalPlanBuilder::from(scan)
518 .filter(col("val").gt(lit(0_i64)))
519 .unwrap()
520 .build()
521 .unwrap();
522
523 let agg_schema = Arc::new(Schema::new(vec![Field::new("key", DataType::Int64, true)]));
524 let df_schema = DFSchema::try_from(agg_schema.as_ref().clone()).unwrap();
525 let agg = datafusion_expr::Aggregate::try_new_with_schema(
526 Arc::new(filtered),
527 vec![col("key")],
528 vec![],
529 Arc::new(df_schema),
530 )
531 .unwrap();
532
533 let plan = LogicalPlanBuilder::from(LogicalPlan::Aggregate(agg))
534 .project(vec![col("key")])
535 .unwrap()
536 .build()
537 .unwrap();
538
539 let extracted = PipelineExtractor::extract(&plan).unwrap();
540
541 assert_eq!(extracted.pipelines.len(), 2);
544 assert_eq!(extracted.breakers.len(), 1);
545
546 let upstream = &extracted.pipelines[0];
547 assert!(upstream
548 .stages
549 .iter()
550 .any(|s| matches!(s, PipelineStage::Filter { .. })));
551 assert!(upstream
552 .stages
553 .iter()
554 .any(|s| matches!(s, PipelineStage::KeyExtract { .. })));
555
556 let downstream = &extracted.pipelines[1];
557 assert!(downstream
558 .stages
559 .iter()
560 .any(|s| matches!(s, PipelineStage::Project { .. })));
561 }
562
563 #[test]
564 fn extract_empty_relation() {
565 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
566 let df_schema = DFSchema::try_from(schema.as_ref().clone()).unwrap();
567 let plan = LogicalPlan::EmptyRelation(datafusion_expr::EmptyRelation {
568 produce_one_row: false,
569 schema: Arc::new(df_schema),
570 });
571
572 let extracted = PipelineExtractor::extract(&plan).unwrap();
573 assert_eq!(extracted.pipelines.len(), 1);
574 assert!(extracted.pipelines[0].stages.is_empty());
575 assert_eq!(extracted.sources.len(), 1);
576 }
577
578 #[test]
579 fn extract_preserves_pipeline_ids() {
580 let scan = table_scan_plan(vec![("x", DataType::Int64)]);
581
582 let plan = LogicalPlanBuilder::from(scan)
583 .sort(vec![col("x").sort(true, true)])
584 .unwrap()
585 .build()
586 .unwrap();
587
588 let extracted = PipelineExtractor::extract(&plan).unwrap();
589
590 let (upstream_id, _, downstream_id) = &extracted.breakers[0];
591 assert_eq!(*upstream_id, extracted.pipelines[0].id);
592 assert_eq!(*downstream_id, extracted.pipelines[1].id);
593 }
594
595 #[test]
596 fn extract_schema_tracking_through_project() {
597 let scan = table_scan_plan(vec![
598 ("a", DataType::Int64),
599 ("b", DataType::Float64),
600 ("c", DataType::Boolean),
601 ]);
602
603 let plan = LogicalPlanBuilder::from(scan)
605 .project(vec![col("a"), col("b")])
606 .unwrap()
607 .build()
608 .unwrap();
609
610 let extracted = PipelineExtractor::extract(&plan).unwrap();
611 let pipeline = &extracted.pipelines[0];
612
613 assert_eq!(pipeline.input_schema.fields().len(), 3);
615 assert_eq!(pipeline.output_schema.fields().len(), 2);
616 }
617
618 #[test]
619 fn extract_limit_creates_breaker() {
620 let scan = table_scan_plan(vec![("x", DataType::Int64)]);
621
622 let plan = LogicalPlanBuilder::from(scan)
623 .limit(0, Some(10))
624 .unwrap()
625 .build()
626 .unwrap();
627
628 let extracted = PipelineExtractor::extract(&plan).unwrap();
629
630 assert_eq!(extracted.pipelines.len(), 2);
631 assert_eq!(extracted.breakers.len(), 1);
632 }
633}