Skip to main content

laminar_sql/datafusion/
lookup_join.rs

1//! `LookupJoinNode` — custom DataFusion logical plan node for lookup joins.
2//!
3//! This node represents a join between a streaming input and a registered
4//! lookup table. It is produced by the `LookupJoinRewriteRule` optimizer
5//! rule when a standard JOIN references a registered lookup table.
6
7#[allow(clippy::disallowed_types)] // cold path: DataFusion integration
8use std::collections::HashSet;
9use std::fmt;
10use std::hash::{Hash, Hasher};
11use std::sync::Arc;
12
13use datafusion::common::DFSchemaRef;
14use datafusion::logical_expr::logical_plan::LogicalPlan;
15use datafusion::logical_expr::{Expr, UserDefinedLogicalNodeCore};
16use datafusion_common::Result;
17
18/// Join type for lookup joins.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
20pub enum LookupJoinType {
21    /// Inner join — only emit rows with a match.
22    Inner,
23    /// Left outer join — emit all stream rows, NULLs for non-matches.
24    LeftOuter,
25}
26
27impl fmt::Display for LookupJoinType {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        match self {
30            Self::Inner => write!(f, "Inner"),
31            Self::LeftOuter => write!(f, "LeftOuter"),
32        }
33    }
34}
35
36/// A pair of expressions defining how stream keys map to lookup columns.
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38pub struct JoinKeyPair {
39    /// Expression on the stream side (e.g., `stream.customer_id`).
40    pub stream_expr: Expr,
41    /// Column name on the lookup table side.
42    pub lookup_column: String,
43}
44
45/// Metadata about a lookup table for plan construction.
46#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
47pub struct LookupTableMetadata {
48    /// Connector type (e.g., "postgres-cdc").
49    pub connector: String,
50    /// Lookup strategy (e.g., "replicated").
51    pub strategy: String,
52    /// Pushdown mode (e.g., "auto").
53    pub pushdown_mode: String,
54    /// Primary key column names.
55    pub primary_key: Vec<String>,
56}
57
58/// Custom logical plan node for a lookup join.
59///
60/// Represents a join between a streaming input plan and a lookup table.
61/// The lookup table is not a DataFusion table; it is resolved at execution
62/// time via the lookup source connector.
63#[derive(Debug, Clone)]
64pub struct LookupJoinNode {
65    /// The streaming input plan.
66    input: Arc<LogicalPlan>,
67    /// Name of the lookup table.
68    lookup_table: String,
69    /// Schema of the lookup table columns.
70    lookup_schema: DFSchemaRef,
71    /// Join key pairs (stream expression -> lookup column).
72    join_keys: Vec<JoinKeyPair>,
73    /// Join type (Inner or LeftOuter).
74    join_type: LookupJoinType,
75    /// Predicates to push down to the lookup source.
76    pushdown_predicates: Vec<Expr>,
77    /// Predicates evaluated locally after the join.
78    local_predicates: Vec<Expr>,
79    /// Required columns from the lookup table.
80    required_lookup_columns: HashSet<String>,
81    /// Combined output schema (stream + lookup columns).
82    output_schema: DFSchemaRef,
83    /// Metadata about the lookup table.
84    metadata: LookupTableMetadata,
85    /// Alias for the lookup table (for qualified column resolution).
86    lookup_alias: Option<String>,
87    /// Alias for the stream input (for qualified column resolution).
88    stream_alias: Option<String>,
89}
90
91impl PartialEq for LookupJoinNode {
92    fn eq(&self, other: &Self) -> bool {
93        self.lookup_table == other.lookup_table
94            && self.join_keys == other.join_keys
95            && self.join_type == other.join_type
96            && self.pushdown_predicates == other.pushdown_predicates
97            && self.local_predicates == other.local_predicates
98            && self.required_lookup_columns == other.required_lookup_columns
99            && self.metadata == other.metadata
100    }
101}
102
103impl Eq for LookupJoinNode {}
104
105impl Hash for LookupJoinNode {
106    fn hash<H: Hasher>(&self, state: &mut H) {
107        self.lookup_table.hash(state);
108        self.join_keys.hash(state);
109        self.join_type.hash(state);
110        self.pushdown_predicates.hash(state);
111        self.local_predicates.hash(state);
112        self.metadata.hash(state);
113        // HashSet doesn't implement Hash; hash sorted elements instead
114        let mut cols: Vec<&String> = self.required_lookup_columns.iter().collect();
115        cols.sort();
116        cols.hash(state);
117    }
118}
119
120impl PartialOrd for LookupJoinNode {
121    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
122        self.lookup_table.partial_cmp(&other.lookup_table)
123    }
124}
125
126impl LookupJoinNode {
127    /// Creates a new lookup join node.
128    #[must_use]
129    #[allow(clippy::too_many_arguments)]
130    pub fn new(
131        input: LogicalPlan,
132        lookup_table: String,
133        lookup_schema: DFSchemaRef,
134        join_keys: Vec<JoinKeyPair>,
135        join_type: LookupJoinType,
136        pushdown_predicates: Vec<Expr>,
137        required_lookup_columns: HashSet<String>,
138        output_schema: DFSchemaRef,
139        metadata: LookupTableMetadata,
140    ) -> Self {
141        Self {
142            input: Arc::new(input),
143            lookup_table,
144            lookup_schema,
145            join_keys,
146            join_type,
147            pushdown_predicates,
148            local_predicates: vec![],
149            required_lookup_columns,
150            output_schema,
151            metadata,
152            lookup_alias: None,
153            stream_alias: None,
154        }
155    }
156
157    /// Sets predicates to be evaluated locally after the join.
158    #[must_use]
159    pub fn with_local_predicates(mut self, predicates: Vec<Expr>) -> Self {
160        self.local_predicates = predicates;
161        self
162    }
163
164    /// Sets table aliases for qualified column resolution.
165    #[must_use]
166    pub fn with_aliases(
167        mut self,
168        lookup_alias: Option<String>,
169        stream_alias: Option<String>,
170    ) -> Self {
171        self.lookup_alias = lookup_alias;
172        self.stream_alias = stream_alias;
173        self
174    }
175
176    /// Returns the lookup table name.
177    #[must_use]
178    pub fn lookup_table_name(&self) -> &str {
179        &self.lookup_table
180    }
181
182    /// Returns the join key pairs.
183    #[must_use]
184    pub fn join_keys(&self) -> &[JoinKeyPair] {
185        &self.join_keys
186    }
187
188    /// Returns the join type.
189    #[must_use]
190    pub fn join_type(&self) -> LookupJoinType {
191        self.join_type
192    }
193
194    /// Returns the pushdown predicates.
195    #[must_use]
196    pub fn pushdown_predicates(&self) -> &[Expr] {
197        &self.pushdown_predicates
198    }
199
200    /// Returns the required lookup columns.
201    #[must_use]
202    pub fn required_lookup_columns(&self) -> &HashSet<String> {
203        &self.required_lookup_columns
204    }
205
206    /// Returns the lookup table metadata.
207    #[must_use]
208    pub fn metadata(&self) -> &LookupTableMetadata {
209        &self.metadata
210    }
211
212    /// Returns the lookup table schema.
213    #[must_use]
214    pub fn lookup_schema(&self) -> &DFSchemaRef {
215        &self.lookup_schema
216    }
217
218    /// Returns the local predicates (evaluated after the join).
219    #[must_use]
220    pub fn local_predicates(&self) -> &[Expr] {
221        &self.local_predicates
222    }
223
224    /// Returns the lookup table alias.
225    #[must_use]
226    pub fn lookup_alias(&self) -> Option<&str> {
227        self.lookup_alias.as_deref()
228    }
229
230    /// Returns the stream input alias.
231    #[must_use]
232    pub fn stream_alias(&self) -> Option<&str> {
233        self.stream_alias.as_deref()
234    }
235}
236
237impl UserDefinedLogicalNodeCore for LookupJoinNode {
238    fn name(&self) -> &'static str {
239        "LookupJoin"
240    }
241
242    fn inputs(&self) -> Vec<&LogicalPlan> {
243        vec![&self.input]
244    }
245
246    fn schema(&self) -> &DFSchemaRef {
247        &self.output_schema
248    }
249
250    fn expressions(&self) -> Vec<Expr> {
251        self.join_keys
252            .iter()
253            .map(|k| k.stream_expr.clone())
254            .chain(self.pushdown_predicates.clone())
255            .chain(self.local_predicates.clone())
256            .collect()
257    }
258
259    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {
260        let keys: Vec<String> = self
261            .join_keys
262            .iter()
263            .map(|k| format!("{}={}", k.stream_expr, k.lookup_column))
264            .collect();
265        write!(
266            f,
267            "LookupJoin: table={}, keys=[{}], type={}, pushdown={}, local={}",
268            self.lookup_table,
269            keys.join(", "),
270            self.join_type,
271            self.pushdown_predicates.len(),
272            self.local_predicates.len(),
273        )
274    }
275
276    fn with_exprs_and_inputs(
277        &self,
278        exprs: Vec<Expr>,
279        mut inputs: Vec<LogicalPlan>,
280    ) -> Result<Self> {
281        let input = inputs.swap_remove(0);
282
283        // Split expressions: keys | pushdown predicates | local predicates
284        let num_keys = self.join_keys.len();
285        let num_pushdown = self.pushdown_predicates.len();
286        let (key_exprs, rest) = exprs.split_at(num_keys.min(exprs.len()));
287        let (pushdown_exprs, local_exprs) = rest.split_at(num_pushdown.min(rest.len()));
288
289        let join_keys: Vec<JoinKeyPair> = key_exprs
290            .iter()
291            .zip(self.join_keys.iter())
292            .map(|(expr, old)| JoinKeyPair {
293                stream_expr: expr.clone(),
294                lookup_column: old.lookup_column.clone(),
295            })
296            .collect();
297
298        Ok(Self {
299            input: Arc::new(input),
300            lookup_table: self.lookup_table.clone(),
301            lookup_schema: Arc::clone(&self.lookup_schema),
302            join_keys,
303            join_type: self.join_type,
304            pushdown_predicates: pushdown_exprs.to_vec(),
305            local_predicates: local_exprs.to_vec(),
306            required_lookup_columns: self.required_lookup_columns.clone(),
307            output_schema: Arc::clone(&self.output_schema),
308            metadata: self.metadata.clone(),
309            lookup_alias: self.lookup_alias.clone(),
310            stream_alias: self.stream_alias.clone(),
311        })
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318    use std::fmt::Write;
319
320    use arrow::datatypes::{DataType, Field, Schema};
321    use datafusion::common::DFSchema;
322    use datafusion::logical_expr::col;
323
324    fn test_stream_schema() -> DFSchemaRef {
325        Arc::new(
326            DFSchema::try_from(Schema::new(vec![
327                Field::new("order_id", DataType::Int64, false),
328                Field::new("customer_id", DataType::Int64, false),
329                Field::new("amount", DataType::Float64, false),
330            ]))
331            .unwrap(),
332        )
333    }
334
335    fn test_lookup_schema() -> DFSchemaRef {
336        Arc::new(
337            DFSchema::try_from(Schema::new(vec![
338                Field::new("id", DataType::Int64, false),
339                Field::new("name", DataType::Utf8, true),
340                Field::new("region", DataType::Utf8, true),
341            ]))
342            .unwrap(),
343        )
344    }
345
346    fn test_output_schema() -> DFSchemaRef {
347        Arc::new(
348            DFSchema::try_from(Schema::new(vec![
349                Field::new("order_id", DataType::Int64, false),
350                Field::new("customer_id", DataType::Int64, false),
351                Field::new("amount", DataType::Float64, false),
352                Field::new("id", DataType::Int64, false),
353                Field::new("name", DataType::Utf8, true),
354                Field::new("region", DataType::Utf8, true),
355            ]))
356            .unwrap(),
357        )
358    }
359
360    fn test_metadata() -> LookupTableMetadata {
361        LookupTableMetadata {
362            connector: "postgres-cdc".to_string(),
363            strategy: "replicated".to_string(),
364            pushdown_mode: "auto".to_string(),
365            primary_key: vec!["id".to_string()],
366        }
367    }
368
369    fn test_node() -> LookupJoinNode {
370        let stream_schema = test_stream_schema();
371        let input = LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
372            produce_one_row: false,
373            schema: stream_schema,
374        });
375
376        LookupJoinNode::new(
377            input,
378            "customers".to_string(),
379            test_lookup_schema(),
380            vec![JoinKeyPair {
381                stream_expr: col("customer_id"),
382                lookup_column: "id".to_string(),
383            }],
384            LookupJoinType::Inner,
385            vec![],
386            HashSet::from(["name".to_string(), "region".to_string()]),
387            test_output_schema(),
388            test_metadata(),
389        )
390    }
391
392    #[test]
393    fn test_name() {
394        let node = test_node();
395        assert_eq!(node.name(), "LookupJoin");
396    }
397
398    #[test]
399    fn test_inputs() {
400        let node = test_node();
401        assert_eq!(node.inputs().len(), 1);
402    }
403
404    #[test]
405    fn test_schema() {
406        let node = test_node();
407        assert_eq!(node.schema().fields().len(), 6);
408    }
409
410    #[test]
411    fn test_expressions() {
412        let node = test_node();
413        let exprs = node.expressions();
414        assert_eq!(exprs.len(), 1); // one join key, no pushdown predicates
415    }
416
417    #[test]
418    fn test_fmt_for_explain() {
419        let node = test_node();
420        let explain = format!("{node:?}");
421        assert!(explain.contains("LookupJoin"));
422
423        // Test the Display-like explain output
424        let mut buf = String::new();
425        write!(buf, "{}", DisplayExplain(&node)).unwrap();
426        assert!(buf.contains("LookupJoin: table=customers"));
427        assert!(buf.contains("type=Inner"));
428    }
429
430    #[test]
431    fn test_with_exprs_and_inputs_roundtrip() {
432        let node = test_node();
433        let exprs = node.expressions();
434        let inputs: Vec<LogicalPlan> = node.inputs().into_iter().cloned().collect();
435
436        let rebuilt = node.with_exprs_and_inputs(exprs, inputs).unwrap();
437        assert_eq!(rebuilt.lookup_table, "customers");
438        assert_eq!(rebuilt.join_keys.len(), 1);
439        assert_eq!(rebuilt.join_type, LookupJoinType::Inner);
440    }
441
442    #[test]
443    fn test_left_outer_join() {
444        let stream_schema = test_stream_schema();
445        let input = LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
446            produce_one_row: false,
447            schema: stream_schema,
448        });
449
450        let node = LookupJoinNode::new(
451            input,
452            "customers".to_string(),
453            test_lookup_schema(),
454            vec![JoinKeyPair {
455                stream_expr: col("customer_id"),
456                lookup_column: "id".to_string(),
457            }],
458            LookupJoinType::LeftOuter,
459            vec![],
460            HashSet::new(),
461            test_output_schema(),
462            test_metadata(),
463        );
464
465        assert_eq!(node.join_type(), LookupJoinType::LeftOuter);
466    }
467
468    /// Helper to test `fmt_for_explain` through the trait method.
469    struct DisplayExplain<'a>(&'a LookupJoinNode);
470
471    impl fmt::Display for DisplayExplain<'_> {
472        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
473            UserDefinedLogicalNodeCore::fmt_for_explain(self.0, f)
474        }
475    }
476}