1#[allow(clippy::disallowed_types)] use std::collections::HashSet;
9use std::fmt;
10use std::hash::{Hash, Hasher};
11use std::sync::Arc;
12
13use datafusion::common::DFSchemaRef;
14use datafusion::logical_expr::logical_plan::LogicalPlan;
15use datafusion::logical_expr::{Expr, UserDefinedLogicalNodeCore};
16use datafusion_common::Result;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
20pub enum LookupJoinType {
21 Inner,
23 LeftOuter,
25}
26
27impl fmt::Display for LookupJoinType {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 match self {
30 Self::Inner => write!(f, "Inner"),
31 Self::LeftOuter => write!(f, "LeftOuter"),
32 }
33 }
34}
35
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38pub struct JoinKeyPair {
39 pub stream_expr: Expr,
41 pub lookup_column: String,
43}
44
45#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
47pub struct LookupTableMetadata {
48 pub connector: String,
50 pub strategy: String,
52 pub pushdown_mode: String,
54 pub primary_key: Vec<String>,
56}
57
58#[derive(Debug, Clone)]
64pub struct LookupJoinNode {
65 input: Arc<LogicalPlan>,
67 lookup_table: String,
69 lookup_schema: DFSchemaRef,
71 join_keys: Vec<JoinKeyPair>,
73 join_type: LookupJoinType,
75 pushdown_predicates: Vec<Expr>,
77 local_predicates: Vec<Expr>,
79 required_lookup_columns: HashSet<String>,
81 output_schema: DFSchemaRef,
83 metadata: LookupTableMetadata,
85 lookup_alias: Option<String>,
87 stream_alias: Option<String>,
89}
90
91impl PartialEq for LookupJoinNode {
92 fn eq(&self, other: &Self) -> bool {
93 self.lookup_table == other.lookup_table
94 && self.join_keys == other.join_keys
95 && self.join_type == other.join_type
96 && self.pushdown_predicates == other.pushdown_predicates
97 && self.local_predicates == other.local_predicates
98 && self.required_lookup_columns == other.required_lookup_columns
99 && self.metadata == other.metadata
100 }
101}
102
103impl Eq for LookupJoinNode {}
104
105impl Hash for LookupJoinNode {
106 fn hash<H: Hasher>(&self, state: &mut H) {
107 self.lookup_table.hash(state);
108 self.join_keys.hash(state);
109 self.join_type.hash(state);
110 self.pushdown_predicates.hash(state);
111 self.local_predicates.hash(state);
112 self.metadata.hash(state);
113 let mut cols: Vec<&String> = self.required_lookup_columns.iter().collect();
115 cols.sort();
116 cols.hash(state);
117 }
118}
119
120impl PartialOrd for LookupJoinNode {
121 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
122 self.lookup_table.partial_cmp(&other.lookup_table)
123 }
124}
125
126impl LookupJoinNode {
127 #[must_use]
129 #[allow(clippy::too_many_arguments)]
130 pub fn new(
131 input: LogicalPlan,
132 lookup_table: String,
133 lookup_schema: DFSchemaRef,
134 join_keys: Vec<JoinKeyPair>,
135 join_type: LookupJoinType,
136 pushdown_predicates: Vec<Expr>,
137 required_lookup_columns: HashSet<String>,
138 output_schema: DFSchemaRef,
139 metadata: LookupTableMetadata,
140 ) -> Self {
141 Self {
142 input: Arc::new(input),
143 lookup_table,
144 lookup_schema,
145 join_keys,
146 join_type,
147 pushdown_predicates,
148 local_predicates: vec![],
149 required_lookup_columns,
150 output_schema,
151 metadata,
152 lookup_alias: None,
153 stream_alias: None,
154 }
155 }
156
157 #[must_use]
159 pub fn with_local_predicates(mut self, predicates: Vec<Expr>) -> Self {
160 self.local_predicates = predicates;
161 self
162 }
163
164 #[must_use]
166 pub fn with_aliases(
167 mut self,
168 lookup_alias: Option<String>,
169 stream_alias: Option<String>,
170 ) -> Self {
171 self.lookup_alias = lookup_alias;
172 self.stream_alias = stream_alias;
173 self
174 }
175
176 #[must_use]
178 pub fn lookup_table_name(&self) -> &str {
179 &self.lookup_table
180 }
181
182 #[must_use]
184 pub fn join_keys(&self) -> &[JoinKeyPair] {
185 &self.join_keys
186 }
187
188 #[must_use]
190 pub fn join_type(&self) -> LookupJoinType {
191 self.join_type
192 }
193
194 #[must_use]
196 pub fn pushdown_predicates(&self) -> &[Expr] {
197 &self.pushdown_predicates
198 }
199
200 #[must_use]
202 pub fn required_lookup_columns(&self) -> &HashSet<String> {
203 &self.required_lookup_columns
204 }
205
206 #[must_use]
208 pub fn metadata(&self) -> &LookupTableMetadata {
209 &self.metadata
210 }
211
212 #[must_use]
214 pub fn lookup_schema(&self) -> &DFSchemaRef {
215 &self.lookup_schema
216 }
217
218 #[must_use]
220 pub fn local_predicates(&self) -> &[Expr] {
221 &self.local_predicates
222 }
223
224 #[must_use]
226 pub fn lookup_alias(&self) -> Option<&str> {
227 self.lookup_alias.as_deref()
228 }
229
230 #[must_use]
232 pub fn stream_alias(&self) -> Option<&str> {
233 self.stream_alias.as_deref()
234 }
235}
236
237impl UserDefinedLogicalNodeCore for LookupJoinNode {
238 fn name(&self) -> &'static str {
239 "LookupJoin"
240 }
241
242 fn inputs(&self) -> Vec<&LogicalPlan> {
243 vec![&self.input]
244 }
245
246 fn schema(&self) -> &DFSchemaRef {
247 &self.output_schema
248 }
249
250 fn expressions(&self) -> Vec<Expr> {
251 self.join_keys
252 .iter()
253 .map(|k| k.stream_expr.clone())
254 .chain(self.pushdown_predicates.clone())
255 .chain(self.local_predicates.clone())
256 .collect()
257 }
258
259 fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {
260 let keys: Vec<String> = self
261 .join_keys
262 .iter()
263 .map(|k| format!("{}={}", k.stream_expr, k.lookup_column))
264 .collect();
265 write!(
266 f,
267 "LookupJoin: table={}, keys=[{}], type={}, pushdown={}, local={}",
268 self.lookup_table,
269 keys.join(", "),
270 self.join_type,
271 self.pushdown_predicates.len(),
272 self.local_predicates.len(),
273 )
274 }
275
276 fn with_exprs_and_inputs(
277 &self,
278 exprs: Vec<Expr>,
279 mut inputs: Vec<LogicalPlan>,
280 ) -> Result<Self> {
281 let input = inputs.swap_remove(0);
282
283 let num_keys = self.join_keys.len();
285 let num_pushdown = self.pushdown_predicates.len();
286 let (key_exprs, rest) = exprs.split_at(num_keys.min(exprs.len()));
287 let (pushdown_exprs, local_exprs) = rest.split_at(num_pushdown.min(rest.len()));
288
289 let join_keys: Vec<JoinKeyPair> = key_exprs
290 .iter()
291 .zip(self.join_keys.iter())
292 .map(|(expr, old)| JoinKeyPair {
293 stream_expr: expr.clone(),
294 lookup_column: old.lookup_column.clone(),
295 })
296 .collect();
297
298 Ok(Self {
299 input: Arc::new(input),
300 lookup_table: self.lookup_table.clone(),
301 lookup_schema: Arc::clone(&self.lookup_schema),
302 join_keys,
303 join_type: self.join_type,
304 pushdown_predicates: pushdown_exprs.to_vec(),
305 local_predicates: local_exprs.to_vec(),
306 required_lookup_columns: self.required_lookup_columns.clone(),
307 output_schema: Arc::clone(&self.output_schema),
308 metadata: self.metadata.clone(),
309 lookup_alias: self.lookup_alias.clone(),
310 stream_alias: self.stream_alias.clone(),
311 })
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318 use std::fmt::Write;
319
320 use arrow::datatypes::{DataType, Field, Schema};
321 use datafusion::common::DFSchema;
322 use datafusion::logical_expr::col;
323
324 fn test_stream_schema() -> DFSchemaRef {
325 Arc::new(
326 DFSchema::try_from(Schema::new(vec![
327 Field::new("order_id", DataType::Int64, false),
328 Field::new("customer_id", DataType::Int64, false),
329 Field::new("amount", DataType::Float64, false),
330 ]))
331 .unwrap(),
332 )
333 }
334
335 fn test_lookup_schema() -> DFSchemaRef {
336 Arc::new(
337 DFSchema::try_from(Schema::new(vec![
338 Field::new("id", DataType::Int64, false),
339 Field::new("name", DataType::Utf8, true),
340 Field::new("region", DataType::Utf8, true),
341 ]))
342 .unwrap(),
343 )
344 }
345
346 fn test_output_schema() -> DFSchemaRef {
347 Arc::new(
348 DFSchema::try_from(Schema::new(vec![
349 Field::new("order_id", DataType::Int64, false),
350 Field::new("customer_id", DataType::Int64, false),
351 Field::new("amount", DataType::Float64, false),
352 Field::new("id", DataType::Int64, false),
353 Field::new("name", DataType::Utf8, true),
354 Field::new("region", DataType::Utf8, true),
355 ]))
356 .unwrap(),
357 )
358 }
359
360 fn test_metadata() -> LookupTableMetadata {
361 LookupTableMetadata {
362 connector: "postgres-cdc".to_string(),
363 strategy: "replicated".to_string(),
364 pushdown_mode: "auto".to_string(),
365 primary_key: vec!["id".to_string()],
366 }
367 }
368
369 fn test_node() -> LookupJoinNode {
370 let stream_schema = test_stream_schema();
371 let input = LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
372 produce_one_row: false,
373 schema: stream_schema,
374 });
375
376 LookupJoinNode::new(
377 input,
378 "customers".to_string(),
379 test_lookup_schema(),
380 vec![JoinKeyPair {
381 stream_expr: col("customer_id"),
382 lookup_column: "id".to_string(),
383 }],
384 LookupJoinType::Inner,
385 vec![],
386 HashSet::from(["name".to_string(), "region".to_string()]),
387 test_output_schema(),
388 test_metadata(),
389 )
390 }
391
392 #[test]
393 fn test_name() {
394 let node = test_node();
395 assert_eq!(node.name(), "LookupJoin");
396 }
397
398 #[test]
399 fn test_inputs() {
400 let node = test_node();
401 assert_eq!(node.inputs().len(), 1);
402 }
403
404 #[test]
405 fn test_schema() {
406 let node = test_node();
407 assert_eq!(node.schema().fields().len(), 6);
408 }
409
410 #[test]
411 fn test_expressions() {
412 let node = test_node();
413 let exprs = node.expressions();
414 assert_eq!(exprs.len(), 1); }
416
417 #[test]
418 fn test_fmt_for_explain() {
419 let node = test_node();
420 let explain = format!("{node:?}");
421 assert!(explain.contains("LookupJoin"));
422
423 let mut buf = String::new();
425 write!(buf, "{}", DisplayExplain(&node)).unwrap();
426 assert!(buf.contains("LookupJoin: table=customers"));
427 assert!(buf.contains("type=Inner"));
428 }
429
430 #[test]
431 fn test_with_exprs_and_inputs_roundtrip() {
432 let node = test_node();
433 let exprs = node.expressions();
434 let inputs: Vec<LogicalPlan> = node.inputs().into_iter().cloned().collect();
435
436 let rebuilt = node.with_exprs_and_inputs(exprs, inputs).unwrap();
437 assert_eq!(rebuilt.lookup_table, "customers");
438 assert_eq!(rebuilt.join_keys.len(), 1);
439 assert_eq!(rebuilt.join_type, LookupJoinType::Inner);
440 }
441
442 #[test]
443 fn test_left_outer_join() {
444 let stream_schema = test_stream_schema();
445 let input = LogicalPlan::EmptyRelation(datafusion::logical_expr::EmptyRelation {
446 produce_one_row: false,
447 schema: stream_schema,
448 });
449
450 let node = LookupJoinNode::new(
451 input,
452 "customers".to_string(),
453 test_lookup_schema(),
454 vec![JoinKeyPair {
455 stream_expr: col("customer_id"),
456 lookup_column: "id".to_string(),
457 }],
458 LookupJoinType::LeftOuter,
459 vec![],
460 HashSet::new(),
461 test_output_schema(),
462 test_metadata(),
463 );
464
465 assert_eq!(node.join_type(), LookupJoinType::LeftOuter);
466 }
467
468 struct DisplayExplain<'a>(&'a LookupJoinNode);
470
471 impl fmt::Display for DisplayExplain<'_> {
472 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
473 UserDefinedLogicalNodeCore::fmt_for_explain(self.0, f)
474 }
475 }
476}