Skip to main content

laminar_sql/planner/
lookup_join.rs

1//! Optimizer rules that rewrite standard JOINs to `LookupJoinNode`.
2//!
3//! When a query joins a streaming source with a registered lookup table,
4//! the `LookupJoinRewriteRule` replaces the standard hash/merge join
5//! with a `LookupJoinNode` that uses the lookup source connector.
6
7#[allow(clippy::disallowed_types)] // cold path: query planning
8use std::collections::{HashMap, HashSet};
9use std::fmt;
10use std::sync::Arc;
11
12use datafusion::common::{DFSchema, Result};
13use datafusion::logical_expr::logical_plan::LogicalPlan;
14use datafusion::logical_expr::{Extension, Join, TableScan, UserDefinedLogicalNodeCore};
15use datafusion_common::tree_node::Transformed;
16use datafusion_optimizer::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
17
18use crate::datafusion::lookup_join::{
19    JoinKeyPair, LookupJoinNode, LookupJoinType, LookupTableMetadata,
20};
21use crate::planner::LookupTableInfo;
22
23/// Rewrites standard JOIN nodes that reference a lookup table into
24/// `LookupJoinNode` extension nodes.
25#[derive(Debug)]
26pub struct LookupJoinRewriteRule {
27    /// Registered lookup tables, keyed by name.
28    lookup_tables: HashMap<String, LookupTableInfo>,
29}
30
31impl LookupJoinRewriteRule {
32    /// Creates a new rewrite rule with the given set of registered lookup tables.
33    #[must_use]
34    pub fn new(lookup_tables: HashMap<String, LookupTableInfo>) -> Self {
35        Self { lookup_tables }
36    }
37
38    /// Detects which side of a join (if any) is a lookup table scan.
39    /// Returns `Some((lookup_side_is_right, table_name))`.
40    fn detect_lookup_side(&self, join: &Join) -> Option<(bool, String)> {
41        // Check right side
42        if let Some(name) = scan_table_name(&join.right) {
43            if self.lookup_tables.contains_key(&name) {
44                return Some((true, name));
45            }
46        }
47        // Check left side
48        if let Some(name) = scan_table_name(&join.left) {
49            if self.lookup_tables.contains_key(&name) {
50                return Some((false, name));
51            }
52        }
53        None
54    }
55}
56
57impl OptimizerRule for LookupJoinRewriteRule {
58    fn name(&self) -> &'static str {
59        "lookup_join_rewrite"
60    }
61
62    fn apply_order(&self) -> Option<ApplyOrder> {
63        Some(ApplyOrder::BottomUp)
64    }
65
66    fn rewrite(
67        &self,
68        plan: LogicalPlan,
69        _config: &dyn OptimizerConfig,
70    ) -> Result<Transformed<LogicalPlan>> {
71        let LogicalPlan::Join(join) = &plan else {
72            return Ok(Transformed::no(plan));
73        };
74
75        let Some((lookup_is_right, table_name)) = self.detect_lookup_side(join) else {
76            return Ok(Transformed::no(plan));
77        };
78
79        let info = &self.lookup_tables[&table_name];
80
81        // Determine which side is stream and which is lookup
82        let (stream_plan, lookup_plan) = if lookup_is_right {
83            (join.left.as_ref(), join.right.as_ref())
84        } else {
85            (join.right.as_ref(), join.left.as_ref())
86        };
87
88        // Extract aliases for qualified column resolution (C7)
89        let stream_alias = scan_table_name_and_alias(stream_plan).and_then(|(_, a)| a);
90        let lookup_alias = scan_table_name_and_alias(lookup_plan).and_then(|(_, a)| a);
91
92        let lookup_schema = lookup_plan.schema().clone();
93
94        // Build join key pairs from the equijoin conditions
95        let join_keys: Vec<JoinKeyPair> = join
96            .on
97            .iter()
98            .map(|(left_expr, right_expr)| {
99                if lookup_is_right {
100                    JoinKeyPair {
101                        stream_expr: left_expr.clone(),
102                        lookup_column: right_expr.to_string(),
103                    }
104                } else {
105                    JoinKeyPair {
106                        stream_expr: right_expr.clone(),
107                        lookup_column: left_expr.to_string(),
108                    }
109                }
110            })
111            .collect();
112
113        // Convert DataFusion join type to our lookup join type
114        let join_type = match join.join_type {
115            datafusion::logical_expr::JoinType::Inner => LookupJoinType::Inner,
116            datafusion::logical_expr::JoinType::Left if lookup_is_right => {
117                LookupJoinType::LeftOuter
118            }
119            datafusion::logical_expr::JoinType::Right if !lookup_is_right => {
120                LookupJoinType::LeftOuter
121            }
122            _ => return Ok(Transformed::no(plan)),
123        };
124
125        // All lookup columns are required initially; pruning is done later
126        let required_columns: HashSet<String> = lookup_schema
127            .fields()
128            .iter()
129            .map(|f| f.name().clone())
130            .collect();
131
132        // Build output schema from stream + lookup
133        let stream_schema = stream_plan.schema();
134        let merged_fields: Vec<_> = stream_schema
135            .fields()
136            .iter()
137            .chain(lookup_schema.fields().iter())
138            .cloned()
139            .collect();
140        let output_schema = Arc::new(DFSchema::from_unqualified_fields(
141            merged_fields.into(),
142            HashMap::new(),
143        )?);
144
145        let metadata = LookupTableMetadata {
146            connector: info.properties.connector.to_string(),
147            strategy: info.properties.strategy.to_string(),
148            pushdown_mode: info.properties.pushdown_mode.to_string(),
149            primary_key: info.primary_key.clone(),
150        };
151
152        let node = LookupJoinNode::new(
153            stream_plan.clone(),
154            table_name,
155            lookup_schema,
156            join_keys,
157            join_type,
158            vec![], // predicates pushed down later
159            required_columns,
160            output_schema,
161            metadata,
162        )
163        .with_aliases(lookup_alias, stream_alias);
164
165        Ok(Transformed::yes(LogicalPlan::Extension(Extension {
166            node: Arc::new(node),
167        })))
168    }
169}
170
171/// Column pruning rule for `LookupJoinNode`.
172///
173/// Narrows `required_lookup_columns` to only the columns referenced
174/// by downstream plan nodes.
175#[derive(Debug)]
176pub struct LookupColumnPruningRule;
177
178impl OptimizerRule for LookupColumnPruningRule {
179    fn name(&self) -> &'static str {
180        "lookup_column_pruning"
181    }
182
183    fn apply_order(&self) -> Option<ApplyOrder> {
184        Some(ApplyOrder::TopDown)
185    }
186
187    fn rewrite(
188        &self,
189        plan: LogicalPlan,
190        _config: &dyn OptimizerConfig,
191    ) -> Result<Transformed<LogicalPlan>> {
192        let LogicalPlan::Extension(ext) = &plan else {
193            return Ok(Transformed::no(plan));
194        };
195
196        let Some(node) = ext.node.as_any().downcast_ref::<LookupJoinNode>() else {
197            return Ok(Transformed::no(plan));
198        };
199
200        // Collect columns actually used downstream by walking the parent plan.
201        // For now, we use the node's schema to determine which lookup columns
202        // appear in the output. A full implementation would track column usage
203        // from parent nodes; this is a conservative starting point.
204        let schema = UserDefinedLogicalNodeCore::schema(node);
205        let used: HashSet<String> = schema
206            .fields()
207            .iter()
208            .filter(|f| node.required_lookup_columns().contains(f.name()))
209            .map(|f| f.name().clone())
210            .collect();
211
212        if used == *node.required_lookup_columns() {
213            return Ok(Transformed::no(plan));
214        }
215
216        // Rebuild with narrowed columns
217        let node_inputs = UserDefinedLogicalNodeCore::inputs(node);
218        let pruned = LookupJoinNode::new(
219            node_inputs[0].clone(),
220            node.lookup_table_name().to_string(),
221            node.lookup_schema().clone(),
222            node.join_keys().to_vec(),
223            node.join_type(),
224            node.pushdown_predicates().to_vec(),
225            used,
226            schema.clone(),
227            node.metadata().clone(),
228        )
229        .with_local_predicates(node.local_predicates().to_vec())
230        .with_aliases(
231            node.lookup_alias().map(String::from),
232            node.stream_alias().map(String::from),
233        );
234
235        Ok(Transformed::yes(LogicalPlan::Extension(Extension {
236            node: Arc::new(pruned),
237        })))
238    }
239}
240
241/// Extracts the table name and optional alias from a plan node.
242///
243/// Returns `(base_table_name, alias)` — alias is the `SubqueryAlias` name
244/// if the scan is wrapped in one, `None` otherwise.
245fn scan_table_name_and_alias(plan: &LogicalPlan) -> Option<(String, Option<String>)> {
246    match plan {
247        LogicalPlan::TableScan(TableScan { table_name, .. }) => {
248            Some((table_name.table().to_string(), None))
249        }
250        LogicalPlan::SubqueryAlias(alias) => {
251            let alias_name = alias.alias.table().to_string();
252            scan_table_name_and_alias(&alias.input).map(|(base, _)| (base, Some(alias_name)))
253        }
254        _ => None,
255    }
256}
257
258/// Extracts the table name from a `TableScan` node, unwrapping aliases.
259fn scan_table_name(plan: &LogicalPlan) -> Option<String> {
260    scan_table_name_and_alias(plan).map(|(name, _)| name)
261}
262
263/// Display helpers for connector/strategy/pushdown types.
264impl fmt::Display for crate::parser::lookup_table::ConnectorType {
265    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
266        match self {
267            Self::PostgresCdc => write!(f, "postgres-cdc"),
268            Self::MysqlCdc => write!(f, "mysql-cdc"),
269            Self::Redis => write!(f, "redis"),
270            Self::S3Parquet => write!(f, "s3-parquet"),
271            Self::Static => write!(f, "static"),
272            Self::Custom(s) => write!(f, "{s}"),
273        }
274    }
275}
276
277impl fmt::Display for crate::parser::lookup_table::LookupStrategy {
278    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
279        match self {
280            Self::Replicated => write!(f, "replicated"),
281            Self::Partitioned => write!(f, "partitioned"),
282            Self::OnDemand => write!(f, "on-demand"),
283        }
284    }
285}
286
287impl fmt::Display for crate::parser::lookup_table::PushdownMode {
288    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
289        match self {
290            Self::Auto => write!(f, "auto"),
291            Self::Enabled => write!(f, "enabled"),
292            Self::Disabled => write!(f, "disabled"),
293        }
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300    use crate::datafusion::create_session_context;
301    use crate::parser::lookup_table::{
302        ByteSize, ConnectorType, LookupStrategy, LookupTableProperties, PushdownMode,
303    };
304    use arrow::datatypes::{DataType, Field, Schema};
305    use datafusion::prelude::SessionContext;
306    use datafusion_common::tree_node::TreeNode;
307    use datafusion_optimizer::optimizer::OptimizerContext;
308
309    fn test_lookup_info() -> LookupTableInfo {
310        let arrow_schema = Arc::new(Schema::new(vec![
311            Field::new("id", DataType::Int32, false),
312            Field::new("name", DataType::Utf8, true),
313        ]));
314        LookupTableInfo {
315            name: "customers".to_string(),
316            columns: vec![
317                ("id".to_string(), "INT".to_string()),
318                ("name".to_string(), "VARCHAR".to_string()),
319            ],
320            primary_key: vec!["id".to_string()],
321            properties: LookupTableProperties {
322                connector: ConnectorType::PostgresCdc,
323                connection: Some("postgresql://localhost/db".to_string()),
324                strategy: LookupStrategy::Replicated,
325                cache_memory: Some(ByteSize(512 * 1024 * 1024)),
326                cache_disk: None,
327                cache_ttl: None,
328                pushdown_mode: PushdownMode::Auto,
329            },
330            arrow_schema,
331            #[allow(clippy::disallowed_types)] // cold path: query planning
332            raw_options: std::collections::HashMap::new(),
333        }
334    }
335
336    fn register_test_tables(ctx: &SessionContext) {
337        let orders_schema = Arc::new(Schema::new(vec![
338            Field::new("order_id", DataType::Int64, false),
339            Field::new("customer_id", DataType::Int64, false),
340            Field::new("amount", DataType::Float64, false),
341        ]));
342        let customers_schema = Arc::new(Schema::new(vec![
343            Field::new("id", DataType::Int64, false),
344            Field::new("name", DataType::Utf8, true),
345        ]));
346        ctx.register_batch(
347            "orders",
348            arrow::array::RecordBatch::new_empty(orders_schema),
349        )
350        .unwrap();
351        ctx.register_batch(
352            "customers",
353            arrow::array::RecordBatch::new_empty(customers_schema),
354        )
355        .unwrap();
356    }
357
358    #[tokio::test]
359    async fn test_rewrite_join_on_lookup_table() {
360        let ctx = create_session_context();
361        register_test_tables(&ctx);
362
363        let plan = ctx
364            .sql("SELECT o.order_id, c.name FROM orders o JOIN customers c ON o.customer_id = c.id")
365            .await
366            .unwrap()
367            .into_unoptimized_plan();
368
369        let mut lookup_tables = HashMap::new();
370        lookup_tables.insert("customers".to_string(), test_lookup_info());
371        let rule = LookupJoinRewriteRule::new(lookup_tables);
372
373        let transformed = plan
374            .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
375            .unwrap();
376
377        // Verify rewrite happened
378        assert!(transformed.transformed);
379        let has_lookup = format!("{:?}", transformed.data).contains("LookupJoin");
380        assert!(has_lookup, "Expected LookupJoin in plan");
381    }
382
383    #[tokio::test]
384    async fn test_non_lookup_join_not_rewritten() {
385        let ctx = create_session_context();
386        // Register both as regular tables (neither is a lookup table)
387        let schema_a = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
388        let schema_b = Arc::new(Schema::new(vec![Field::new(
389            "a_id",
390            DataType::Int64,
391            false,
392        )]));
393        ctx.register_batch("a", arrow::array::RecordBatch::new_empty(schema_a))
394            .unwrap();
395        ctx.register_batch("b", arrow::array::RecordBatch::new_empty(schema_b))
396            .unwrap();
397
398        let plan = ctx
399            .sql("SELECT * FROM a JOIN b ON a.id = b.a_id")
400            .await
401            .unwrap()
402            .into_unoptimized_plan();
403
404        // No lookup tables registered
405        let rule = LookupJoinRewriteRule::new(HashMap::new());
406
407        let transformed = plan
408            .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
409            .unwrap();
410
411        assert!(!transformed.transformed);
412    }
413
414    #[tokio::test]
415    async fn test_left_outer_produces_left_outer_type() {
416        let ctx = create_session_context();
417        register_test_tables(&ctx);
418
419        let plan = ctx
420            .sql("SELECT o.order_id, c.name FROM orders o LEFT JOIN customers c ON o.customer_id = c.id")
421            .await
422            .unwrap()
423            .into_unoptimized_plan();
424
425        let mut lookup_tables = HashMap::new();
426        lookup_tables.insert("customers".to_string(), test_lookup_info());
427        let rule = LookupJoinRewriteRule::new(lookup_tables);
428
429        let transformed = plan
430            .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
431            .unwrap();
432
433        assert!(transformed.transformed);
434        let debug_str = format!("{:?}", transformed.data);
435        assert!(
436            debug_str.contains("LeftOuter"),
437            "Expected LeftOuter join type, got: {debug_str}"
438        );
439    }
440
441    #[test]
442    fn test_fmt_display_connector_type() {
443        assert_eq!(ConnectorType::PostgresCdc.to_string(), "postgres-cdc");
444        assert_eq!(ConnectorType::Redis.to_string(), "redis");
445        assert_eq!(
446            ConnectorType::Custom("my-conn".into()).to_string(),
447            "my-conn"
448        );
449    }
450
451    #[test]
452    fn test_fmt_display_strategy() {
453        assert_eq!(LookupStrategy::Replicated.to_string(), "replicated");
454        assert_eq!(LookupStrategy::OnDemand.to_string(), "on-demand");
455    }
456
457    #[test]
458    fn test_fmt_display_pushdown_mode() {
459        assert_eq!(PushdownMode::Auto.to_string(), "auto");
460        assert_eq!(PushdownMode::Disabled.to_string(), "disabled");
461    }
462}