Skip to main content

laminar_sql/parser/
join_parser.rs

1//! Join query analysis and extraction
2//!
3//! This module analyzes JOIN clauses to extract:
4//! - Join type (INNER, LEFT, RIGHT, FULL)
5//! - Key columns for join condition
6//! - Time bounds for stream-stream joins
7//! - Detection of lookup joins vs stream-stream joins
8
9use std::time::Duration;
10
11use sqlparser::ast::{
12    BinaryOperator, Expr, FunctionArg, FunctionArgExpr, FunctionArguments, JoinConstraint,
13    JoinOperator, Select, TableFactor, TableVersion,
14};
15
16use super::window_rewriter::WindowRewriter;
17use super::ParseError;
18
19/// Join type classification
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum JoinType {
22    /// INNER JOIN
23    Inner,
24    /// LEFT \[OUTER\] JOIN
25    Left,
26    /// RIGHT \[OUTER\] JOIN
27    Right,
28    /// FULL \[OUTER\] JOIN
29    Full,
30    /// LEFT SEMI JOIN — emit left rows with at least one match
31    LeftSemi,
32    /// LEFT ANTI JOIN — emit left rows with no match
33    LeftAnti,
34    /// RIGHT SEMI JOIN — emit right rows with at least one match
35    RightSemi,
36    /// RIGHT ANTI JOIN — emit right rows with no match
37    RightAnti,
38    /// ASOF JOIN
39    AsOf,
40}
41
42/// Direction for ASOF JOIN time matching.
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum AsofSqlDirection {
45    /// `left.ts >= right.ts` — find most recent right row
46    Backward,
47    /// `left.ts <= right.ts` — find next right row
48    Forward,
49    /// Match by minimum absolute time difference
50    Nearest,
51}
52
53impl std::fmt::Display for AsofSqlDirection {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        match self {
56            AsofSqlDirection::Backward => write!(f, "BACKWARD"),
57            AsofSqlDirection::Forward => write!(f, "FORWARD"),
58            AsofSqlDirection::Nearest => write!(f, "NEAREST"),
59        }
60    }
61}
62
63/// Unresolved time column refs from a BETWEEN clause.
64#[derive(Debug, Clone)]
65struct RawTimeCols {
66    expr_qualifier: Option<String>,
67    expr_col: String,
68    low_qualifier: Option<String>,
69    low_col: String,
70}
71
72/// Resolve BETWEEN time columns to `(left_time_col, right_time_col)` using
73/// table qualifiers. Falls back to positional (low=left, expr=right).
74fn resolve_time_cols(
75    raw: &RawTimeCols,
76    left_table: &str,
77    right_table: &str,
78    left_alias: Option<&str>,
79    right_alias: Option<&str>,
80) -> (String, String) {
81    let matches_left = |q: &Option<String>| -> bool {
82        q.as_ref()
83            .is_some_and(|t| t == left_table || left_alias.is_some_and(|a| a == t))
84    };
85    let matches_right = |q: &Option<String>| -> bool {
86        q.as_ref()
87            .is_some_and(|t| t == right_table || right_alias.is_some_and(|a| a == t))
88    };
89
90    if matches_right(&raw.expr_qualifier) && matches_left(&raw.low_qualifier) {
91        (raw.low_col.clone(), raw.expr_col.clone())
92    } else if matches_left(&raw.expr_qualifier) && matches_right(&raw.low_qualifier) {
93        (raw.expr_col.clone(), raw.low_col.clone())
94    } else {
95        (raw.low_col.clone(), raw.expr_col.clone())
96    }
97}
98
99/// Analysis result for a JOIN clause
100#[derive(Debug, Clone)]
101pub struct JoinAnalysis {
102    /// Type of join (inner, left, right, full)
103    pub join_type: JoinType,
104    /// Left side table name
105    pub left_table: String,
106    /// Right side table name
107    pub right_table: String,
108    /// Left side key column
109    pub left_key_column: String,
110    /// Right side key column
111    pub right_key_column: String,
112    /// Time bound for stream-stream joins (None for lookup joins)
113    pub time_bound: Option<Duration>,
114    /// Whether this is a lookup join (no time bound)
115    pub is_lookup_join: bool,
116    /// Left side alias (if any)
117    pub left_alias: Option<String>,
118    /// Right side alias (if any)
119    pub right_alias: Option<String>,
120    /// Whether this is an ASOF join
121    pub is_asof_join: bool,
122    /// ASOF join direction (Backward or Forward)
123    pub asof_direction: Option<AsofSqlDirection>,
124    /// Left side time column for ASOF join
125    pub left_time_column: Option<String>,
126    /// Right side time column for ASOF join
127    pub right_time_column: Option<String>,
128    /// ASOF join tolerance (max time difference)
129    pub asof_tolerance: Option<Duration>,
130    /// Whether this is a temporal join (FOR SYSTEM_TIME AS OF)
131    pub is_temporal_join: bool,
132    /// The version column from FOR SYSTEM_TIME AS OF (e.g., `order_time`)
133    pub temporal_version_column: Option<String>,
134    /// Additional key columns for composite join keys (beyond the primary key pair)
135    pub additional_key_columns: Vec<(String, String)>,
136}
137
138impl JoinAnalysis {
139    /// Create a stream-stream join analysis
140    #[must_use]
141    pub fn stream_stream(
142        left_table: String,
143        right_table: String,
144        left_key: String,
145        right_key: String,
146        time_bound: Duration,
147        join_type: JoinType,
148    ) -> Self {
149        Self {
150            join_type,
151            left_table,
152            right_table,
153            left_key_column: left_key,
154            right_key_column: right_key,
155            time_bound: Some(time_bound),
156            is_lookup_join: false,
157            left_alias: None,
158            right_alias: None,
159            is_asof_join: false,
160            asof_direction: None,
161            left_time_column: None,
162            right_time_column: None,
163            asof_tolerance: None,
164            is_temporal_join: false,
165            temporal_version_column: None,
166            additional_key_columns: vec![],
167        }
168    }
169
170    /// Create a lookup join analysis
171    #[must_use]
172    pub fn lookup(
173        left_table: String,
174        right_table: String,
175        left_key: String,
176        right_key: String,
177        join_type: JoinType,
178    ) -> Self {
179        Self {
180            join_type,
181            left_table,
182            right_table,
183            left_key_column: left_key,
184            right_key_column: right_key,
185            time_bound: None,
186            is_lookup_join: true,
187            left_alias: None,
188            right_alias: None,
189            is_asof_join: false,
190            asof_direction: None,
191            left_time_column: None,
192            right_time_column: None,
193            asof_tolerance: None,
194            is_temporal_join: false,
195            temporal_version_column: None,
196            additional_key_columns: vec![],
197        }
198    }
199
200    /// Create an ASOF join analysis
201    #[must_use]
202    #[allow(clippy::too_many_arguments)]
203    pub fn asof(
204        left_table: String,
205        right_table: String,
206        left_key: String,
207        right_key: String,
208        direction: AsofSqlDirection,
209        left_time_col: String,
210        right_time_col: String,
211        tolerance: Option<Duration>,
212    ) -> Self {
213        Self {
214            join_type: JoinType::AsOf,
215            left_table,
216            right_table,
217            left_key_column: left_key,
218            right_key_column: right_key,
219            time_bound: None,
220            is_lookup_join: false,
221            left_alias: None,
222            right_alias: None,
223            is_asof_join: true,
224            asof_direction: Some(direction),
225            left_time_column: Some(left_time_col),
226            right_time_column: Some(right_time_col),
227            asof_tolerance: tolerance,
228            is_temporal_join: false,
229            temporal_version_column: None,
230            additional_key_columns: vec![],
231        }
232    }
233
234    /// Create a temporal join analysis (FOR SYSTEM_TIME AS OF).
235    #[must_use]
236    pub fn temporal(
237        left_table: String,
238        right_table: String,
239        left_key: String,
240        right_key: String,
241        version_column: String,
242        join_type: JoinType,
243    ) -> Self {
244        Self {
245            join_type,
246            left_table,
247            right_table,
248            left_key_column: left_key,
249            right_key_column: right_key,
250            time_bound: None,
251            is_lookup_join: false,
252            left_alias: None,
253            right_alias: None,
254            is_asof_join: false,
255            asof_direction: None,
256            left_time_column: None,
257            right_time_column: None,
258            asof_tolerance: None,
259            is_temporal_join: true,
260            temporal_version_column: Some(version_column),
261            additional_key_columns: vec![],
262        }
263    }
264
265    /// True if this step has any kind of temporal bound — a `BETWEEN`-derived
266    /// time bound, ASOF match condition, or `FOR SYSTEM_TIME AS OF`.
267    #[must_use]
268    pub fn is_bounded(&self) -> bool {
269        self.time_bound.is_some() || self.is_asof_join || self.is_temporal_join
270    }
271}
272
273/// Analyze a SELECT statement for join information.
274///
275/// # Errors
276///
277/// Returns `ParseError::StreamingError` if:
278/// - Join constraint is not supported
279/// - Cannot extract key columns
280pub fn analyze_join(select: &Select) -> Result<Option<JoinAnalysis>, ParseError> {
281    let from = &select.from;
282    if from.is_empty() {
283        return Ok(None);
284    }
285
286    let first_table = &from[0];
287    if first_table.joins.is_empty() {
288        return Ok(None);
289    }
290
291    // Extract left table information
292    let left_table = extract_table_name(&first_table.relation)?;
293    let left_alias = extract_table_alias(&first_table.relation);
294
295    // Analyze the first join
296    let join = &first_table.joins[0];
297    let right_table = extract_table_name(&join.relation)?;
298    let right_alias = extract_table_alias(&join.relation);
299
300    let join_type = map_join_operator(&join.join_operator);
301
302    // Handle ASOF JOIN specially
303    if let JoinOperator::AsOf {
304        match_condition,
305        constraint,
306    } = &join.join_operator
307    {
308        let (direction, left_time, right_time, tolerance) =
309            analyze_asof_match_condition(match_condition)?;
310
311        // Extract key columns from the ON constraint
312        let (left_key, right_key) = analyze_asof_constraint(constraint)?;
313
314        let mut analysis = JoinAnalysis::asof(
315            left_table,
316            right_table,
317            left_key,
318            right_key,
319            direction,
320            left_time,
321            right_time,
322            tolerance,
323        );
324        analysis.left_alias = left_alias;
325        analysis.right_alias = right_alias;
326        return Ok(Some(analysis));
327    }
328
329    // Check for temporal join (FOR SYSTEM_TIME AS OF)
330    if let Some(version_col) = extract_temporal_version(&join.relation) {
331        let (left_key, right_key, additional, _, _) = analyze_join_constraint(&join.join_operator)?;
332        let mut analysis = JoinAnalysis::temporal(
333            left_table,
334            right_table,
335            left_key,
336            right_key,
337            version_col,
338            join_type,
339        );
340        analysis.left_alias = left_alias;
341        analysis.right_alias = right_alias;
342        analysis.additional_key_columns = additional;
343        return Ok(Some(analysis));
344    }
345
346    // Analyze the join constraint
347    let (left_key, right_key, additional, time_bound, time_cols) =
348        analyze_join_constraint(&join.join_operator)?;
349
350    let mut analysis = if let Some(tb) = time_bound {
351        JoinAnalysis::stream_stream(left_table, right_table, left_key, right_key, tb, join_type)
352    } else {
353        JoinAnalysis::lookup(left_table, right_table, left_key, right_key, join_type)
354    };
355
356    analysis.left_alias.clone_from(&left_alias);
357    analysis.right_alias.clone_from(&right_alias);
358    analysis.additional_key_columns = additional;
359
360    if let Some(ref raw) = time_cols {
361        let (lt, rt) = resolve_time_cols(
362            raw,
363            &analysis.left_table,
364            &analysis.right_table,
365            left_alias.as_deref(),
366            right_alias.as_deref(),
367        );
368        analysis.left_time_column = Some(lt);
369        analysis.right_time_column = Some(rt);
370    }
371
372    Ok(Some(analysis))
373}
374
375/// Extract table name from a TableFactor.
376fn extract_table_name(factor: &TableFactor) -> Result<String, ParseError> {
377    match factor {
378        TableFactor::Table { name, .. } => Ok(name.to_string()),
379        TableFactor::Derived { alias, .. } => {
380            if let Some(alias) = alias {
381                Ok(alias.name.value.clone())
382            } else {
383                Err(ParseError::StreamingError(
384                    "Derived table without alias not supported".to_string(),
385                ))
386            }
387        }
388        _ => Err(ParseError::StreamingError(
389            "Unsupported table factor type".to_string(),
390        )),
391    }
392}
393
394/// Extract the version column from a temporal join's `FOR SYSTEM_TIME AS OF` clause.
395///
396/// Returns `Some(column_name)` if the table factor has a temporal version qualifier,
397/// `None` otherwise.
398fn extract_temporal_version(factor: &TableFactor) -> Option<String> {
399    if let TableFactor::Table {
400        version: Some(TableVersion::ForSystemTimeAsOf(expr)),
401        ..
402    } = factor
403    {
404        Some(extract_column_name_from_expr(expr))
405    } else {
406        None
407    }
408}
409
410/// Extract a column name from an expression (e.g., `o.order_time` → `order_time`).
411///
412/// Falls back to the full expression string for complex expressions.
413fn extract_column_name_from_expr(expr: &Expr) -> String {
414    match expr {
415        Expr::Identifier(ident) => ident.value.clone(),
416        Expr::CompoundIdentifier(parts) => parts
417            .last()
418            .map_or_else(|| expr.to_string(), |p| p.value.clone()),
419        _ => expr.to_string(),
420    }
421}
422
423/// Extract table alias from a TableFactor.
424fn extract_table_alias(factor: &TableFactor) -> Option<String> {
425    match factor {
426        TableFactor::Table { alias, .. } => alias.as_ref().map(|a| a.name.value.clone()),
427        TableFactor::Derived { alias, .. } => alias.as_ref().map(|a| a.name.value.clone()),
428        _ => None,
429    }
430}
431
432/// Map sqlparser `JoinOperator` to our `JoinType`.
433fn map_join_operator(op: &JoinOperator) -> JoinType {
434    match op {
435        JoinOperator::Inner(_) | JoinOperator::Join(_) | JoinOperator::StraightJoin(_) => {
436            JoinType::Inner
437        }
438        JoinOperator::Left(_) | JoinOperator::LeftOuter(_) => JoinType::Left,
439        JoinOperator::LeftSemi(_) | JoinOperator::Semi(_) => JoinType::LeftSemi,
440        JoinOperator::LeftAnti(_) | JoinOperator::Anti(_) => JoinType::LeftAnti,
441        JoinOperator::AsOf { .. } => JoinType::AsOf,
442        JoinOperator::Right(_) | JoinOperator::RightOuter(_) => JoinType::Right,
443        JoinOperator::RightSemi(_) => JoinType::RightSemi,
444        JoinOperator::RightAnti(_) => JoinType::RightAnti,
445        JoinOperator::FullOuter(_) => JoinType::Full,
446        // CrossJoin, CrossApply, OuterApply are rejected by get_join_constraint()
447        _ => JoinType::Inner,
448    }
449}
450
451/// Analyze join constraint to extract key columns, additional key columns,
452/// time bound, and optional time column pair.
453#[allow(clippy::type_complexity)]
454fn analyze_join_constraint(
455    op: &JoinOperator,
456) -> Result<
457    (
458        String,
459        String,
460        Vec<(String, String)>,
461        Option<Duration>,
462        Option<RawTimeCols>,
463    ),
464    ParseError,
465> {
466    let constraint = get_join_constraint(op)?;
467
468    match constraint {
469        JoinConstraint::On(expr) => {
470            let (key_pairs, time_bound, time_cols) = analyze_on_expression(expr)?;
471            if key_pairs.is_empty() {
472                return Ok((String::new(), String::new(), vec![], time_bound, time_cols));
473            }
474            let (first_left, first_right) = key_pairs[0].clone();
475            let additional = key_pairs[1..].to_vec();
476            Ok((first_left, first_right, additional, time_bound, time_cols))
477        }
478        JoinConstraint::Using(cols) => {
479            if cols.is_empty() {
480                return Err(ParseError::StreamingError(
481                    "USING clause requires at least one column".to_string(),
482                ));
483            }
484            // First column is the primary key pair
485            let first_col = cols[0].to_string();
486            // Remaining columns are additional key pairs
487            let additional: Vec<(String, String)> = cols[1..]
488                .iter()
489                .map(|c| {
490                    let col = c.to_string();
491                    (col.clone(), col)
492                })
493                .collect();
494            Ok((first_col.clone(), first_col, additional, None, None))
495        }
496        JoinConstraint::Natural => Err(ParseError::StreamingError(
497            "NATURAL JOIN not supported for streaming".to_string(),
498        )),
499        JoinConstraint::None => Err(ParseError::StreamingError(
500            "JOIN without condition not supported for streaming".to_string(),
501        )),
502    }
503}
504
505/// Get the JoinConstraint from a JoinOperator.
506fn get_join_constraint(op: &JoinOperator) -> Result<&JoinConstraint, ParseError> {
507    match op {
508        JoinOperator::Inner(constraint)
509        | JoinOperator::Join(constraint)
510        | JoinOperator::Left(constraint)
511        | JoinOperator::LeftOuter(constraint)
512        | JoinOperator::Right(constraint)
513        | JoinOperator::RightOuter(constraint)
514        | JoinOperator::FullOuter(constraint)
515        | JoinOperator::LeftSemi(constraint)
516        | JoinOperator::RightSemi(constraint)
517        | JoinOperator::LeftAnti(constraint)
518        | JoinOperator::RightAnti(constraint)
519        | JoinOperator::Semi(constraint)
520        | JoinOperator::Anti(constraint)
521        | JoinOperator::StraightJoin(constraint)
522        | JoinOperator::AsOf { constraint, .. } => Ok(constraint),
523        JoinOperator::CrossJoin(_) | JoinOperator::CrossApply | JoinOperator::OuterApply => Err(
524            ParseError::StreamingError("CROSS JOIN not supported for streaming".to_string()),
525        ),
526    }
527}
528
529/// Analyze ON expression to extract all key column pairs, time bound,
530/// and optional time column pair for stream-stream joins.
531#[allow(clippy::type_complexity)]
532fn analyze_on_expression(
533    expr: &Expr,
534) -> Result<(Vec<(String, String)>, Option<Duration>, Option<RawTimeCols>), ParseError> {
535    // Handle compound expressions (AND)
536    match expr {
537        Expr::BinaryOp {
538            left,
539            op: BinaryOperator::And,
540            right,
541        } => {
542            // Recursively analyze both sides
543            let left_result = analyze_on_expression(left);
544            let right_result = analyze_on_expression(right);
545
546            // Combine results - collect all key pairs and time bounds
547            match (left_result, right_result) {
548                (Ok((mut lk, lt, ltc)), Ok((rk, rt, rtc))) => {
549                    lk.extend(rk);
550                    Ok((lk, lt.or(rt), ltc.or(rtc)))
551                }
552                (Ok(result), Err(_)) | (Err(_), Ok(result)) => Ok(result),
553                (Err(e), Err(_)) => Err(e),
554            }
555        }
556        // Equality condition: a.col = b.col
557        Expr::BinaryOp {
558            left,
559            op: BinaryOperator::Eq,
560            right,
561        } => {
562            let left_col = extract_column_ref(left);
563            let right_col = extract_column_ref(right);
564
565            match (left_col, right_col) {
566                (Some(l), Some(r)) => Ok((vec![(l, r)], None, None)),
567                _ => Err(ParseError::StreamingError(
568                    "Cannot extract column references from equality condition".to_string(),
569                )),
570            }
571        }
572        // BETWEEN clause for time bound: p.ts BETWEEN o.ts AND o.ts + INTERVAL
573        Expr::Between {
574            expr: between_expr,
575            low,
576            high,
577            ..
578        } => {
579            // Try to extract time bound from high expression
580            let time_bound = extract_time_bound_from_expr(high).ok();
581            let between_col = extract_qualified_column_ref(between_expr);
582            let low_col = extract_qualified_column_ref(low);
583            let time_cols = if let (Some((bt, bc)), Some((lt, lc))) = (between_col, low_col) {
584                Some(RawTimeCols {
585                    expr_qualifier: bt,
586                    expr_col: bc,
587                    low_qualifier: lt,
588                    low_col: lc,
589                })
590            } else {
591                if time_bound.is_some() {
592                    tracing::warn!(
593                        "BETWEEN clause has time bound but time column references \
594                         could not be extracted (expressions must be simple column refs)"
595                    );
596                }
597                None
598            };
599            Ok((vec![], time_bound, time_cols))
600        }
601        // Comparison operators for time bounds
602        Expr::BinaryOp {
603            left: _,
604            op:
605                BinaryOperator::LtEq | BinaryOperator::Lt | BinaryOperator::GtEq | BinaryOperator::Gt,
606            right,
607        } => {
608            // Try to extract time bound from right side
609            let time_bound = extract_time_bound_from_expr(right).ok();
610            Ok((vec![], time_bound, None))
611        }
612        _ => Err(ParseError::StreamingError(format!(
613            "Unsupported join condition expression: {expr:?}"
614        ))),
615    }
616}
617
618/// Extract column reference from a function argument.
619fn extract_column_from_func_arg(arg: &FunctionArg) -> Option<String> {
620    let (FunctionArg::Unnamed(FunctionArgExpr::Expr(expr))
621    | FunctionArg::Named {
622        arg: FunctionArgExpr::Expr(expr),
623        ..
624    }
625    | FunctionArg::ExprNamed {
626        arg: FunctionArgExpr::Expr(expr),
627        ..
628    }) = arg
629    else {
630        return None;
631    };
632    extract_column_ref(expr)
633}
634
635/// Extract column reference from expression (e.g., "a.id" -> "id")
636fn extract_column_ref(expr: &Expr) -> Option<String> {
637    match expr {
638        Expr::Identifier(ident) => Some(ident.value.clone()),
639        Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
640        _ => None,
641    }
642}
643
644fn extract_qualified_column_ref(expr: &Expr) -> Option<(Option<String>, String)> {
645    match expr {
646        Expr::Identifier(ident) => Some((None, ident.value.clone())),
647        Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
648            Some((Some(parts[0].value.clone()), parts[1].value.clone()))
649        }
650        Expr::CompoundIdentifier(parts) => parts.last().map(|p| (None, p.value.clone())),
651        _ => None,
652    }
653}
654
655/// Extract time bound from an expression like "o.ts + INTERVAL '1' HOUR"
656fn extract_time_bound_from_expr(expr: &Expr) -> Result<Duration, ParseError> {
657    match expr {
658        // Direct interval
659        Expr::Interval(_) => WindowRewriter::parse_interval_to_duration(expr),
660        // Addition or subtraction: col +/- INTERVAL
661        Expr::BinaryOp {
662            left: _,
663            op: BinaryOperator::Plus | BinaryOperator::Minus,
664            right,
665        } => extract_time_bound_from_expr(right),
666        // Nested expression
667        Expr::Nested(inner) => extract_time_bound_from_expr(inner),
668        _ => Err(ParseError::StreamingError(format!(
669            "Cannot extract time bound from: {expr:?}"
670        ))),
671    }
672}
673
674/// Analyze ASOF JOIN MATCH_CONDITION expression.
675///
676/// Extracts direction, time column names, and optional tolerance.
677fn analyze_asof_match_condition(
678    expr: &Expr,
679) -> Result<(AsofSqlDirection, String, String, Option<Duration>), ParseError> {
680    if let Expr::BinaryOp {
681        left,
682        op: BinaryOperator::And,
683        right,
684    } = expr
685    {
686        // Try to get direction from left, tolerance from right
687        let dir_result = analyze_asof_direction(left);
688        let tol_result = extract_asof_tolerance(right);
689
690        match (dir_result, tol_result) {
691            (Ok((dir, lt, rt)), Ok(tol)) => Ok((dir, lt, rt, Some(tol))),
692            (Ok((dir, lt, rt)), Err(_)) => {
693                // Maybe tolerance is on left and direction on right
694                let dir2 = analyze_asof_direction(right);
695                let tol2 = extract_asof_tolerance(left);
696                match (dir2, tol2) {
697                    (Ok((d, l, r)), Ok(t)) => Ok((d, l, r, Some(t))),
698                    _ => Ok((dir, lt, rt, None)),
699                }
700            }
701            (Err(_), _) => {
702                // Try reversed
703                let dir2 = analyze_asof_direction(right);
704                let tol2 = extract_asof_tolerance(left);
705                match (dir2, tol2) {
706                    (Ok((d, l, r)), Ok(t)) => Ok((d, l, r, Some(t))),
707                    (Ok((d, l, r)), Err(_)) => Ok((d, l, r, None)),
708                    _ => Err(ParseError::StreamingError(
709                        "Cannot extract ASOF direction from MATCH_CONDITION".to_string(),
710                    )),
711                }
712            }
713        }
714    } else {
715        let (dir, lt, rt) = analyze_asof_direction(expr)?;
716        Ok((dir, lt, rt, None))
717    }
718}
719
720/// Extract ASOF direction and time columns from a comparison expression.
721fn analyze_asof_direction(expr: &Expr) -> Result<(AsofSqlDirection, String, String), ParseError> {
722    match expr {
723        Expr::BinaryOp {
724            left,
725            op: BinaryOperator::GtEq,
726            right,
727        } => {
728            let left_col = extract_column_ref(left).ok_or_else(|| {
729                ParseError::StreamingError(
730                    "Cannot extract left time column from MATCH_CONDITION".to_string(),
731                )
732            })?;
733            let right_col = extract_column_ref(right).ok_or_else(|| {
734                ParseError::StreamingError(
735                    "Cannot extract right time column from MATCH_CONDITION".to_string(),
736                )
737            })?;
738            Ok((AsofSqlDirection::Backward, left_col, right_col))
739        }
740        Expr::BinaryOp {
741            left,
742            op: BinaryOperator::LtEq,
743            right,
744        } => {
745            let left_col = extract_column_ref(left).ok_or_else(|| {
746                ParseError::StreamingError(
747                    "Cannot extract left time column from MATCH_CONDITION".to_string(),
748                )
749            })?;
750            let right_col = extract_column_ref(right).ok_or_else(|| {
751                ParseError::StreamingError(
752                    "Cannot extract right time column from MATCH_CONDITION".to_string(),
753                )
754            })?;
755            Ok((AsofSqlDirection::Forward, left_col, right_col))
756        }
757        // NEAREST(left_col, right_col) — function-style syntax
758        Expr::Function(func) => {
759            let name = func.name.to_string().to_uppercase();
760            if name != "NEAREST" {
761                return Err(ParseError::StreamingError(format!(
762                    "Unknown ASOF MATCH_CONDITION function: {name}"
763                )));
764            }
765            let args = match &func.args {
766                FunctionArguments::List(arg_list) => &arg_list.args,
767                _ => {
768                    return Err(ParseError::StreamingError(
769                        "NEAREST() requires exactly 2 column arguments".to_string(),
770                    ))
771                }
772            };
773            if args.len() != 2 {
774                return Err(ParseError::StreamingError(format!(
775                    "NEAREST() requires exactly 2 arguments, got {}",
776                    args.len()
777                )));
778            }
779            let left_col = extract_column_from_func_arg(&args[0]).ok_or_else(|| {
780                ParseError::StreamingError(
781                    "Cannot extract left time column from NEAREST()".to_string(),
782                )
783            })?;
784            let right_col = extract_column_from_func_arg(&args[1]).ok_or_else(|| {
785                ParseError::StreamingError(
786                    "Cannot extract right time column from NEAREST()".to_string(),
787                )
788            })?;
789            Ok((AsofSqlDirection::Nearest, left_col, right_col))
790        }
791        _ => Err(ParseError::StreamingError(
792            "ASOF MATCH_CONDITION must be >= or <= comparison, or NEAREST()".to_string(),
793        )),
794    }
795}
796
797/// Extract tolerance duration from an ASOF tolerance expression.
798///
799/// Handles: `left - right <= value` or `left - right <= INTERVAL '...'`
800fn extract_asof_tolerance(expr: &Expr) -> Result<Duration, ParseError> {
801    match expr {
802        Expr::BinaryOp {
803            left: _,
804            op: BinaryOperator::LtEq,
805            right,
806        } => {
807            // right side is either a literal number or INTERVAL
808            match right.as_ref() {
809                Expr::Value(v) => {
810                    if let sqlparser::ast::Value::Number(n, _) = &v.value {
811                        let ms: u64 = n.parse().map_err(|_| {
812                            ParseError::StreamingError(format!(
813                                "Cannot parse tolerance as number: {n}"
814                            ))
815                        })?;
816                        Ok(Duration::from_millis(ms))
817                    } else {
818                        Err(ParseError::StreamingError(
819                            "ASOF tolerance must be a number or INTERVAL".to_string(),
820                        ))
821                    }
822                }
823                Expr::Interval(_) => WindowRewriter::parse_interval_to_duration(right),
824                _ => Err(ParseError::StreamingError(
825                    "ASOF tolerance must be a number or INTERVAL".to_string(),
826                )),
827            }
828        }
829        _ => Err(ParseError::StreamingError(
830            "ASOF tolerance expression must be <= comparison".to_string(),
831        )),
832    }
833}
834
835/// Extract key columns from an ASOF JOIN constraint (ON clause).
836fn analyze_asof_constraint(constraint: &JoinConstraint) -> Result<(String, String), ParseError> {
837    match constraint {
838        JoinConstraint::On(expr) => extract_equality_columns(expr),
839        JoinConstraint::Using(cols) => {
840            if cols.is_empty() {
841                return Err(ParseError::StreamingError(
842                    "USING clause requires at least one column".to_string(),
843                ));
844            }
845            let col = cols[0].to_string();
846            Ok((col.clone(), col))
847        }
848        _ => Err(ParseError::StreamingError(
849            "ASOF JOIN requires ON or USING constraint".to_string(),
850        )),
851    }
852}
853
854/// Extract left and right column names from an equality expression.
855fn extract_equality_columns(expr: &Expr) -> Result<(String, String), ParseError> {
856    match expr {
857        Expr::BinaryOp {
858            left,
859            op: BinaryOperator::Eq,
860            right,
861        } => {
862            let left_col = extract_column_ref(left).ok_or_else(|| {
863                ParseError::StreamingError("Cannot extract left key column".to_string())
864            })?;
865            let right_col = extract_column_ref(right).ok_or_else(|| {
866                ParseError::StreamingError("Cannot extract right key column".to_string())
867            })?;
868            Ok((left_col, right_col))
869        }
870        // If there's an AND, find the equality part
871        Expr::BinaryOp {
872            left,
873            op: BinaryOperator::And,
874            right,
875        } => extract_equality_columns(left).or_else(|_| extract_equality_columns(right)),
876        _ => Err(ParseError::StreamingError(
877            "ASOF JOIN ON clause must contain an equality condition".to_string(),
878        )),
879    }
880}
881
882/// Check if a SELECT contains a join.
883#[must_use]
884pub fn has_join(select: &Select) -> bool {
885    !select.from.is_empty() && !select.from[0].joins.is_empty()
886}
887
888/// Count the number of joins in a SELECT.
889#[must_use]
890pub fn count_joins(select: &Select) -> usize {
891    select
892        .from
893        .iter()
894        .map(|table_with_joins| table_with_joins.joins.len())
895        .sum()
896}
897
898/// Analysis result for multi-way JOINs (e.g., `A JOIN B ... JOIN C ...`).
899///
900/// Each step represents one left-deep join: step 0 joins the base table with
901/// the first right table, step 1 joins the result with the next right table, etc.
902#[derive(Debug, Clone)]
903pub struct MultiJoinAnalysis {
904    /// Ordered join steps (left-to-right)
905    pub joins: Vec<JoinAnalysis>,
906    /// All referenced tables in order (base table first, then each right table)
907    pub tables: Vec<String>,
908}
909
910impl MultiJoinAnalysis {
911    /// Number of join steps.
912    #[must_use]
913    pub fn len(&self) -> usize {
914        self.joins.len()
915    }
916
917    /// Whether there are no join steps.
918    #[must_use]
919    pub fn is_empty(&self) -> bool {
920        self.joins.is_empty()
921    }
922
923    /// Whether this is a single join (backward-compatible case).
924    #[must_use]
925    pub fn is_single(&self) -> bool {
926        self.joins.len() == 1
927    }
928
929    /// The first join step (convenience for single-join queries).
930    #[must_use]
931    pub fn first(&self) -> Option<&JoinAnalysis> {
932        self.joins.first()
933    }
934}
935
936/// Analyze a SELECT statement for all join steps (multi-way).
937///
938/// Returns `None` if the query has no joins. For a single join this
939/// returns a `MultiJoinAnalysis` with one step, making it backward
940/// compatible with `analyze_join()`.
941///
942/// # Errors
943///
944/// Returns `ParseError::StreamingError` if any join constraint is
945/// not supported or key columns cannot be extracted.
946pub fn analyze_joins(select: &Select) -> Result<Option<MultiJoinAnalysis>, ParseError> {
947    let from = &select.from;
948    if from.is_empty() {
949        return Ok(None);
950    }
951
952    let first_table = &from[0];
953    if first_table.joins.is_empty() {
954        return Ok(None);
955    }
956
957    // Extract base table
958    let base_table = extract_table_name(&first_table.relation)?;
959    let base_alias = extract_table_alias(&first_table.relation);
960
961    let mut join_steps = Vec::with_capacity(first_table.joins.len());
962    let mut tables = vec![base_table.clone()];
963
964    // Track the left table name for left-deep chaining
965    let mut prev_left_table = base_table;
966    let mut prev_left_alias = base_alias;
967
968    for join in &first_table.joins {
969        let right_table = extract_table_name(&join.relation)?;
970        let right_alias = extract_table_alias(&join.relation);
971        tables.push(right_table.clone());
972
973        let join_type = map_join_operator(&join.join_operator);
974
975        // Handle ASOF JOIN
976        if let JoinOperator::AsOf {
977            match_condition,
978            constraint,
979        } = &join.join_operator
980        {
981            let (direction, left_time, right_time, tolerance) =
982                analyze_asof_match_condition(match_condition)?;
983            let (left_key, right_key) = analyze_asof_constraint(constraint)?;
984
985            let mut analysis = JoinAnalysis::asof(
986                prev_left_table.clone(),
987                right_table.clone(),
988                left_key,
989                right_key,
990                direction,
991                left_time,
992                right_time,
993                tolerance,
994            );
995            analysis.left_alias.clone_from(&prev_left_alias);
996            analysis.right_alias = right_alias;
997            join_steps.push(analysis);
998        } else if let Some(version_col) = extract_temporal_version(&join.relation) {
999            // Temporal join: right side has FOR SYSTEM_TIME AS OF
1000            let (left_key, right_key, additional, _, _) =
1001                analyze_join_constraint(&join.join_operator)?;
1002
1003            let mut analysis = JoinAnalysis::temporal(
1004                prev_left_table.clone(),
1005                right_table.clone(),
1006                left_key,
1007                right_key,
1008                version_col,
1009                join_type,
1010            );
1011            analysis.left_alias.clone_from(&prev_left_alias);
1012            analysis.right_alias = right_alias;
1013            analysis.additional_key_columns = additional;
1014            join_steps.push(analysis);
1015        } else {
1016            // Regular join (inner, left, right, full)
1017            let (left_key, right_key, additional, time_bound, time_cols) =
1018                analyze_join_constraint(&join.join_operator)?;
1019
1020            let mut analysis = if let Some(tb) = time_bound {
1021                JoinAnalysis::stream_stream(
1022                    prev_left_table.clone(),
1023                    right_table.clone(),
1024                    left_key,
1025                    right_key,
1026                    tb,
1027                    join_type,
1028                )
1029            } else {
1030                JoinAnalysis::lookup(
1031                    prev_left_table.clone(),
1032                    right_table.clone(),
1033                    left_key,
1034                    right_key,
1035                    join_type,
1036                )
1037            };
1038            analysis.left_alias.clone_from(&prev_left_alias);
1039            analysis.right_alias.clone_from(&right_alias);
1040            analysis.additional_key_columns = additional;
1041
1042            if let Some(ref raw) = time_cols {
1043                let (lt, rt) = resolve_time_cols(
1044                    raw,
1045                    &analysis.left_table,
1046                    &analysis.right_table,
1047                    prev_left_alias.as_deref(),
1048                    right_alias.as_deref(),
1049                );
1050                analysis.left_time_column = Some(lt);
1051                analysis.right_time_column = Some(rt);
1052            }
1053            join_steps.push(analysis);
1054        }
1055
1056        // Next step's left table is this step's right table (left-deep)
1057        prev_left_table = right_table;
1058        prev_left_alias = extract_table_alias(&join.relation);
1059    }
1060
1061    Ok(Some(MultiJoinAnalysis {
1062        joins: join_steps,
1063        tables,
1064    }))
1065}
1066
1067#[cfg(test)]
1068mod tests {
1069    use super::*;
1070    use sqlparser::ast::{SetExpr, Statement};
1071    use sqlparser::dialect::GenericDialect;
1072    use sqlparser::parser::Parser;
1073
1074    fn parse_select(sql: &str) -> Select {
1075        let dialect = GenericDialect {};
1076        let statements = Parser::parse_sql(&dialect, sql).unwrap();
1077        if let Statement::Query(query) = &statements[0] {
1078            if let SetExpr::Select(select) = query.body.as_ref() {
1079                return *select.clone();
1080            }
1081        }
1082        panic!("Expected SELECT query");
1083    }
1084
1085    #[test]
1086    fn test_analyze_inner_join() {
1087        let sql = "SELECT * FROM orders o INNER JOIN payments p ON o.order_id = p.order_id";
1088        let select = parse_select(sql);
1089
1090        let analysis = analyze_join(&select).unwrap().unwrap();
1091
1092        assert_eq!(analysis.join_type, JoinType::Inner);
1093        assert_eq!(analysis.left_table, "orders");
1094        assert_eq!(analysis.right_table, "payments");
1095        assert_eq!(analysis.left_key_column, "order_id");
1096        assert_eq!(analysis.right_key_column, "order_id");
1097        assert!(analysis.is_lookup_join); // No time bound = lookup join
1098    }
1099
1100    #[test]
1101    fn test_analyze_left_join() {
1102        let sql = "SELECT * FROM orders o LEFT JOIN customers c ON o.customer_id = c.id";
1103        let select = parse_select(sql);
1104
1105        let analysis = analyze_join(&select).unwrap().unwrap();
1106
1107        assert_eq!(analysis.join_type, JoinType::Left);
1108        assert_eq!(analysis.left_key_column, "customer_id");
1109        assert_eq!(analysis.right_key_column, "id");
1110    }
1111
1112    #[test]
1113    fn test_analyze_join_using() {
1114        let sql = "SELECT * FROM orders o JOIN payments p USING (order_id)";
1115        let select = parse_select(sql);
1116
1117        let analysis = analyze_join(&select).unwrap().unwrap();
1118
1119        assert_eq!(analysis.left_key_column, "order_id");
1120        assert_eq!(analysis.right_key_column, "order_id");
1121    }
1122
1123    #[test]
1124    fn test_analyze_stream_stream_join_with_time_bound() {
1125        let sql = "SELECT * FROM orders o
1126                   JOIN payments p ON o.order_id = p.order_id
1127                   AND p.ts BETWEEN o.ts AND o.ts + INTERVAL '1' HOUR";
1128        let select = parse_select(sql);
1129
1130        let analysis = analyze_join(&select).unwrap().unwrap();
1131
1132        assert!(!analysis.is_lookup_join);
1133        assert!(analysis.time_bound.is_some());
1134        assert_eq!(analysis.time_bound.unwrap(), Duration::from_secs(3600));
1135    }
1136
1137    #[test]
1138    fn test_no_join() {
1139        let sql = "SELECT * FROM orders";
1140        let select = parse_select(sql);
1141
1142        let analysis = analyze_join(&select).unwrap();
1143        assert!(analysis.is_none());
1144    }
1145
1146    #[test]
1147    fn test_has_join() {
1148        let sql_with_join = "SELECT * FROM orders o JOIN payments p ON o.id = p.order_id";
1149        let sql_without_join = "SELECT * FROM orders";
1150
1151        let select_with = parse_select(sql_with_join);
1152        let select_without = parse_select(sql_without_join);
1153
1154        assert!(has_join(&select_with));
1155        assert!(!has_join(&select_without));
1156    }
1157
1158    #[test]
1159    fn test_count_joins() {
1160        let sql_one = "SELECT * FROM a JOIN b ON a.id = b.id";
1161        let sql_two = "SELECT * FROM a JOIN b ON a.id = b.id JOIN c ON b.id = c.id";
1162        let sql_zero = "SELECT * FROM a";
1163
1164        assert_eq!(count_joins(&parse_select(sql_one)), 1);
1165        assert_eq!(count_joins(&parse_select(sql_two)), 2);
1166        assert_eq!(count_joins(&parse_select(sql_zero)), 0);
1167    }
1168
1169    #[test]
1170    fn test_aliases() {
1171        let sql = "SELECT * FROM orders AS o JOIN payments AS p ON o.id = p.order_id";
1172        let select = parse_select(sql);
1173
1174        let analysis = analyze_join(&select).unwrap().unwrap();
1175
1176        assert_eq!(analysis.left_alias, Some("o".to_string()));
1177        assert_eq!(analysis.right_alias, Some("p".to_string()));
1178    }
1179
1180    // -- ASOF JOIN tests --
1181
1182    fn parse_select_snowflake(sql: &str) -> Select {
1183        let dialect = sqlparser::dialect::SnowflakeDialect {};
1184        let statements = Parser::parse_sql(&dialect, sql).unwrap();
1185        if let Statement::Query(query) = &statements[0] {
1186            if let SetExpr::Select(select) = query.body.as_ref() {
1187                return *select.clone();
1188            }
1189        }
1190        panic!("Expected SELECT query");
1191    }
1192
1193    fn parse_select_laminar(sql: &str) -> Select {
1194        let dialect = crate::parser::dialect::LaminarDialect::default();
1195        let statements = Parser::parse_sql(&dialect, sql).unwrap();
1196        if let Statement::Query(query) = &statements[0] {
1197            if let SetExpr::Select(select) = query.body.as_ref() {
1198                return *select.clone();
1199            }
1200        }
1201        panic!("Expected SELECT query");
1202    }
1203
1204    #[test]
1205    fn test_asof_join_backward() {
1206        let sql = "SELECT * FROM trades t \
1207                    ASOF JOIN quotes q \
1208                    MATCH_CONDITION(t.ts >= q.ts) \
1209                    ON t.symbol = q.symbol";
1210        let select = parse_select_snowflake(sql);
1211        let analysis = analyze_join(&select).unwrap().unwrap();
1212
1213        assert!(analysis.is_asof_join);
1214        assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
1215        assert_eq!(analysis.join_type, JoinType::AsOf);
1216        assert!(analysis.asof_tolerance.is_none());
1217    }
1218
1219    #[test]
1220    fn test_asof_join_forward() {
1221        let sql = "SELECT * FROM trades t \
1222                    ASOF JOIN quotes q \
1223                    MATCH_CONDITION(t.ts <= q.ts) \
1224                    ON t.symbol = q.symbol";
1225        let select = parse_select_snowflake(sql);
1226        let analysis = analyze_join(&select).unwrap().unwrap();
1227
1228        assert!(analysis.is_asof_join);
1229        assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Forward));
1230    }
1231
1232    #[test]
1233    fn test_asof_join_nearest() {
1234        let sql = "SELECT * FROM trades t \
1235                    ASOF JOIN quotes q \
1236                    MATCH_CONDITION(NEAREST(t.ts, q.ts)) \
1237                    ON t.symbol = q.symbol";
1238        let select = parse_select_snowflake(sql);
1239        let analysis = analyze_join(&select).unwrap().unwrap();
1240
1241        assert!(analysis.is_asof_join);
1242        assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Nearest));
1243        assert_eq!(analysis.join_type, JoinType::AsOf);
1244        assert!(analysis.asof_tolerance.is_none());
1245    }
1246
1247    #[test]
1248    fn test_asof_join_with_tolerance() {
1249        let sql = "SELECT * FROM trades t \
1250                    ASOF JOIN quotes q \
1251                    MATCH_CONDITION(t.ts >= q.ts AND t.ts - q.ts <= 5000) \
1252                    ON t.symbol = q.symbol";
1253        let select = parse_select_snowflake(sql);
1254        let analysis = analyze_join(&select).unwrap().unwrap();
1255
1256        assert!(analysis.is_asof_join);
1257        assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
1258        assert_eq!(analysis.asof_tolerance, Some(Duration::from_secs(5)));
1259    }
1260
1261    #[test]
1262    fn test_asof_join_with_interval_tolerance() {
1263        let sql = "SELECT * FROM trades t \
1264                    ASOF JOIN quotes q \
1265                    MATCH_CONDITION(t.ts >= q.ts AND t.ts - q.ts <= INTERVAL '5' SECOND) \
1266                    ON t.symbol = q.symbol";
1267        let select = parse_select_snowflake(sql);
1268        let analysis = analyze_join(&select).unwrap().unwrap();
1269
1270        assert!(analysis.is_asof_join);
1271        assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
1272        assert_eq!(analysis.asof_tolerance, Some(Duration::from_secs(5)));
1273    }
1274
1275    #[test]
1276    fn test_asof_join_type_mapping() {
1277        let sql = "SELECT * FROM trades t \
1278                    ASOF JOIN quotes q \
1279                    MATCH_CONDITION(t.ts >= q.ts) \
1280                    ON t.symbol = q.symbol";
1281        let select = parse_select_snowflake(sql);
1282        let analysis = analyze_join(&select).unwrap().unwrap();
1283
1284        assert_eq!(analysis.join_type, JoinType::AsOf);
1285        assert!(!analysis.is_lookup_join);
1286    }
1287
1288    #[test]
1289    fn test_asof_join_extracts_time_columns() {
1290        let sql = "SELECT * FROM trades t \
1291                    ASOF JOIN quotes q \
1292                    MATCH_CONDITION(t.ts >= q.ts) \
1293                    ON t.symbol = q.symbol";
1294        let select = parse_select_snowflake(sql);
1295        let analysis = analyze_join(&select).unwrap().unwrap();
1296
1297        assert_eq!(analysis.left_time_column, Some("ts".to_string()));
1298        assert_eq!(analysis.right_time_column, Some("ts".to_string()));
1299    }
1300
1301    #[test]
1302    fn test_asof_join_extracts_key_columns() {
1303        let sql = "SELECT * FROM trades t \
1304                    ASOF JOIN quotes q \
1305                    MATCH_CONDITION(t.ts >= q.ts) \
1306                    ON t.symbol = q.symbol";
1307        let select = parse_select_snowflake(sql);
1308        let analysis = analyze_join(&select).unwrap().unwrap();
1309
1310        assert_eq!(analysis.left_key_column, "symbol");
1311        assert_eq!(analysis.right_key_column, "symbol");
1312    }
1313
1314    #[test]
1315    fn test_asof_join_aliases() {
1316        let sql = "SELECT * FROM trades AS t \
1317                    ASOF JOIN quotes AS q \
1318                    MATCH_CONDITION(t.ts >= q.ts) \
1319                    ON t.symbol = q.symbol";
1320        let select = parse_select_snowflake(sql);
1321        let analysis = analyze_join(&select).unwrap().unwrap();
1322
1323        assert_eq!(analysis.left_alias, Some("t".to_string()));
1324        assert_eq!(analysis.right_alias, Some("q".to_string()));
1325        assert_eq!(analysis.left_table, "trades");
1326        assert_eq!(analysis.right_table, "quotes");
1327    }
1328
1329    // -- Multi-way JOIN tests --
1330
1331    #[test]
1332    fn test_multi_join_single_backward_compat() {
1333        let sql = "SELECT * FROM orders o JOIN payments p ON o.id = p.order_id";
1334        let select = parse_select(sql);
1335        let multi = analyze_joins(&select).unwrap().unwrap();
1336
1337        assert!(multi.is_single());
1338        assert_eq!(multi.len(), 1);
1339        assert!(!multi.is_empty());
1340        let first = multi.first().unwrap();
1341        assert_eq!(first.left_table, "orders");
1342        assert_eq!(first.right_table, "payments");
1343    }
1344
1345    #[test]
1346    fn test_multi_join_two_way() {
1347        let sql = "SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id";
1348        let select = parse_select(sql);
1349        let multi = analyze_joins(&select).unwrap().unwrap();
1350
1351        assert_eq!(multi.len(), 2);
1352        assert!(!multi.is_single());
1353
1354        assert_eq!(multi.joins[0].left_table, "a");
1355        assert_eq!(multi.joins[0].right_table, "b");
1356        assert_eq!(multi.joins[0].left_key_column, "id");
1357        assert_eq!(multi.joins[0].right_key_column, "a_id");
1358
1359        assert_eq!(multi.joins[1].left_table, "b");
1360        assert_eq!(multi.joins[1].right_table, "c");
1361        assert_eq!(multi.joins[1].left_key_column, "id");
1362        assert_eq!(multi.joins[1].right_key_column, "b_id");
1363    }
1364
1365    #[test]
1366    fn test_multi_join_three_way() {
1367        let sql = "SELECT * FROM a \
1368                    JOIN b ON a.id = b.a_id \
1369                    JOIN c ON b.id = c.b_id \
1370                    JOIN d ON c.id = d.c_id";
1371        let select = parse_select(sql);
1372        let multi = analyze_joins(&select).unwrap().unwrap();
1373
1374        assert_eq!(multi.len(), 3);
1375        assert_eq!(multi.tables.len(), 4);
1376        assert_eq!(multi.tables, vec!["a", "b", "c", "d"]);
1377    }
1378
1379    #[test]
1380    fn test_multi_join_mixed_asof_and_lookup() {
1381        // ASOF first, then lookup (use Snowflake dialect for ASOF)
1382        let sql = "SELECT * FROM trades t \
1383                    ASOF JOIN quotes q \
1384                    MATCH_CONDITION(t.ts >= q.ts) \
1385                    ON t.symbol = q.symbol \
1386                    JOIN products p ON q.product_id = p.id";
1387        let select = parse_select_snowflake(sql);
1388        let multi = analyze_joins(&select).unwrap().unwrap();
1389
1390        assert_eq!(multi.len(), 2);
1391        assert!(multi.joins[0].is_asof_join);
1392        assert!(multi.joins[1].is_lookup_join);
1393    }
1394
1395    #[test]
1396    fn test_multi_join_stream_stream_and_lookup() {
1397        let sql = "SELECT * FROM orders o \
1398                    JOIN payments p ON o.id = p.order_id \
1399                        AND p.ts BETWEEN o.ts AND o.ts + INTERVAL '1' HOUR \
1400                    JOIN customers c ON o.customer_id = c.id";
1401        let select = parse_select(sql);
1402        let multi = analyze_joins(&select).unwrap().unwrap();
1403
1404        assert_eq!(multi.len(), 2);
1405        assert!(!multi.joins[0].is_lookup_join); // stream-stream
1406        assert!(multi.joins[0].time_bound.is_some());
1407        assert!(multi.joins[1].is_lookup_join); // lookup
1408    }
1409
1410    #[test]
1411    fn test_multi_join_tables_list() {
1412        let sql = "SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id";
1413        let select = parse_select(sql);
1414        let multi = analyze_joins(&select).unwrap().unwrap();
1415
1416        assert_eq!(multi.tables, vec!["a", "b", "c"]);
1417    }
1418
1419    #[test]
1420    fn test_multi_join_aliases() {
1421        let sql = "SELECT * FROM orders AS o \
1422                    JOIN payments AS p ON o.id = p.order_id \
1423                    JOIN refunds AS r ON p.id = r.payment_id";
1424        let select = parse_select(sql);
1425        let multi = analyze_joins(&select).unwrap().unwrap();
1426
1427        assert_eq!(multi.joins[0].left_alias, Some("o".to_string()));
1428        assert_eq!(multi.joins[0].right_alias, Some("p".to_string()));
1429        assert_eq!(multi.joins[1].left_alias, Some("p".to_string()));
1430        assert_eq!(multi.joins[1].right_alias, Some("r".to_string()));
1431    }
1432
1433    #[test]
1434    fn test_multi_join_no_join_returns_none() {
1435        let sql = "SELECT * FROM orders";
1436        let select = parse_select(sql);
1437        let multi = analyze_joins(&select).unwrap();
1438        assert!(multi.is_none());
1439    }
1440
1441    // -- Temporal JOIN tests (FOR SYSTEM_TIME AS OF) --
1442
1443    #[test]
1444    fn test_temporal_join_detected() {
1445        let sql = "SELECT o.*, p.price \
1446                    FROM orders o \
1447                    JOIN products FOR SYSTEM_TIME AS OF o.order_time AS p \
1448                    ON o.product_id = p.id";
1449        let select = parse_select_laminar(sql);
1450        let analysis = analyze_join(&select).unwrap().unwrap();
1451
1452        assert!(analysis.is_temporal_join);
1453        assert_eq!(
1454            analysis.temporal_version_column,
1455            Some("order_time".to_string())
1456        );
1457        assert_eq!(analysis.left_table, "orders");
1458        assert_eq!(analysis.right_table, "products");
1459        assert_eq!(analysis.left_key_column, "product_id");
1460        assert_eq!(analysis.right_key_column, "id");
1461        assert!(!analysis.is_lookup_join);
1462        assert!(!analysis.is_asof_join);
1463    }
1464
1465    #[test]
1466    fn test_temporal_join_via_analyze_joins() {
1467        let sql = "SELECT o.*, p.price \
1468                    FROM orders o \
1469                    JOIN products FOR SYSTEM_TIME AS OF o.order_time AS p \
1470                    ON o.product_id = p.id";
1471        let select = parse_select_laminar(sql);
1472        let multi = analyze_joins(&select).unwrap().unwrap();
1473
1474        assert_eq!(multi.len(), 1);
1475        let first = multi.first().unwrap();
1476        assert!(first.is_temporal_join);
1477        assert_eq!(
1478            first.temporal_version_column,
1479            Some("order_time".to_string())
1480        );
1481    }
1482
1483    #[test]
1484    fn test_non_temporal_join_not_flagged() {
1485        let sql = "SELECT * FROM orders o JOIN payments p ON o.id = p.order_id";
1486        let select = parse_select(sql);
1487        let analysis = analyze_join(&select).unwrap().unwrap();
1488
1489        assert!(!analysis.is_temporal_join);
1490        assert!(analysis.temporal_version_column.is_none());
1491    }
1492
1493    #[test]
1494    fn test_unqualified_anti_maps_to_left_anti() {
1495        let sql = "SELECT * FROM orders o ANTI JOIN returns r ON o.id = r.order_id";
1496        let select = parse_select(sql);
1497        let analysis = analyze_join(&select).unwrap().unwrap();
1498        assert_eq!(analysis.join_type, JoinType::LeftAnti);
1499    }
1500
1501    #[test]
1502    fn test_unqualified_semi_maps_to_left_semi() {
1503        let sql = "SELECT * FROM orders o SEMI JOIN payments p ON o.id = p.order_id";
1504        let select = parse_select(sql);
1505        let analysis = analyze_join(&select).unwrap().unwrap();
1506        assert_eq!(analysis.join_type, JoinType::LeftSemi);
1507    }
1508
1509    #[test]
1510    fn test_composite_join_keys() {
1511        let sql = "SELECT * FROM orders o \
1512                    JOIN shipments s \
1513                    ON o.order_id = s.order_id AND o.region = s.region";
1514        let select = parse_select(sql);
1515        let analysis = analyze_join(&select).unwrap().unwrap();
1516
1517        // First key pair is the primary key
1518        assert_eq!(analysis.left_key_column, "order_id");
1519        assert_eq!(analysis.right_key_column, "order_id");
1520
1521        // Second key pair should be in additional_key_columns
1522        assert_eq!(
1523            analysis.additional_key_columns.len(),
1524            1,
1525            "Should have 1 additional key pair"
1526        );
1527        assert_eq!(analysis.additional_key_columns[0].0, "region");
1528        assert_eq!(analysis.additional_key_columns[0].1, "region");
1529    }
1530
1531    #[test]
1532    fn test_composite_using_clause() {
1533        let sql = "SELECT * FROM orders o JOIN shipments s USING (order_id, region)";
1534        let select = parse_select(sql);
1535        let analysis = analyze_join(&select).unwrap().unwrap();
1536
1537        // First column becomes primary key
1538        assert_eq!(analysis.left_key_column, "order_id");
1539        assert_eq!(analysis.right_key_column, "order_id");
1540
1541        // Additional columns
1542        assert_eq!(
1543            analysis.additional_key_columns.len(),
1544            1,
1545            "USING(order_id, region) should have 1 additional key"
1546        );
1547        assert_eq!(analysis.additional_key_columns[0].0, "region");
1548        assert_eq!(analysis.additional_key_columns[0].1, "region");
1549    }
1550}