1#[allow(clippy::disallowed_types)] use std::collections::{HashMap, HashSet};
9use std::fmt;
10use std::sync::Arc;
11
12use datafusion::common::Result;
13use datafusion::logical_expr::logical_plan::LogicalPlan;
14use datafusion::logical_expr::{Extension, Join, TableScan, UserDefinedLogicalNodeCore};
15use datafusion_common::tree_node::Transformed;
16use datafusion_optimizer::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
17
18use crate::datafusion::lookup_join::{
19 JoinKeyPair, LookupJoinNode, LookupJoinType, LookupTableMetadata,
20};
21use crate::planner::LookupTableInfo;
22
23#[derive(Debug)]
26pub struct LookupJoinRewriteRule {
27 lookup_tables: HashMap<String, LookupTableInfo>,
29}
30
31impl LookupJoinRewriteRule {
32 #[must_use]
34 pub fn new(lookup_tables: HashMap<String, LookupTableInfo>) -> Self {
35 Self { lookup_tables }
36 }
37
38 fn detect_lookup_side(&self, join: &Join) -> Option<(bool, String)> {
41 if let Some(name) = scan_table_name(&join.right) {
43 if self.lookup_tables.contains_key(&name) {
44 return Some((true, name));
45 }
46 }
47 if let Some(name) = scan_table_name(&join.left) {
49 if self.lookup_tables.contains_key(&name) {
50 return Some((false, name));
51 }
52 }
53 None
54 }
55}
56
57impl OptimizerRule for LookupJoinRewriteRule {
58 fn name(&self) -> &'static str {
59 "lookup_join_rewrite"
60 }
61
62 fn apply_order(&self) -> Option<ApplyOrder> {
63 Some(ApplyOrder::BottomUp)
64 }
65
66 fn rewrite(
67 &self,
68 plan: LogicalPlan,
69 _config: &dyn OptimizerConfig,
70 ) -> Result<Transformed<LogicalPlan>> {
71 let LogicalPlan::Join(join) = &plan else {
72 return Ok(Transformed::no(plan));
73 };
74
75 let Some((lookup_is_right, table_name)) = self.detect_lookup_side(join) else {
76 return Ok(Transformed::no(plan));
77 };
78
79 let info = &self.lookup_tables[&table_name];
80
81 let (stream_plan, lookup_plan) = if lookup_is_right {
83 (join.left.as_ref(), join.right.as_ref())
84 } else {
85 (join.right.as_ref(), join.left.as_ref())
86 };
87
88 let stream_alias = scan_table_name_and_alias(stream_plan).and_then(|(_, a)| a);
90 let lookup_alias = scan_table_name_and_alias(lookup_plan).and_then(|(_, a)| a);
91
92 let lookup_schema = lookup_plan.schema().clone();
93
94 let join_keys: Vec<JoinKeyPair> = join
96 .on
97 .iter()
98 .map(|(left_expr, right_expr)| {
99 let lookup_expr = if lookup_is_right {
100 right_expr
101 } else {
102 left_expr
103 };
104 let stream_expr = if lookup_is_right {
105 left_expr
106 } else {
107 right_expr
108 };
109 let lookup_column = match lookup_expr {
110 datafusion::logical_expr::Expr::Column(col) => col.name.clone(),
111 other => other.to_string(),
112 };
113 JoinKeyPair {
114 stream_expr: stream_expr.clone(),
115 lookup_column,
116 }
117 })
118 .collect();
119
120 let join_type = match join.join_type {
122 datafusion::logical_expr::JoinType::Inner => LookupJoinType::Inner,
123 datafusion::logical_expr::JoinType::Left if lookup_is_right => {
124 LookupJoinType::LeftOuter
125 }
126 datafusion::logical_expr::JoinType::Right if !lookup_is_right => {
127 LookupJoinType::LeftOuter
128 }
129 _ => return Ok(Transformed::no(plan)),
130 };
131
132 let required_columns: HashSet<String> = lookup_schema
134 .fields()
135 .iter()
136 .map(|f| f.name().clone())
137 .collect();
138
139 let stream_schema = stream_plan.schema();
141 let output_schema = Arc::new(stream_schema.join(lookup_schema.as_ref())?);
142
143 let metadata = LookupTableMetadata {
144 connector: info.properties.connector.to_string(),
145 strategy: info.properties.strategy.to_string(),
146 pushdown_mode: info.properties.pushdown_mode.to_string(),
147 primary_key: info.primary_key.clone(),
148 };
149
150 let node = LookupJoinNode::new(
151 stream_plan.clone(),
152 table_name,
153 lookup_schema,
154 join_keys,
155 join_type,
156 vec![], required_columns,
158 output_schema,
159 metadata,
160 )
161 .with_aliases(lookup_alias, stream_alias);
162
163 Ok(Transformed::yes(LogicalPlan::Extension(Extension {
164 node: Arc::new(node),
165 })))
166 }
167}
168
169#[derive(Debug)]
174pub struct LookupColumnPruningRule;
175
176impl OptimizerRule for LookupColumnPruningRule {
177 fn name(&self) -> &'static str {
178 "lookup_column_pruning"
179 }
180
181 fn apply_order(&self) -> Option<ApplyOrder> {
182 Some(ApplyOrder::TopDown)
183 }
184
185 fn rewrite(
186 &self,
187 plan: LogicalPlan,
188 _config: &dyn OptimizerConfig,
189 ) -> Result<Transformed<LogicalPlan>> {
190 let LogicalPlan::Extension(ext) = &plan else {
191 return Ok(Transformed::no(plan));
192 };
193
194 let Some(node) = ext.node.as_any().downcast_ref::<LookupJoinNode>() else {
195 return Ok(Transformed::no(plan));
196 };
197
198 let schema = UserDefinedLogicalNodeCore::schema(node);
203 let used: HashSet<String> = schema
204 .fields()
205 .iter()
206 .filter(|f| node.required_lookup_columns().contains(f.name()))
207 .map(|f| f.name().clone())
208 .collect();
209
210 if used == *node.required_lookup_columns() {
211 return Ok(Transformed::no(plan));
212 }
213
214 let node_inputs = UserDefinedLogicalNodeCore::inputs(node);
216 let pruned = LookupJoinNode::new(
217 node_inputs[0].clone(),
218 node.lookup_table_name().to_string(),
219 node.lookup_schema().clone(),
220 node.join_keys().to_vec(),
221 node.join_type(),
222 node.pushdown_predicates().to_vec(),
223 used,
224 schema.clone(),
225 node.metadata().clone(),
226 )
227 .with_local_predicates(node.local_predicates().to_vec())
228 .with_aliases(
229 node.lookup_alias().map(String::from),
230 node.stream_alias().map(String::from),
231 );
232
233 Ok(Transformed::yes(LogicalPlan::Extension(Extension {
234 node: Arc::new(pruned),
235 })))
236 }
237}
238
239fn scan_table_name_and_alias(plan: &LogicalPlan) -> Option<(String, Option<String>)> {
244 match plan {
245 LogicalPlan::TableScan(TableScan { table_name, .. }) => {
246 Some((table_name.table().to_string(), None))
247 }
248 LogicalPlan::SubqueryAlias(alias) => {
249 let alias_name = alias.alias.table().to_string();
250 scan_table_name_and_alias(&alias.input).map(|(base, _)| (base, Some(alias_name)))
251 }
252 _ => None,
253 }
254}
255
256fn scan_table_name(plan: &LogicalPlan) -> Option<String> {
258 scan_table_name_and_alias(plan).map(|(name, _)| name)
259}
260
261impl fmt::Display for crate::parser::lookup_table::ConnectorType {
263 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
264 match self {
265 Self::Postgres => write!(f, "postgres"),
266 Self::PostgresCdc => write!(f, "postgres-cdc"),
267 Self::MysqlCdc => write!(f, "mysql-cdc"),
268 Self::Redis => write!(f, "redis"),
269 Self::S3Parquet => write!(f, "s3-parquet"),
270 Self::DeltaLake => write!(f, "delta-lake"),
271 Self::Static => write!(f, "static"),
272 Self::Custom(s) => write!(f, "{s}"),
273 }
274 }
275}
276
277impl fmt::Display for crate::parser::lookup_table::LookupStrategy {
278 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
279 match self {
280 Self::Replicated => write!(f, "replicated"),
281 Self::Partitioned => write!(f, "partitioned"),
282 Self::OnDemand => write!(f, "on-demand"),
283 }
284 }
285}
286
287impl fmt::Display for crate::parser::lookup_table::PushdownMode {
288 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
289 match self {
290 Self::Auto => write!(f, "auto"),
291 Self::Enabled => write!(f, "enabled"),
292 Self::Disabled => write!(f, "disabled"),
293 }
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use crate::datafusion::create_session_context;
301 use crate::parser::lookup_table::{
302 ByteSize, ConnectorType, LookupStrategy, LookupTableProperties, PushdownMode,
303 };
304 use arrow::datatypes::{DataType, Field, Schema};
305 use datafusion::prelude::SessionContext;
306 use datafusion_common::tree_node::TreeNode;
307 use datafusion_optimizer::optimizer::OptimizerContext;
308
309 fn test_lookup_info() -> LookupTableInfo {
310 let arrow_schema = Arc::new(Schema::new(vec![
311 Field::new("id", DataType::Int32, false),
312 Field::new("name", DataType::Utf8, true),
313 ]));
314 LookupTableInfo {
315 name: "customers".to_string(),
316 columns: vec![
317 ("id".to_string(), "INT".to_string()),
318 ("name".to_string(), "VARCHAR".to_string()),
319 ],
320 primary_key: vec!["id".to_string()],
321 properties: LookupTableProperties {
322 connector: ConnectorType::PostgresCdc,
323 connection: Some("postgresql://localhost/db".to_string()),
324 strategy: LookupStrategy::Replicated,
325 cache_memory: Some(ByteSize(512 * 1024 * 1024)),
326 cache_disk: None,
327 cache_ttl: None,
328 pushdown_mode: PushdownMode::Auto,
329 },
330 arrow_schema,
331 #[allow(clippy::disallowed_types)] raw_options: std::collections::HashMap::new(),
333 }
334 }
335
336 fn register_test_tables(ctx: &SessionContext) {
337 let orders_schema = Arc::new(Schema::new(vec![
338 Field::new("order_id", DataType::Int64, false),
339 Field::new("customer_id", DataType::Int64, false),
340 Field::new("amount", DataType::Float64, false),
341 ]));
342 let customers_schema = Arc::new(Schema::new(vec![
343 Field::new("id", DataType::Int64, false),
344 Field::new("name", DataType::Utf8, true),
345 ]));
346 ctx.register_batch(
347 "orders",
348 arrow::array::RecordBatch::new_empty(orders_schema),
349 )
350 .unwrap();
351 ctx.register_batch(
352 "customers",
353 arrow::array::RecordBatch::new_empty(customers_schema),
354 )
355 .unwrap();
356 }
357
358 #[tokio::test]
359 async fn test_rewrite_join_on_lookup_table() {
360 let ctx = create_session_context();
361 register_test_tables(&ctx);
362
363 let plan = ctx
364 .sql("SELECT o.order_id, c.name FROM orders o JOIN customers c ON o.customer_id = c.id")
365 .await
366 .unwrap()
367 .into_unoptimized_plan();
368
369 let mut lookup_tables = HashMap::new();
370 lookup_tables.insert("customers".to_string(), test_lookup_info());
371 let rule = LookupJoinRewriteRule::new(lookup_tables);
372
373 let transformed = plan
374 .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
375 .unwrap();
376
377 assert!(transformed.transformed);
379 let has_lookup = format!("{:?}", transformed.data).contains("LookupJoin");
380 assert!(has_lookup, "Expected LookupJoin in plan");
381 }
382
383 #[tokio::test]
384 async fn test_non_lookup_join_not_rewritten() {
385 let ctx = create_session_context();
386 let schema_a = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
388 let schema_b = Arc::new(Schema::new(vec![Field::new(
389 "a_id",
390 DataType::Int64,
391 false,
392 )]));
393 ctx.register_batch("a", arrow::array::RecordBatch::new_empty(schema_a))
394 .unwrap();
395 ctx.register_batch("b", arrow::array::RecordBatch::new_empty(schema_b))
396 .unwrap();
397
398 let plan = ctx
399 .sql("SELECT * FROM a JOIN b ON a.id = b.a_id")
400 .await
401 .unwrap()
402 .into_unoptimized_plan();
403
404 let rule = LookupJoinRewriteRule::new(HashMap::new());
406
407 let transformed = plan
408 .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
409 .unwrap();
410
411 assert!(!transformed.transformed);
412 }
413
414 #[tokio::test]
415 async fn test_left_outer_produces_left_outer_type() {
416 let ctx = create_session_context();
417 register_test_tables(&ctx);
418
419 let plan = ctx
420 .sql("SELECT o.order_id, c.name FROM orders o LEFT JOIN customers c ON o.customer_id = c.id")
421 .await
422 .unwrap()
423 .into_unoptimized_plan();
424
425 let mut lookup_tables = HashMap::new();
426 lookup_tables.insert("customers".to_string(), test_lookup_info());
427 let rule = LookupJoinRewriteRule::new(lookup_tables);
428
429 let transformed = plan
430 .transform_down(|p| rule.rewrite(p, &OptimizerContext::new()))
431 .unwrap();
432
433 assert!(transformed.transformed);
434 let debug_str = format!("{:?}", transformed.data);
435 assert!(
436 debug_str.contains("LeftOuter"),
437 "Expected LeftOuter join type, got: {debug_str}"
438 );
439 }
440
441 #[test]
442 fn test_fmt_display_connector_type() {
443 assert_eq!(ConnectorType::PostgresCdc.to_string(), "postgres-cdc");
444 assert_eq!(ConnectorType::Redis.to_string(), "redis");
445 assert_eq!(
446 ConnectorType::Custom("my-conn".into()).to_string(),
447 "my-conn"
448 );
449 }
450
451 #[test]
452 fn test_fmt_display_strategy() {
453 assert_eq!(LookupStrategy::Replicated.to_string(), "replicated");
454 assert_eq!(LookupStrategy::OnDemand.to_string(), "on-demand");
455 }
456
457 #[test]
458 fn test_fmt_display_pushdown_mode() {
459 assert_eq!(PushdownMode::Auto.to_string(), "auto");
460 assert_eq!(PushdownMode::Disabled.to_string(), "disabled");
461 }
462}