1use std::time::Duration;
10
11use sqlparser::ast::{
12 BinaryOperator, Expr, FunctionArg, FunctionArgExpr, FunctionArguments, JoinConstraint,
13 JoinOperator, Select, TableFactor, TableVersion,
14};
15
16use super::window_rewriter::WindowRewriter;
17use super::ParseError;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum JoinType {
22 Inner,
24 Left,
26 Right,
28 Full,
30 LeftSemi,
32 LeftAnti,
34 RightSemi,
36 RightAnti,
38 AsOf,
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum AsofSqlDirection {
45 Backward,
47 Forward,
49 Nearest,
51}
52
53impl std::fmt::Display for AsofSqlDirection {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 match self {
56 AsofSqlDirection::Backward => write!(f, "BACKWARD"),
57 AsofSqlDirection::Forward => write!(f, "FORWARD"),
58 AsofSqlDirection::Nearest => write!(f, "NEAREST"),
59 }
60 }
61}
62
63#[derive(Debug, Clone)]
65struct RawTimeCols {
66 expr_qualifier: Option<String>,
67 expr_col: String,
68 low_qualifier: Option<String>,
69 low_col: String,
70}
71
72fn resolve_time_cols(
75 raw: &RawTimeCols,
76 left_table: &str,
77 right_table: &str,
78 left_alias: Option<&str>,
79 right_alias: Option<&str>,
80) -> (String, String) {
81 let matches_left = |q: &Option<String>| -> bool {
82 q.as_ref()
83 .is_some_and(|t| t == left_table || left_alias.is_some_and(|a| a == t))
84 };
85 let matches_right = |q: &Option<String>| -> bool {
86 q.as_ref()
87 .is_some_and(|t| t == right_table || right_alias.is_some_and(|a| a == t))
88 };
89
90 if matches_right(&raw.expr_qualifier) && matches_left(&raw.low_qualifier) {
91 (raw.low_col.clone(), raw.expr_col.clone())
92 } else if matches_left(&raw.expr_qualifier) && matches_right(&raw.low_qualifier) {
93 (raw.expr_col.clone(), raw.low_col.clone())
94 } else {
95 (raw.low_col.clone(), raw.expr_col.clone())
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct JoinAnalysis {
102 pub join_type: JoinType,
104 pub left_table: String,
106 pub right_table: String,
108 pub left_key_column: String,
110 pub right_key_column: String,
112 pub time_bound: Option<Duration>,
114 pub is_lookup_join: bool,
116 pub left_alias: Option<String>,
118 pub right_alias: Option<String>,
120 pub is_asof_join: bool,
122 pub asof_direction: Option<AsofSqlDirection>,
124 pub left_time_column: Option<String>,
126 pub right_time_column: Option<String>,
128 pub asof_tolerance: Option<Duration>,
130 pub is_temporal_join: bool,
132 pub temporal_version_column: Option<String>,
134 pub additional_key_columns: Vec<(String, String)>,
136}
137
138impl JoinAnalysis {
139 #[must_use]
141 pub fn stream_stream(
142 left_table: String,
143 right_table: String,
144 left_key: String,
145 right_key: String,
146 time_bound: Duration,
147 join_type: JoinType,
148 ) -> Self {
149 Self {
150 join_type,
151 left_table,
152 right_table,
153 left_key_column: left_key,
154 right_key_column: right_key,
155 time_bound: Some(time_bound),
156 is_lookup_join: false,
157 left_alias: None,
158 right_alias: None,
159 is_asof_join: false,
160 asof_direction: None,
161 left_time_column: None,
162 right_time_column: None,
163 asof_tolerance: None,
164 is_temporal_join: false,
165 temporal_version_column: None,
166 additional_key_columns: vec![],
167 }
168 }
169
170 #[must_use]
172 pub fn lookup(
173 left_table: String,
174 right_table: String,
175 left_key: String,
176 right_key: String,
177 join_type: JoinType,
178 ) -> Self {
179 Self {
180 join_type,
181 left_table,
182 right_table,
183 left_key_column: left_key,
184 right_key_column: right_key,
185 time_bound: None,
186 is_lookup_join: true,
187 left_alias: None,
188 right_alias: None,
189 is_asof_join: false,
190 asof_direction: None,
191 left_time_column: None,
192 right_time_column: None,
193 asof_tolerance: None,
194 is_temporal_join: false,
195 temporal_version_column: None,
196 additional_key_columns: vec![],
197 }
198 }
199
200 #[must_use]
202 #[allow(clippy::too_many_arguments)]
203 pub fn asof(
204 left_table: String,
205 right_table: String,
206 left_key: String,
207 right_key: String,
208 direction: AsofSqlDirection,
209 left_time_col: String,
210 right_time_col: String,
211 tolerance: Option<Duration>,
212 ) -> Self {
213 Self {
214 join_type: JoinType::AsOf,
215 left_table,
216 right_table,
217 left_key_column: left_key,
218 right_key_column: right_key,
219 time_bound: None,
220 is_lookup_join: false,
221 left_alias: None,
222 right_alias: None,
223 is_asof_join: true,
224 asof_direction: Some(direction),
225 left_time_column: Some(left_time_col),
226 right_time_column: Some(right_time_col),
227 asof_tolerance: tolerance,
228 is_temporal_join: false,
229 temporal_version_column: None,
230 additional_key_columns: vec![],
231 }
232 }
233
234 #[must_use]
236 pub fn temporal(
237 left_table: String,
238 right_table: String,
239 left_key: String,
240 right_key: String,
241 version_column: String,
242 join_type: JoinType,
243 ) -> Self {
244 Self {
245 join_type,
246 left_table,
247 right_table,
248 left_key_column: left_key,
249 right_key_column: right_key,
250 time_bound: None,
251 is_lookup_join: false,
252 left_alias: None,
253 right_alias: None,
254 is_asof_join: false,
255 asof_direction: None,
256 left_time_column: None,
257 right_time_column: None,
258 asof_tolerance: None,
259 is_temporal_join: true,
260 temporal_version_column: Some(version_column),
261 additional_key_columns: vec![],
262 }
263 }
264
265 #[must_use]
268 pub fn is_bounded(&self) -> bool {
269 self.time_bound.is_some() || self.is_asof_join || self.is_temporal_join
270 }
271}
272
273pub fn analyze_join(select: &Select) -> Result<Option<JoinAnalysis>, ParseError> {
281 let from = &select.from;
282 if from.is_empty() {
283 return Ok(None);
284 }
285
286 let first_table = &from[0];
287 if first_table.joins.is_empty() {
288 return Ok(None);
289 }
290
291 let left_table = extract_table_name(&first_table.relation)?;
293 let left_alias = extract_table_alias(&first_table.relation);
294
295 let join = &first_table.joins[0];
297 let right_table = extract_table_name(&join.relation)?;
298 let right_alias = extract_table_alias(&join.relation);
299
300 let join_type = map_join_operator(&join.join_operator);
301
302 if let JoinOperator::AsOf {
304 match_condition,
305 constraint,
306 } = &join.join_operator
307 {
308 let (direction, left_time, right_time, tolerance) =
309 analyze_asof_match_condition(match_condition)?;
310
311 let (left_key, right_key) = analyze_asof_constraint(constraint)?;
313
314 let mut analysis = JoinAnalysis::asof(
315 left_table,
316 right_table,
317 left_key,
318 right_key,
319 direction,
320 left_time,
321 right_time,
322 tolerance,
323 );
324 analysis.left_alias = left_alias;
325 analysis.right_alias = right_alias;
326 return Ok(Some(analysis));
327 }
328
329 if let Some(version_col) = extract_temporal_version(&join.relation) {
331 let (left_key, right_key, additional, _, _) = analyze_join_constraint(&join.join_operator)?;
332 let mut analysis = JoinAnalysis::temporal(
333 left_table,
334 right_table,
335 left_key,
336 right_key,
337 version_col,
338 join_type,
339 );
340 analysis.left_alias = left_alias;
341 analysis.right_alias = right_alias;
342 analysis.additional_key_columns = additional;
343 return Ok(Some(analysis));
344 }
345
346 let (left_key, right_key, additional, time_bound, time_cols) =
348 analyze_join_constraint(&join.join_operator)?;
349
350 let mut analysis = if let Some(tb) = time_bound {
351 JoinAnalysis::stream_stream(left_table, right_table, left_key, right_key, tb, join_type)
352 } else {
353 JoinAnalysis::lookup(left_table, right_table, left_key, right_key, join_type)
354 };
355
356 analysis.left_alias.clone_from(&left_alias);
357 analysis.right_alias.clone_from(&right_alias);
358 analysis.additional_key_columns = additional;
359
360 if let Some(ref raw) = time_cols {
361 let (lt, rt) = resolve_time_cols(
362 raw,
363 &analysis.left_table,
364 &analysis.right_table,
365 left_alias.as_deref(),
366 right_alias.as_deref(),
367 );
368 analysis.left_time_column = Some(lt);
369 analysis.right_time_column = Some(rt);
370 }
371
372 Ok(Some(analysis))
373}
374
375fn extract_table_name(factor: &TableFactor) -> Result<String, ParseError> {
377 match factor {
378 TableFactor::Table { name, .. } => Ok(name.to_string()),
379 TableFactor::Derived { alias, .. } => {
380 if let Some(alias) = alias {
381 Ok(alias.name.value.clone())
382 } else {
383 Err(ParseError::StreamingError(
384 "Derived table without alias not supported".to_string(),
385 ))
386 }
387 }
388 _ => Err(ParseError::StreamingError(
389 "Unsupported table factor type".to_string(),
390 )),
391 }
392}
393
394fn extract_temporal_version(factor: &TableFactor) -> Option<String> {
399 if let TableFactor::Table {
400 version: Some(TableVersion::ForSystemTimeAsOf(expr)),
401 ..
402 } = factor
403 {
404 Some(extract_column_name_from_expr(expr))
405 } else {
406 None
407 }
408}
409
410fn extract_column_name_from_expr(expr: &Expr) -> String {
414 match expr {
415 Expr::Identifier(ident) => ident.value.clone(),
416 Expr::CompoundIdentifier(parts) => parts
417 .last()
418 .map_or_else(|| expr.to_string(), |p| p.value.clone()),
419 _ => expr.to_string(),
420 }
421}
422
423fn extract_table_alias(factor: &TableFactor) -> Option<String> {
425 match factor {
426 TableFactor::Table { alias, .. } => alias.as_ref().map(|a| a.name.value.clone()),
427 TableFactor::Derived { alias, .. } => alias.as_ref().map(|a| a.name.value.clone()),
428 _ => None,
429 }
430}
431
432fn map_join_operator(op: &JoinOperator) -> JoinType {
434 match op {
435 JoinOperator::Inner(_) | JoinOperator::Join(_) | JoinOperator::StraightJoin(_) => {
436 JoinType::Inner
437 }
438 JoinOperator::Left(_) | JoinOperator::LeftOuter(_) => JoinType::Left,
439 JoinOperator::LeftSemi(_) | JoinOperator::Semi(_) => JoinType::LeftSemi,
440 JoinOperator::LeftAnti(_) | JoinOperator::Anti(_) => JoinType::LeftAnti,
441 JoinOperator::AsOf { .. } => JoinType::AsOf,
442 JoinOperator::Right(_) | JoinOperator::RightOuter(_) => JoinType::Right,
443 JoinOperator::RightSemi(_) => JoinType::RightSemi,
444 JoinOperator::RightAnti(_) => JoinType::RightAnti,
445 JoinOperator::FullOuter(_) => JoinType::Full,
446 _ => JoinType::Inner,
448 }
449}
450
451#[allow(clippy::type_complexity)]
454fn analyze_join_constraint(
455 op: &JoinOperator,
456) -> Result<
457 (
458 String,
459 String,
460 Vec<(String, String)>,
461 Option<Duration>,
462 Option<RawTimeCols>,
463 ),
464 ParseError,
465> {
466 let constraint = get_join_constraint(op)?;
467
468 match constraint {
469 JoinConstraint::On(expr) => {
470 let (key_pairs, time_bound, time_cols) = analyze_on_expression(expr)?;
471 if key_pairs.is_empty() {
472 return Ok((String::new(), String::new(), vec![], time_bound, time_cols));
473 }
474 let (first_left, first_right) = key_pairs[0].clone();
475 let additional = key_pairs[1..].to_vec();
476 Ok((first_left, first_right, additional, time_bound, time_cols))
477 }
478 JoinConstraint::Using(cols) => {
479 if cols.is_empty() {
480 return Err(ParseError::StreamingError(
481 "USING clause requires at least one column".to_string(),
482 ));
483 }
484 let first_col = cols[0].to_string();
486 let additional: Vec<(String, String)> = cols[1..]
488 .iter()
489 .map(|c| {
490 let col = c.to_string();
491 (col.clone(), col)
492 })
493 .collect();
494 Ok((first_col.clone(), first_col, additional, None, None))
495 }
496 JoinConstraint::Natural => Err(ParseError::StreamingError(
497 "NATURAL JOIN not supported for streaming".to_string(),
498 )),
499 JoinConstraint::None => Err(ParseError::StreamingError(
500 "JOIN without condition not supported for streaming".to_string(),
501 )),
502 }
503}
504
505fn get_join_constraint(op: &JoinOperator) -> Result<&JoinConstraint, ParseError> {
507 match op {
508 JoinOperator::Inner(constraint)
509 | JoinOperator::Join(constraint)
510 | JoinOperator::Left(constraint)
511 | JoinOperator::LeftOuter(constraint)
512 | JoinOperator::Right(constraint)
513 | JoinOperator::RightOuter(constraint)
514 | JoinOperator::FullOuter(constraint)
515 | JoinOperator::LeftSemi(constraint)
516 | JoinOperator::RightSemi(constraint)
517 | JoinOperator::LeftAnti(constraint)
518 | JoinOperator::RightAnti(constraint)
519 | JoinOperator::Semi(constraint)
520 | JoinOperator::Anti(constraint)
521 | JoinOperator::StraightJoin(constraint)
522 | JoinOperator::AsOf { constraint, .. } => Ok(constraint),
523 JoinOperator::CrossJoin(_) | JoinOperator::CrossApply | JoinOperator::OuterApply => Err(
524 ParseError::StreamingError("CROSS JOIN not supported for streaming".to_string()),
525 ),
526 }
527}
528
529#[allow(clippy::type_complexity)]
532fn analyze_on_expression(
533 expr: &Expr,
534) -> Result<(Vec<(String, String)>, Option<Duration>, Option<RawTimeCols>), ParseError> {
535 match expr {
537 Expr::BinaryOp {
538 left,
539 op: BinaryOperator::And,
540 right,
541 } => {
542 let left_result = analyze_on_expression(left);
544 let right_result = analyze_on_expression(right);
545
546 match (left_result, right_result) {
548 (Ok((mut lk, lt, ltc)), Ok((rk, rt, rtc))) => {
549 lk.extend(rk);
550 Ok((lk, lt.or(rt), ltc.or(rtc)))
551 }
552 (Ok(result), Err(_)) | (Err(_), Ok(result)) => Ok(result),
553 (Err(e), Err(_)) => Err(e),
554 }
555 }
556 Expr::BinaryOp {
558 left,
559 op: BinaryOperator::Eq,
560 right,
561 } => {
562 let left_col = extract_column_ref(left);
563 let right_col = extract_column_ref(right);
564
565 match (left_col, right_col) {
566 (Some(l), Some(r)) => Ok((vec![(l, r)], None, None)),
567 _ => Err(ParseError::StreamingError(
568 "Cannot extract column references from equality condition".to_string(),
569 )),
570 }
571 }
572 Expr::Between {
574 expr: between_expr,
575 low,
576 high,
577 ..
578 } => {
579 let time_bound = extract_time_bound_from_expr(high).ok();
581 let between_col = extract_qualified_column_ref(between_expr);
582 let low_col = extract_qualified_column_ref(low);
583 let time_cols = if let (Some((bt, bc)), Some((lt, lc))) = (between_col, low_col) {
584 Some(RawTimeCols {
585 expr_qualifier: bt,
586 expr_col: bc,
587 low_qualifier: lt,
588 low_col: lc,
589 })
590 } else {
591 if time_bound.is_some() {
592 tracing::warn!(
593 "BETWEEN clause has time bound but time column references \
594 could not be extracted (expressions must be simple column refs)"
595 );
596 }
597 None
598 };
599 Ok((vec![], time_bound, time_cols))
600 }
601 Expr::BinaryOp {
603 left: _,
604 op:
605 BinaryOperator::LtEq | BinaryOperator::Lt | BinaryOperator::GtEq | BinaryOperator::Gt,
606 right,
607 } => {
608 let time_bound = extract_time_bound_from_expr(right).ok();
610 Ok((vec![], time_bound, None))
611 }
612 _ => Err(ParseError::StreamingError(format!(
613 "Unsupported join condition expression: {expr:?}"
614 ))),
615 }
616}
617
618fn extract_column_from_func_arg(arg: &FunctionArg) -> Option<String> {
620 let (FunctionArg::Unnamed(FunctionArgExpr::Expr(expr))
621 | FunctionArg::Named {
622 arg: FunctionArgExpr::Expr(expr),
623 ..
624 }
625 | FunctionArg::ExprNamed {
626 arg: FunctionArgExpr::Expr(expr),
627 ..
628 }) = arg
629 else {
630 return None;
631 };
632 extract_column_ref(expr)
633}
634
635fn extract_column_ref(expr: &Expr) -> Option<String> {
637 match expr {
638 Expr::Identifier(ident) => Some(ident.value.clone()),
639 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
640 _ => None,
641 }
642}
643
644fn extract_qualified_column_ref(expr: &Expr) -> Option<(Option<String>, String)> {
645 match expr {
646 Expr::Identifier(ident) => Some((None, ident.value.clone())),
647 Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
648 Some((Some(parts[0].value.clone()), parts[1].value.clone()))
649 }
650 Expr::CompoundIdentifier(parts) => parts.last().map(|p| (None, p.value.clone())),
651 _ => None,
652 }
653}
654
655fn extract_time_bound_from_expr(expr: &Expr) -> Result<Duration, ParseError> {
657 match expr {
658 Expr::Interval(_) => WindowRewriter::parse_interval_to_duration(expr),
660 Expr::BinaryOp {
662 left: _,
663 op: BinaryOperator::Plus | BinaryOperator::Minus,
664 right,
665 } => extract_time_bound_from_expr(right),
666 Expr::Nested(inner) => extract_time_bound_from_expr(inner),
668 _ => Err(ParseError::StreamingError(format!(
669 "Cannot extract time bound from: {expr:?}"
670 ))),
671 }
672}
673
674fn analyze_asof_match_condition(
678 expr: &Expr,
679) -> Result<(AsofSqlDirection, String, String, Option<Duration>), ParseError> {
680 if let Expr::BinaryOp {
681 left,
682 op: BinaryOperator::And,
683 right,
684 } = expr
685 {
686 let dir_result = analyze_asof_direction(left);
688 let tol_result = extract_asof_tolerance(right);
689
690 match (dir_result, tol_result) {
691 (Ok((dir, lt, rt)), Ok(tol)) => Ok((dir, lt, rt, Some(tol))),
692 (Ok((dir, lt, rt)), Err(_)) => {
693 let dir2 = analyze_asof_direction(right);
695 let tol2 = extract_asof_tolerance(left);
696 match (dir2, tol2) {
697 (Ok((d, l, r)), Ok(t)) => Ok((d, l, r, Some(t))),
698 _ => Ok((dir, lt, rt, None)),
699 }
700 }
701 (Err(_), _) => {
702 let dir2 = analyze_asof_direction(right);
704 let tol2 = extract_asof_tolerance(left);
705 match (dir2, tol2) {
706 (Ok((d, l, r)), Ok(t)) => Ok((d, l, r, Some(t))),
707 (Ok((d, l, r)), Err(_)) => Ok((d, l, r, None)),
708 _ => Err(ParseError::StreamingError(
709 "Cannot extract ASOF direction from MATCH_CONDITION".to_string(),
710 )),
711 }
712 }
713 }
714 } else {
715 let (dir, lt, rt) = analyze_asof_direction(expr)?;
716 Ok((dir, lt, rt, None))
717 }
718}
719
720fn analyze_asof_direction(expr: &Expr) -> Result<(AsofSqlDirection, String, String), ParseError> {
722 match expr {
723 Expr::BinaryOp {
724 left,
725 op: BinaryOperator::GtEq,
726 right,
727 } => {
728 let left_col = extract_column_ref(left).ok_or_else(|| {
729 ParseError::StreamingError(
730 "Cannot extract left time column from MATCH_CONDITION".to_string(),
731 )
732 })?;
733 let right_col = extract_column_ref(right).ok_or_else(|| {
734 ParseError::StreamingError(
735 "Cannot extract right time column from MATCH_CONDITION".to_string(),
736 )
737 })?;
738 Ok((AsofSqlDirection::Backward, left_col, right_col))
739 }
740 Expr::BinaryOp {
741 left,
742 op: BinaryOperator::LtEq,
743 right,
744 } => {
745 let left_col = extract_column_ref(left).ok_or_else(|| {
746 ParseError::StreamingError(
747 "Cannot extract left time column from MATCH_CONDITION".to_string(),
748 )
749 })?;
750 let right_col = extract_column_ref(right).ok_or_else(|| {
751 ParseError::StreamingError(
752 "Cannot extract right time column from MATCH_CONDITION".to_string(),
753 )
754 })?;
755 Ok((AsofSqlDirection::Forward, left_col, right_col))
756 }
757 Expr::Function(func) => {
759 let name = func.name.to_string().to_uppercase();
760 if name != "NEAREST" {
761 return Err(ParseError::StreamingError(format!(
762 "Unknown ASOF MATCH_CONDITION function: {name}"
763 )));
764 }
765 let args = match &func.args {
766 FunctionArguments::List(arg_list) => &arg_list.args,
767 _ => {
768 return Err(ParseError::StreamingError(
769 "NEAREST() requires exactly 2 column arguments".to_string(),
770 ))
771 }
772 };
773 if args.len() != 2 {
774 return Err(ParseError::StreamingError(format!(
775 "NEAREST() requires exactly 2 arguments, got {}",
776 args.len()
777 )));
778 }
779 let left_col = extract_column_from_func_arg(&args[0]).ok_or_else(|| {
780 ParseError::StreamingError(
781 "Cannot extract left time column from NEAREST()".to_string(),
782 )
783 })?;
784 let right_col = extract_column_from_func_arg(&args[1]).ok_or_else(|| {
785 ParseError::StreamingError(
786 "Cannot extract right time column from NEAREST()".to_string(),
787 )
788 })?;
789 Ok((AsofSqlDirection::Nearest, left_col, right_col))
790 }
791 _ => Err(ParseError::StreamingError(
792 "ASOF MATCH_CONDITION must be >= or <= comparison, or NEAREST()".to_string(),
793 )),
794 }
795}
796
797fn extract_asof_tolerance(expr: &Expr) -> Result<Duration, ParseError> {
801 match expr {
802 Expr::BinaryOp {
803 left: _,
804 op: BinaryOperator::LtEq,
805 right,
806 } => {
807 match right.as_ref() {
809 Expr::Value(v) => {
810 if let sqlparser::ast::Value::Number(n, _) = &v.value {
811 let ms: u64 = n.parse().map_err(|_| {
812 ParseError::StreamingError(format!(
813 "Cannot parse tolerance as number: {n}"
814 ))
815 })?;
816 Ok(Duration::from_millis(ms))
817 } else {
818 Err(ParseError::StreamingError(
819 "ASOF tolerance must be a number or INTERVAL".to_string(),
820 ))
821 }
822 }
823 Expr::Interval(_) => WindowRewriter::parse_interval_to_duration(right),
824 _ => Err(ParseError::StreamingError(
825 "ASOF tolerance must be a number or INTERVAL".to_string(),
826 )),
827 }
828 }
829 _ => Err(ParseError::StreamingError(
830 "ASOF tolerance expression must be <= comparison".to_string(),
831 )),
832 }
833}
834
835fn analyze_asof_constraint(constraint: &JoinConstraint) -> Result<(String, String), ParseError> {
837 match constraint {
838 JoinConstraint::On(expr) => extract_equality_columns(expr),
839 JoinConstraint::Using(cols) => {
840 if cols.is_empty() {
841 return Err(ParseError::StreamingError(
842 "USING clause requires at least one column".to_string(),
843 ));
844 }
845 let col = cols[0].to_string();
846 Ok((col.clone(), col))
847 }
848 _ => Err(ParseError::StreamingError(
849 "ASOF JOIN requires ON or USING constraint".to_string(),
850 )),
851 }
852}
853
854fn extract_equality_columns(expr: &Expr) -> Result<(String, String), ParseError> {
856 match expr {
857 Expr::BinaryOp {
858 left,
859 op: BinaryOperator::Eq,
860 right,
861 } => {
862 let left_col = extract_column_ref(left).ok_or_else(|| {
863 ParseError::StreamingError("Cannot extract left key column".to_string())
864 })?;
865 let right_col = extract_column_ref(right).ok_or_else(|| {
866 ParseError::StreamingError("Cannot extract right key column".to_string())
867 })?;
868 Ok((left_col, right_col))
869 }
870 Expr::BinaryOp {
872 left,
873 op: BinaryOperator::And,
874 right,
875 } => extract_equality_columns(left).or_else(|_| extract_equality_columns(right)),
876 _ => Err(ParseError::StreamingError(
877 "ASOF JOIN ON clause must contain an equality condition".to_string(),
878 )),
879 }
880}
881
882#[must_use]
884pub fn has_join(select: &Select) -> bool {
885 !select.from.is_empty() && !select.from[0].joins.is_empty()
886}
887
888#[must_use]
890pub fn count_joins(select: &Select) -> usize {
891 select
892 .from
893 .iter()
894 .map(|table_with_joins| table_with_joins.joins.len())
895 .sum()
896}
897
898#[derive(Debug, Clone)]
903pub struct MultiJoinAnalysis {
904 pub joins: Vec<JoinAnalysis>,
906 pub tables: Vec<String>,
908}
909
910impl MultiJoinAnalysis {
911 #[must_use]
913 pub fn len(&self) -> usize {
914 self.joins.len()
915 }
916
917 #[must_use]
919 pub fn is_empty(&self) -> bool {
920 self.joins.is_empty()
921 }
922
923 #[must_use]
925 pub fn is_single(&self) -> bool {
926 self.joins.len() == 1
927 }
928
929 #[must_use]
931 pub fn first(&self) -> Option<&JoinAnalysis> {
932 self.joins.first()
933 }
934}
935
936pub fn analyze_joins(select: &Select) -> Result<Option<MultiJoinAnalysis>, ParseError> {
947 let from = &select.from;
948 if from.is_empty() {
949 return Ok(None);
950 }
951
952 let first_table = &from[0];
953 if first_table.joins.is_empty() {
954 return Ok(None);
955 }
956
957 let base_table = extract_table_name(&first_table.relation)?;
959 let base_alias = extract_table_alias(&first_table.relation);
960
961 let mut join_steps = Vec::with_capacity(first_table.joins.len());
962 let mut tables = vec![base_table.clone()];
963
964 let mut prev_left_table = base_table;
966 let mut prev_left_alias = base_alias;
967
968 for join in &first_table.joins {
969 let right_table = extract_table_name(&join.relation)?;
970 let right_alias = extract_table_alias(&join.relation);
971 tables.push(right_table.clone());
972
973 let join_type = map_join_operator(&join.join_operator);
974
975 if let JoinOperator::AsOf {
977 match_condition,
978 constraint,
979 } = &join.join_operator
980 {
981 let (direction, left_time, right_time, tolerance) =
982 analyze_asof_match_condition(match_condition)?;
983 let (left_key, right_key) = analyze_asof_constraint(constraint)?;
984
985 let mut analysis = JoinAnalysis::asof(
986 prev_left_table.clone(),
987 right_table.clone(),
988 left_key,
989 right_key,
990 direction,
991 left_time,
992 right_time,
993 tolerance,
994 );
995 analysis.left_alias.clone_from(&prev_left_alias);
996 analysis.right_alias = right_alias;
997 join_steps.push(analysis);
998 } else if let Some(version_col) = extract_temporal_version(&join.relation) {
999 let (left_key, right_key, additional, _, _) =
1001 analyze_join_constraint(&join.join_operator)?;
1002
1003 let mut analysis = JoinAnalysis::temporal(
1004 prev_left_table.clone(),
1005 right_table.clone(),
1006 left_key,
1007 right_key,
1008 version_col,
1009 join_type,
1010 );
1011 analysis.left_alias.clone_from(&prev_left_alias);
1012 analysis.right_alias = right_alias;
1013 analysis.additional_key_columns = additional;
1014 join_steps.push(analysis);
1015 } else {
1016 let (left_key, right_key, additional, time_bound, time_cols) =
1018 analyze_join_constraint(&join.join_operator)?;
1019
1020 let mut analysis = if let Some(tb) = time_bound {
1021 JoinAnalysis::stream_stream(
1022 prev_left_table.clone(),
1023 right_table.clone(),
1024 left_key,
1025 right_key,
1026 tb,
1027 join_type,
1028 )
1029 } else {
1030 JoinAnalysis::lookup(
1031 prev_left_table.clone(),
1032 right_table.clone(),
1033 left_key,
1034 right_key,
1035 join_type,
1036 )
1037 };
1038 analysis.left_alias.clone_from(&prev_left_alias);
1039 analysis.right_alias.clone_from(&right_alias);
1040 analysis.additional_key_columns = additional;
1041
1042 if let Some(ref raw) = time_cols {
1043 let (lt, rt) = resolve_time_cols(
1044 raw,
1045 &analysis.left_table,
1046 &analysis.right_table,
1047 prev_left_alias.as_deref(),
1048 right_alias.as_deref(),
1049 );
1050 analysis.left_time_column = Some(lt);
1051 analysis.right_time_column = Some(rt);
1052 }
1053 join_steps.push(analysis);
1054 }
1055
1056 prev_left_table = right_table;
1058 prev_left_alias = extract_table_alias(&join.relation);
1059 }
1060
1061 Ok(Some(MultiJoinAnalysis {
1062 joins: join_steps,
1063 tables,
1064 }))
1065}
1066
1067#[cfg(test)]
1068mod tests {
1069 use super::*;
1070 use sqlparser::ast::{SetExpr, Statement};
1071 use sqlparser::dialect::GenericDialect;
1072 use sqlparser::parser::Parser;
1073
1074 fn parse_select(sql: &str) -> Select {
1075 let dialect = GenericDialect {};
1076 let statements = Parser::parse_sql(&dialect, sql).unwrap();
1077 if let Statement::Query(query) = &statements[0] {
1078 if let SetExpr::Select(select) = query.body.as_ref() {
1079 return *select.clone();
1080 }
1081 }
1082 panic!("Expected SELECT query");
1083 }
1084
1085 #[test]
1086 fn test_analyze_inner_join() {
1087 let sql = "SELECT * FROM orders o INNER JOIN payments p ON o.order_id = p.order_id";
1088 let select = parse_select(sql);
1089
1090 let analysis = analyze_join(&select).unwrap().unwrap();
1091
1092 assert_eq!(analysis.join_type, JoinType::Inner);
1093 assert_eq!(analysis.left_table, "orders");
1094 assert_eq!(analysis.right_table, "payments");
1095 assert_eq!(analysis.left_key_column, "order_id");
1096 assert_eq!(analysis.right_key_column, "order_id");
1097 assert!(analysis.is_lookup_join); }
1099
1100 #[test]
1101 fn test_analyze_left_join() {
1102 let sql = "SELECT * FROM orders o LEFT JOIN customers c ON o.customer_id = c.id";
1103 let select = parse_select(sql);
1104
1105 let analysis = analyze_join(&select).unwrap().unwrap();
1106
1107 assert_eq!(analysis.join_type, JoinType::Left);
1108 assert_eq!(analysis.left_key_column, "customer_id");
1109 assert_eq!(analysis.right_key_column, "id");
1110 }
1111
1112 #[test]
1113 fn test_analyze_join_using() {
1114 let sql = "SELECT * FROM orders o JOIN payments p USING (order_id)";
1115 let select = parse_select(sql);
1116
1117 let analysis = analyze_join(&select).unwrap().unwrap();
1118
1119 assert_eq!(analysis.left_key_column, "order_id");
1120 assert_eq!(analysis.right_key_column, "order_id");
1121 }
1122
1123 #[test]
1124 fn test_analyze_stream_stream_join_with_time_bound() {
1125 let sql = "SELECT * FROM orders o
1126 JOIN payments p ON o.order_id = p.order_id
1127 AND p.ts BETWEEN o.ts AND o.ts + INTERVAL '1' HOUR";
1128 let select = parse_select(sql);
1129
1130 let analysis = analyze_join(&select).unwrap().unwrap();
1131
1132 assert!(!analysis.is_lookup_join);
1133 assert!(analysis.time_bound.is_some());
1134 assert_eq!(analysis.time_bound.unwrap(), Duration::from_secs(3600));
1135 }
1136
1137 #[test]
1138 fn test_no_join() {
1139 let sql = "SELECT * FROM orders";
1140 let select = parse_select(sql);
1141
1142 let analysis = analyze_join(&select).unwrap();
1143 assert!(analysis.is_none());
1144 }
1145
1146 #[test]
1147 fn test_has_join() {
1148 let sql_with_join = "SELECT * FROM orders o JOIN payments p ON o.id = p.order_id";
1149 let sql_without_join = "SELECT * FROM orders";
1150
1151 let select_with = parse_select(sql_with_join);
1152 let select_without = parse_select(sql_without_join);
1153
1154 assert!(has_join(&select_with));
1155 assert!(!has_join(&select_without));
1156 }
1157
1158 #[test]
1159 fn test_count_joins() {
1160 let sql_one = "SELECT * FROM a JOIN b ON a.id = b.id";
1161 let sql_two = "SELECT * FROM a JOIN b ON a.id = b.id JOIN c ON b.id = c.id";
1162 let sql_zero = "SELECT * FROM a";
1163
1164 assert_eq!(count_joins(&parse_select(sql_one)), 1);
1165 assert_eq!(count_joins(&parse_select(sql_two)), 2);
1166 assert_eq!(count_joins(&parse_select(sql_zero)), 0);
1167 }
1168
1169 #[test]
1170 fn test_aliases() {
1171 let sql = "SELECT * FROM orders AS o JOIN payments AS p ON o.id = p.order_id";
1172 let select = parse_select(sql);
1173
1174 let analysis = analyze_join(&select).unwrap().unwrap();
1175
1176 assert_eq!(analysis.left_alias, Some("o".to_string()));
1177 assert_eq!(analysis.right_alias, Some("p".to_string()));
1178 }
1179
1180 fn parse_select_snowflake(sql: &str) -> Select {
1183 let dialect = sqlparser::dialect::SnowflakeDialect {};
1184 let statements = Parser::parse_sql(&dialect, sql).unwrap();
1185 if let Statement::Query(query) = &statements[0] {
1186 if let SetExpr::Select(select) = query.body.as_ref() {
1187 return *select.clone();
1188 }
1189 }
1190 panic!("Expected SELECT query");
1191 }
1192
1193 fn parse_select_laminar(sql: &str) -> Select {
1194 let dialect = crate::parser::dialect::LaminarDialect::default();
1195 let statements = Parser::parse_sql(&dialect, sql).unwrap();
1196 if let Statement::Query(query) = &statements[0] {
1197 if let SetExpr::Select(select) = query.body.as_ref() {
1198 return *select.clone();
1199 }
1200 }
1201 panic!("Expected SELECT query");
1202 }
1203
1204 #[test]
1205 fn test_asof_join_backward() {
1206 let sql = "SELECT * FROM trades t \
1207 ASOF JOIN quotes q \
1208 MATCH_CONDITION(t.ts >= q.ts) \
1209 ON t.symbol = q.symbol";
1210 let select = parse_select_snowflake(sql);
1211 let analysis = analyze_join(&select).unwrap().unwrap();
1212
1213 assert!(analysis.is_asof_join);
1214 assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
1215 assert_eq!(analysis.join_type, JoinType::AsOf);
1216 assert!(analysis.asof_tolerance.is_none());
1217 }
1218
1219 #[test]
1220 fn test_asof_join_forward() {
1221 let sql = "SELECT * FROM trades t \
1222 ASOF JOIN quotes q \
1223 MATCH_CONDITION(t.ts <= q.ts) \
1224 ON t.symbol = q.symbol";
1225 let select = parse_select_snowflake(sql);
1226 let analysis = analyze_join(&select).unwrap().unwrap();
1227
1228 assert!(analysis.is_asof_join);
1229 assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Forward));
1230 }
1231
1232 #[test]
1233 fn test_asof_join_nearest() {
1234 let sql = "SELECT * FROM trades t \
1235 ASOF JOIN quotes q \
1236 MATCH_CONDITION(NEAREST(t.ts, q.ts)) \
1237 ON t.symbol = q.symbol";
1238 let select = parse_select_snowflake(sql);
1239 let analysis = analyze_join(&select).unwrap().unwrap();
1240
1241 assert!(analysis.is_asof_join);
1242 assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Nearest));
1243 assert_eq!(analysis.join_type, JoinType::AsOf);
1244 assert!(analysis.asof_tolerance.is_none());
1245 }
1246
1247 #[test]
1248 fn test_asof_join_with_tolerance() {
1249 let sql = "SELECT * FROM trades t \
1250 ASOF JOIN quotes q \
1251 MATCH_CONDITION(t.ts >= q.ts AND t.ts - q.ts <= 5000) \
1252 ON t.symbol = q.symbol";
1253 let select = parse_select_snowflake(sql);
1254 let analysis = analyze_join(&select).unwrap().unwrap();
1255
1256 assert!(analysis.is_asof_join);
1257 assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
1258 assert_eq!(analysis.asof_tolerance, Some(Duration::from_secs(5)));
1259 }
1260
1261 #[test]
1262 fn test_asof_join_with_interval_tolerance() {
1263 let sql = "SELECT * FROM trades t \
1264 ASOF JOIN quotes q \
1265 MATCH_CONDITION(t.ts >= q.ts AND t.ts - q.ts <= INTERVAL '5' SECOND) \
1266 ON t.symbol = q.symbol";
1267 let select = parse_select_snowflake(sql);
1268 let analysis = analyze_join(&select).unwrap().unwrap();
1269
1270 assert!(analysis.is_asof_join);
1271 assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
1272 assert_eq!(analysis.asof_tolerance, Some(Duration::from_secs(5)));
1273 }
1274
1275 #[test]
1276 fn test_asof_join_type_mapping() {
1277 let sql = "SELECT * FROM trades t \
1278 ASOF JOIN quotes q \
1279 MATCH_CONDITION(t.ts >= q.ts) \
1280 ON t.symbol = q.symbol";
1281 let select = parse_select_snowflake(sql);
1282 let analysis = analyze_join(&select).unwrap().unwrap();
1283
1284 assert_eq!(analysis.join_type, JoinType::AsOf);
1285 assert!(!analysis.is_lookup_join);
1286 }
1287
1288 #[test]
1289 fn test_asof_join_extracts_time_columns() {
1290 let sql = "SELECT * FROM trades t \
1291 ASOF JOIN quotes q \
1292 MATCH_CONDITION(t.ts >= q.ts) \
1293 ON t.symbol = q.symbol";
1294 let select = parse_select_snowflake(sql);
1295 let analysis = analyze_join(&select).unwrap().unwrap();
1296
1297 assert_eq!(analysis.left_time_column, Some("ts".to_string()));
1298 assert_eq!(analysis.right_time_column, Some("ts".to_string()));
1299 }
1300
1301 #[test]
1302 fn test_asof_join_extracts_key_columns() {
1303 let sql = "SELECT * FROM trades t \
1304 ASOF JOIN quotes q \
1305 MATCH_CONDITION(t.ts >= q.ts) \
1306 ON t.symbol = q.symbol";
1307 let select = parse_select_snowflake(sql);
1308 let analysis = analyze_join(&select).unwrap().unwrap();
1309
1310 assert_eq!(analysis.left_key_column, "symbol");
1311 assert_eq!(analysis.right_key_column, "symbol");
1312 }
1313
1314 #[test]
1315 fn test_asof_join_aliases() {
1316 let sql = "SELECT * FROM trades AS t \
1317 ASOF JOIN quotes AS q \
1318 MATCH_CONDITION(t.ts >= q.ts) \
1319 ON t.symbol = q.symbol";
1320 let select = parse_select_snowflake(sql);
1321 let analysis = analyze_join(&select).unwrap().unwrap();
1322
1323 assert_eq!(analysis.left_alias, Some("t".to_string()));
1324 assert_eq!(analysis.right_alias, Some("q".to_string()));
1325 assert_eq!(analysis.left_table, "trades");
1326 assert_eq!(analysis.right_table, "quotes");
1327 }
1328
1329 #[test]
1332 fn test_multi_join_single_backward_compat() {
1333 let sql = "SELECT * FROM orders o JOIN payments p ON o.id = p.order_id";
1334 let select = parse_select(sql);
1335 let multi = analyze_joins(&select).unwrap().unwrap();
1336
1337 assert!(multi.is_single());
1338 assert_eq!(multi.len(), 1);
1339 assert!(!multi.is_empty());
1340 let first = multi.first().unwrap();
1341 assert_eq!(first.left_table, "orders");
1342 assert_eq!(first.right_table, "payments");
1343 }
1344
1345 #[test]
1346 fn test_multi_join_two_way() {
1347 let sql = "SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id";
1348 let select = parse_select(sql);
1349 let multi = analyze_joins(&select).unwrap().unwrap();
1350
1351 assert_eq!(multi.len(), 2);
1352 assert!(!multi.is_single());
1353
1354 assert_eq!(multi.joins[0].left_table, "a");
1355 assert_eq!(multi.joins[0].right_table, "b");
1356 assert_eq!(multi.joins[0].left_key_column, "id");
1357 assert_eq!(multi.joins[0].right_key_column, "a_id");
1358
1359 assert_eq!(multi.joins[1].left_table, "b");
1360 assert_eq!(multi.joins[1].right_table, "c");
1361 assert_eq!(multi.joins[1].left_key_column, "id");
1362 assert_eq!(multi.joins[1].right_key_column, "b_id");
1363 }
1364
1365 #[test]
1366 fn test_multi_join_three_way() {
1367 let sql = "SELECT * FROM a \
1368 JOIN b ON a.id = b.a_id \
1369 JOIN c ON b.id = c.b_id \
1370 JOIN d ON c.id = d.c_id";
1371 let select = parse_select(sql);
1372 let multi = analyze_joins(&select).unwrap().unwrap();
1373
1374 assert_eq!(multi.len(), 3);
1375 assert_eq!(multi.tables.len(), 4);
1376 assert_eq!(multi.tables, vec!["a", "b", "c", "d"]);
1377 }
1378
1379 #[test]
1380 fn test_multi_join_mixed_asof_and_lookup() {
1381 let sql = "SELECT * FROM trades t \
1383 ASOF JOIN quotes q \
1384 MATCH_CONDITION(t.ts >= q.ts) \
1385 ON t.symbol = q.symbol \
1386 JOIN products p ON q.product_id = p.id";
1387 let select = parse_select_snowflake(sql);
1388 let multi = analyze_joins(&select).unwrap().unwrap();
1389
1390 assert_eq!(multi.len(), 2);
1391 assert!(multi.joins[0].is_asof_join);
1392 assert!(multi.joins[1].is_lookup_join);
1393 }
1394
1395 #[test]
1396 fn test_multi_join_stream_stream_and_lookup() {
1397 let sql = "SELECT * FROM orders o \
1398 JOIN payments p ON o.id = p.order_id \
1399 AND p.ts BETWEEN o.ts AND o.ts + INTERVAL '1' HOUR \
1400 JOIN customers c ON o.customer_id = c.id";
1401 let select = parse_select(sql);
1402 let multi = analyze_joins(&select).unwrap().unwrap();
1403
1404 assert_eq!(multi.len(), 2);
1405 assert!(!multi.joins[0].is_lookup_join); assert!(multi.joins[0].time_bound.is_some());
1407 assert!(multi.joins[1].is_lookup_join); }
1409
1410 #[test]
1411 fn test_multi_join_tables_list() {
1412 let sql = "SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id";
1413 let select = parse_select(sql);
1414 let multi = analyze_joins(&select).unwrap().unwrap();
1415
1416 assert_eq!(multi.tables, vec!["a", "b", "c"]);
1417 }
1418
1419 #[test]
1420 fn test_multi_join_aliases() {
1421 let sql = "SELECT * FROM orders AS o \
1422 JOIN payments AS p ON o.id = p.order_id \
1423 JOIN refunds AS r ON p.id = r.payment_id";
1424 let select = parse_select(sql);
1425 let multi = analyze_joins(&select).unwrap().unwrap();
1426
1427 assert_eq!(multi.joins[0].left_alias, Some("o".to_string()));
1428 assert_eq!(multi.joins[0].right_alias, Some("p".to_string()));
1429 assert_eq!(multi.joins[1].left_alias, Some("p".to_string()));
1430 assert_eq!(multi.joins[1].right_alias, Some("r".to_string()));
1431 }
1432
1433 #[test]
1434 fn test_multi_join_no_join_returns_none() {
1435 let sql = "SELECT * FROM orders";
1436 let select = parse_select(sql);
1437 let multi = analyze_joins(&select).unwrap();
1438 assert!(multi.is_none());
1439 }
1440
1441 #[test]
1444 fn test_temporal_join_detected() {
1445 let sql = "SELECT o.*, p.price \
1446 FROM orders o \
1447 JOIN products FOR SYSTEM_TIME AS OF o.order_time AS p \
1448 ON o.product_id = p.id";
1449 let select = parse_select_laminar(sql);
1450 let analysis = analyze_join(&select).unwrap().unwrap();
1451
1452 assert!(analysis.is_temporal_join);
1453 assert_eq!(
1454 analysis.temporal_version_column,
1455 Some("order_time".to_string())
1456 );
1457 assert_eq!(analysis.left_table, "orders");
1458 assert_eq!(analysis.right_table, "products");
1459 assert_eq!(analysis.left_key_column, "product_id");
1460 assert_eq!(analysis.right_key_column, "id");
1461 assert!(!analysis.is_lookup_join);
1462 assert!(!analysis.is_asof_join);
1463 }
1464
1465 #[test]
1466 fn test_temporal_join_via_analyze_joins() {
1467 let sql = "SELECT o.*, p.price \
1468 FROM orders o \
1469 JOIN products FOR SYSTEM_TIME AS OF o.order_time AS p \
1470 ON o.product_id = p.id";
1471 let select = parse_select_laminar(sql);
1472 let multi = analyze_joins(&select).unwrap().unwrap();
1473
1474 assert_eq!(multi.len(), 1);
1475 let first = multi.first().unwrap();
1476 assert!(first.is_temporal_join);
1477 assert_eq!(
1478 first.temporal_version_column,
1479 Some("order_time".to_string())
1480 );
1481 }
1482
1483 #[test]
1484 fn test_non_temporal_join_not_flagged() {
1485 let sql = "SELECT * FROM orders o JOIN payments p ON o.id = p.order_id";
1486 let select = parse_select(sql);
1487 let analysis = analyze_join(&select).unwrap().unwrap();
1488
1489 assert!(!analysis.is_temporal_join);
1490 assert!(analysis.temporal_version_column.is_none());
1491 }
1492
1493 #[test]
1494 fn test_unqualified_anti_maps_to_left_anti() {
1495 let sql = "SELECT * FROM orders o ANTI JOIN returns r ON o.id = r.order_id";
1496 let select = parse_select(sql);
1497 let analysis = analyze_join(&select).unwrap().unwrap();
1498 assert_eq!(analysis.join_type, JoinType::LeftAnti);
1499 }
1500
1501 #[test]
1502 fn test_unqualified_semi_maps_to_left_semi() {
1503 let sql = "SELECT * FROM orders o SEMI JOIN payments p ON o.id = p.order_id";
1504 let select = parse_select(sql);
1505 let analysis = analyze_join(&select).unwrap().unwrap();
1506 assert_eq!(analysis.join_type, JoinType::LeftSemi);
1507 }
1508
1509 #[test]
1510 fn test_composite_join_keys() {
1511 let sql = "SELECT * FROM orders o \
1512 JOIN shipments s \
1513 ON o.order_id = s.order_id AND o.region = s.region";
1514 let select = parse_select(sql);
1515 let analysis = analyze_join(&select).unwrap().unwrap();
1516
1517 assert_eq!(analysis.left_key_column, "order_id");
1519 assert_eq!(analysis.right_key_column, "order_id");
1520
1521 assert_eq!(
1523 analysis.additional_key_columns.len(),
1524 1,
1525 "Should have 1 additional key pair"
1526 );
1527 assert_eq!(analysis.additional_key_columns[0].0, "region");
1528 assert_eq!(analysis.additional_key_columns[0].1, "region");
1529 }
1530
1531 #[test]
1532 fn test_composite_using_clause() {
1533 let sql = "SELECT * FROM orders o JOIN shipments s USING (order_id, region)";
1534 let select = parse_select(sql);
1535 let analysis = analyze_join(&select).unwrap().unwrap();
1536
1537 assert_eq!(analysis.left_key_column, "order_id");
1539 assert_eq!(analysis.right_key_column, "order_id");
1540
1541 assert_eq!(
1543 analysis.additional_key_columns.len(),
1544 1,
1545 "USING(order_id, region) should have 1 additional key"
1546 );
1547 assert_eq!(analysis.additional_key_columns[0].0, "region");
1548 assert_eq!(analysis.additional_key_columns[0].1, "region");
1549 }
1550}