Skip to main content

laminar_sql/planner/
lookup_join.rs

1//! Optimizer rules that rewrite standard JOINs to `LookupJoinNode`.
2//!
3//! When a query joins a streaming source with a registered lookup table,
4//! the `LookupJoinRewriteRule` replaces the standard hash/merge join
5//! with a `LookupJoinNode` that uses the lookup source connector.
6
7#[allow(clippy::disallowed_types)] // cold path: query planning
8use std::collections::{HashMap, HashSet};
9use std::fmt;
10use std::sync::Arc;
11
12use datafusion::common::Result;
13use datafusion::logical_expr::logical_plan::LogicalPlan;
14use datafusion::logical_expr::{Extension, Join, TableScan, UserDefinedLogicalNodeCore};
15use datafusion_common::tree_node::Transformed;
16use datafusion_optimizer::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
17
18use crate::datafusion::lookup_join::{
19    JoinKeyPair, LookupJoinNode, LookupJoinType, LookupTableMetadata,
20};
21use crate::planner::LookupTableInfo;
22
23/// Rewrites standard JOIN nodes that reference a lookup table into
24/// `LookupJoinNode` extension nodes.
25#[derive(Debug)]
26pub struct LookupJoinRewriteRule {
27    /// Registered lookup tables, keyed by name.
28    lookup_tables: HashMap<String, LookupTableInfo>,
29}
30
31impl LookupJoinRewriteRule {
32    /// Creates a new rewrite rule with the given set of registered lookup tables.
33    #[must_use]
34    pub fn new(lookup_tables: HashMap<String, LookupTableInfo>) -> Self {
35        Self { lookup_tables }
36    }
37
38    /// Detects which side of a join (if any) is a lookup table scan.
39    /// Returns `Some((lookup_side_is_right, table_name))`.
40    fn detect_lookup_side(&self, join: &Join) -> Option<(bool, String)> {
41        // Check right side
42        if let Some(name) = scan_table_name(&join.right) {
43            if self.lookup_tables.contains_key(&name) {
44                return Some((true, name));
45            }
46        }
47        // Check left side
48        if let Some(name) = scan_table_name(&join.left) {
49            if self.lookup_tables.contains_key(&name) {
50                return Some((false, name));
51            }
52        }
53        None
54    }
55}
56
57impl OptimizerRule for LookupJoinRewriteRule {
58    fn name(&self) -> &'static str {
59        "lookup_join_rewrite"
60    }
61
62    fn apply_order(&self) -> Option<ApplyOrder> {
63        Some(ApplyOrder::BottomUp)
64    }
65
66    fn rewrite(
67        &self,
68        plan: LogicalPlan,
69        _config: &dyn OptimizerConfig,
70    ) -> Result<Transformed<LogicalPlan>> {
71        let LogicalPlan::Join(join) = &plan else {
72            return Ok(Transformed::no(plan));
73        };
74
75        let Some((lookup_is_right, table_name)) = self.detect_lookup_side(join) else {
76            return Ok(Transformed::no(plan));
77        };
78
79        let info = &self.lookup_tables[&table_name];
80
81        // Determine which side is stream and which is lookup
82        let (stream_plan, lookup_plan) = if lookup_is_right {
83            (join.left.as_ref(), join.right.as_ref())
84        } else {
85            (join.right.as_ref(), join.left.as_ref())
86        };
87
88        // Extract aliases for qualified column resolution (C7)
89        let stream_alias = scan_table_name_and_alias(stream_plan).and_then(|(_, a)| a);
90        let lookup_alias = scan_table_name_and_alias(lookup_plan).and_then(|(_, a)| a);
91
92        let lookup_schema = lookup_plan.schema().clone();
93
94        // Build join key pairs from the equijoin conditions
95        let join_keys: Vec<JoinKeyPair> = join
96            .on
97            .iter()
98            .map(|(left_expr, right_expr)| {
99                let lookup_expr = if lookup_is_right {
100                    right_expr
101                } else {
102                    left_expr
103                };
104                let stream_expr = if lookup_is_right {
105                    left_expr
106                } else {
107                    right_expr
108                };
109                let lookup_column = match lookup_expr {
110                    datafusion::logical_expr::Expr::Column(col) => col.name.clone(),
111                    other => other.to_string(),
112                };
113                JoinKeyPair {
114                    stream_expr: stream_expr.clone(),
115                    lookup_column,
116                }
117            })
118            .collect();
119
120        // Convert DataFusion join type to our lookup join type
121        let join_type = match join.join_type {
122            datafusion::logical_expr::JoinType::Inner => LookupJoinType::Inner,
123            datafusion::logical_expr::JoinType::Left if lookup_is_right => {
124                LookupJoinType::LeftOuter
125            }
126            datafusion::logical_expr::JoinType::Right if !lookup_is_right => {
127                LookupJoinType::LeftOuter
128            }
129            _ => return Ok(Transformed::no(plan)),
130        };
131
132        // All lookup columns are required initially; pruning is done later
133        let required_columns: HashSet<String> = lookup_schema
134            .fields()
135            .iter()
136            .map(|f| f.name().clone())
137            .collect();
138
139        // Build output schema from stream + lookup
140        let stream_schema = stream_plan.schema();
141        let output_schema = Arc::new(stream_schema.join(lookup_schema.as_ref())?);
142
143        let metadata = LookupTableMetadata {
144            connector: info.properties.connector.to_string(),
145            strategy: info.properties.strategy.to_string(),
146            pushdown_mode: info.properties.pushdown_mode.to_string(),
147            primary_key: info.primary_key.clone(),
148        };
149
150        let node = LookupJoinNode::new(
151            stream_plan.clone(),
152            table_name,
153            lookup_schema,
154            join_keys,
155            join_type,
156            vec![], // predicates pushed down later
157            required_columns,
158            output_schema,
159            metadata,
160        )
161        .with_aliases(lookup_alias, stream_alias);
162
163        Ok(Transformed::yes(LogicalPlan::Extension(Extension {
164            node: Arc::new(node),
165        })))
166    }
167}
168
169/// Column pruning rule for `LookupJoinNode`.
170///
171/// Narrows `required_lookup_columns` to only the columns referenced
172/// by downstream plan nodes.
173#[derive(Debug)]
174pub struct LookupColumnPruningRule;
175
176impl OptimizerRule for LookupColumnPruningRule {
177    fn name(&self) -> &'static str {
178        "lookup_column_pruning"
179    }
180
181    fn apply_order(&self) -> Option<ApplyOrder> {
182        Some(ApplyOrder::TopDown)
183    }
184
185    fn rewrite(
186        &self,
187        plan: LogicalPlan,
188        _config: &dyn OptimizerConfig,
189    ) -> Result<Transformed<LogicalPlan>> {
190        let LogicalPlan::Extension(ext) = &plan else {
191            return Ok(Transformed::no(plan));
192        };
193
194        let Some(node) = ext.node.as_any().downcast_ref::<LookupJoinNode>() else {
195            return Ok(Transformed::no(plan));
196        };
197
198        // Collect columns actually used downstream by walking the parent plan.
199        // For now, we use the node's schema to determine which lookup columns
200        // appear in the output. A full implementation would track column usage
201        // from parent nodes; this is a conservative starting point.
202        let schema = UserDefinedLogicalNodeCore::schema(node);
203        let used: HashSet<String> = schema
204            .fields()
205            .iter()
206            .filter(|f| node.required_lookup_columns().contains(f.name()))
207            .map(|f| f.name().clone())
208            .collect();
209
210        if used == *node.required_lookup_columns() {
211            return Ok(Transformed::no(plan));
212        }
213
214        // Rebuild with narrowed columns
215        let node_inputs = UserDefinedLogicalNodeCore::inputs(node);
216        let pruned = LookupJoinNode::new(
217            node_inputs[0].clone(),
218            node.lookup_table_name().to_string(),
219            node.lookup_schema().clone(),
220            node.join_keys().to_vec(),
221            node.join_type(),
222            node.pushdown_predicates().to_vec(),
223            used,
224            schema.clone(),
225            node.metadata().clone(),
226        )
227        .with_local_predicates(node.local_predicates().to_vec())
228        .with_aliases(
229            node.lookup_alias().map(String::from),
230            node.stream_alias().map(String::from),
231        );
232
233        Ok(Transformed::yes(LogicalPlan::Extension(Extension {
234            node: Arc::new(pruned),
235        })))
236    }
237}
238
239/// Extracts the table name and optional alias from a plan node.
240///
241/// Returns `(base_table_name, alias)` — alias is the `SubqueryAlias` name
242/// if the scan is wrapped in one, `None` otherwise.
243fn scan_table_name_and_alias(plan: &LogicalPlan) -> Option<(String, Option<String>)> {
244    match plan {
245        LogicalPlan::TableScan(TableScan { table_name, .. }) => {
246            Some((table_name.table().to_string(), None))
247        }
248        LogicalPlan::SubqueryAlias(alias) => {
249            let alias_name = alias.alias.table().to_string();
250            scan_table_name_and_alias(&alias.input).map(|(base, _)| (base, Some(alias_name)))
251        }
252        _ => None,
253    }
254}
255
256/// Extracts the table name from a `TableScan` node, unwrapping aliases.
257fn scan_table_name(plan: &LogicalPlan) -> Option<String> {
258    scan_table_name_and_alias(plan).map(|(name, _)| name)
259}
260
261/// Display helpers for connector/strategy/pushdown types.
262impl fmt::Display for crate::parser::lookup_table::ConnectorType {
263    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
264        match self {
265            Self::Postgres => write!(f, "postgres"),
266            Self::PostgresCdc => write!(f, "postgres-cdc"),
267            Self::MysqlCdc => write!(f, "mysql-cdc"),
268            Self::Redis => write!(f, "redis"),
269            Self::S3Parquet => write!(f, "s3-parquet"),
270            Self::DeltaLake => write!(f, "delta-lake"),
271            Self::Static => write!(f, "static"),
272            Self::Custom(s) => write!(f, "{s}"),
273        }
274    }
275}
276
277impl fmt::Display for crate::parser::lookup_table::LookupStrategy {
278    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
279        match self {
280            Self::Replicated => write!(f, "replicated"),
281            Self::Partitioned => write!(f, "partitioned"),
282            Self::OnDemand => write!(f, "on-demand"),
283        }
284    }
285}
286
287impl fmt::Display for crate::parser::lookup_table::PushdownMode {
288    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
289        match self {
290            Self::Auto => write!(f, "auto"),
291            Self::Enabled => write!(f, "enabled"),
292            Self::Disabled => write!(f, "disabled"),
293        }
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300    use crate::datafusion::create_session_context;
301    use crate::parser::lookup_table::{
302        ByteSize, ConnectorType, LookupStrategy, LookupTableProperties, PushdownMode,
303    };
304    use arrow::datatypes::{DataType, Field, Schema};
305    use datafusion::prelude::SessionContext;
306    use datafusion_common::tree_node::TreeNode;
307    use datafusion_optimizer::optimizer::OptimizerContext;
308
309    fn test_lookup_info() -> LookupTableInfo {
310        let arrow_schema = Arc::new(Schema::new(vec![
311            Field::new("id", DataType::Int32, false),
312            Field::new("name", DataType::Utf8, true),
313        ]));
314        LookupTableInfo {
315            name: "customers".to_string(),
316            columns: vec![
317                ("id".to_string(), "INT".to_string()),
318                ("name".to_string(), "VARCHAR".to_string()),
319            ],
320            primary_key: vec!["id".to_string()],
321            properties: LookupTableProperties {
322                connector: ConnectorType::PostgresCdc,
323                connection: Some("postgresql://localhost/db".to_string()),
324                strategy: LookupStrategy::Replicated,
325                cache_memory: Some(ByteSize(512 * 1024 * 1024)),
326                cache_disk: None,
327                cache_ttl: None,
328                pushdown_mode: PushdownMode::Auto,
329            },
330            arrow_schema,
331            #[allow(clippy::disallowed_types)] // cold path: query planning
332            raw_options: std::collections::HashMap::new(),
333        }
334    }
335
336    fn register_test_tables(ctx: &SessionContext) {
337        let orders_schema = Arc::new(Schema::new(vec![
338            Field::new("order_id", DataType::Int64, false),
339            Field::new("customer_id", DataType::Int64, false),
340            Field::new("amount", DataType::Float64, false),
341        ]));
342        let customers_schema = Arc::new(Schema::new(vec![
343            Field::new("id", DataType::Int64, false),
344            Field::new("name", DataType::Utf8, true),
345        ]));
346        ctx.register_batch(
347            "orders",
348            arrow::array::RecordBatch::new_empty(orders_schema),
349        )
350        .unwrap();
351        ctx.register_batch(
352            "customers",
353            arrow::array::RecordBatch::new_empty(customers_schema),
354        )
355        .unwrap();
356    }
357
358    #[tokio::test]
359    async fn test_rewrite_join_on_lookup_table() {
360        let ctx = create_session_context();
361        register_test_tables(&ctx);
362
363        let plan = ctx
364            .sql("SELECT o.order_id, c.name FROM orders o JOIN customers c ON o.customer_id = c.id")
365            .await
366            .unwrap()
367            .into_unoptimized_plan();
368
369        let mut lookup_tables = HashMap::new();
370        lookup_tables.insert("customers".to_string(), test_lookup_info());
371        let rule = LookupJoinRewriteRule::new(lookup_tables);
372
373        let transformed = plan
374            .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
375            .unwrap();
376
377        // Verify rewrite happened
378        assert!(transformed.transformed);
379        let has_lookup = format!("{:?}", transformed.data).contains("LookupJoin");
380        assert!(has_lookup, "Expected LookupJoin in plan");
381    }
382
383    #[tokio::test]
384    async fn test_non_lookup_join_not_rewritten() {
385        let ctx = create_session_context();
386        // Register both as regular tables (neither is a lookup table)
387        let schema_a = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
388        let schema_b = Arc::new(Schema::new(vec![Field::new(
389            "a_id",
390            DataType::Int64,
391            false,
392        )]));
393        ctx.register_batch("a", arrow::array::RecordBatch::new_empty(schema_a))
394            .unwrap();
395        ctx.register_batch("b", arrow::array::RecordBatch::new_empty(schema_b))
396            .unwrap();
397
398        let plan = ctx
399            .sql("SELECT * FROM a JOIN b ON a.id = b.a_id")
400            .await
401            .unwrap()
402            .into_unoptimized_plan();
403
404        // No lookup tables registered
405        let rule = LookupJoinRewriteRule::new(HashMap::new());
406
407        let transformed = plan
408            .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
409            .unwrap();
410
411        assert!(!transformed.transformed);
412    }
413
414    #[tokio::test]
415    async fn test_left_outer_produces_left_outer_type() {
416        let ctx = create_session_context();
417        register_test_tables(&ctx);
418
419        let plan = ctx
420            .sql("SELECT o.order_id, c.name FROM orders o LEFT JOIN customers c ON o.customer_id = c.id")
421            .await
422            .unwrap()
423            .into_unoptimized_plan();
424
425        let mut lookup_tables = HashMap::new();
426        lookup_tables.insert("customers".to_string(), test_lookup_info());
427        let rule = LookupJoinRewriteRule::new(lookup_tables);
428
429        let transformed = plan
430            .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
431            .unwrap();
432
433        assert!(transformed.transformed);
434        let debug_str = format!("{:?}", transformed.data);
435        assert!(
436            debug_str.contains("LeftOuter"),
437            "Expected LeftOuter join type, got: {debug_str}"
438        );
439    }
440
441    #[test]
442    fn test_fmt_display_connector_type() {
443        assert_eq!(ConnectorType::PostgresCdc.to_string(), "postgres-cdc");
444        assert_eq!(ConnectorType::Redis.to_string(), "redis");
445        assert_eq!(
446            ConnectorType::Custom("my-conn".into()).to_string(),
447            "my-conn"
448        );
449    }
450
451    #[test]
452    fn test_fmt_display_strategy() {
453        assert_eq!(LookupStrategy::Replicated.to_string(), "replicated");
454        assert_eq!(LookupStrategy::OnDemand.to_string(), "on-demand");
455    }
456
457    #[test]
458    fn test_fmt_display_pushdown_mode() {
459        assert_eq!(PushdownMode::Auto.to_string(), "auto");
460        assert_eq!(PushdownMode::Disabled.to_string(), "disabled");
461    }
462}