Skip to main content

laminar_sql/parser/
analytic_parser.rs

1//! Analytic window function detection and extraction
2//!
3//! Analyzes SQL queries for analytic functions like LAG, LEAD, FIRST_VALUE,
4//! LAST_VALUE, and NTH_VALUE with OVER clauses. These are per-row window
5//! functions (distinct from GROUP BY aggregate windows like TUMBLE/HOP/SESSION).
6
7use sqlparser::ast::{Expr, SelectItem, SetExpr, Statement};
8
9/// Types of analytic window functions.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub enum AnalyticFunctionType {
12    /// LAG(col, offset, default) — look back `offset` rows in partition.
13    Lag,
14    /// LEAD(col, offset, default) — look ahead `offset` rows in partition.
15    Lead,
16    /// FIRST_VALUE(col) OVER (...) — first value in window frame.
17    FirstValue,
18    /// LAST_VALUE(col) OVER (...) — last value in window frame.
19    LastValue,
20    /// NTH_VALUE(col, n) OVER (...) — n-th value in window frame.
21    NthValue,
22}
23
24impl AnalyticFunctionType {
25    /// Returns the function name as used in SQL.
26    #[must_use]
27    pub fn sql_name(&self) -> &'static str {
28        match self {
29            Self::Lag => "LAG",
30            Self::Lead => "LEAD",
31            Self::FirstValue => "FIRST_VALUE",
32            Self::LastValue => "LAST_VALUE",
33            Self::NthValue => "NTH_VALUE",
34        }
35    }
36
37    /// Returns true if this function requires buffering future events.
38    #[must_use]
39    pub fn requires_lookahead(&self) -> bool {
40        matches!(self, Self::Lead)
41    }
42}
43
44/// Information about a single analytic function call.
45#[derive(Debug, Clone, PartialEq, Eq)]
46pub struct AnalyticFunctionInfo {
47    /// Type of analytic function
48    pub function_type: AnalyticFunctionType,
49    /// Column being referenced (first argument)
50    pub column: String,
51    /// Offset for LAG/LEAD (default 1), or N for NTH_VALUE
52    pub offset: usize,
53    /// Default value expression as string (for LAG/LEAD third argument)
54    pub default_value: Option<String>,
55    /// Output alias (AS name)
56    pub alias: Option<String>,
57}
58
59/// Result of analyzing analytic functions in a query.
60#[derive(Debug, Clone, PartialEq, Eq)]
61pub struct AnalyticWindowAnalysis {
62    /// Analytic functions found in the query
63    pub functions: Vec<AnalyticFunctionInfo>,
64    /// PARTITION BY columns from the OVER clause
65    pub partition_columns: Vec<String>,
66    /// ORDER BY columns from the OVER clause
67    pub order_columns: Vec<String>,
68}
69
70impl AnalyticWindowAnalysis {
71    /// Returns true if any function requires lookahead (LEAD).
72    #[must_use]
73    pub fn has_lookahead(&self) -> bool {
74        self.functions
75            .iter()
76            .any(|f| f.function_type.requires_lookahead())
77    }
78
79    /// Returns the maximum offset across all functions.
80    #[must_use]
81    pub fn max_offset(&self) -> usize {
82        self.functions.iter().map(|f| f.offset).max().unwrap_or(0)
83    }
84}
85
86/// Analyzes a SQL statement for analytic window functions.
87///
88/// Walks SELECT items looking for functions with OVER clauses that match
89/// LAG, LEAD, FIRST_VALUE, LAST_VALUE, or NTH_VALUE. Returns `None` if
90/// no analytic functions are found.
91///
92/// # Arguments
93///
94/// * `stmt` - The SQL statement to analyze
95///
96/// # Returns
97///
98/// An `AnalyticWindowAnalysis` if analytic functions are found, or `None`.
99#[must_use]
100pub fn analyze_analytic_functions(stmt: &Statement) -> Option<AnalyticWindowAnalysis> {
101    let Statement::Query(query) = stmt else {
102        return None;
103    };
104
105    let SetExpr::Select(select) = query.body.as_ref() else {
106        return None;
107    };
108
109    let mut functions = Vec::new();
110    let mut partition_columns = Vec::new();
111    let mut order_columns = Vec::new();
112    let mut first_window = true;
113
114    for item in &select.projection {
115        let (expr, alias) = match item {
116            SelectItem::UnnamedExpr(expr) => (expr, None),
117            SelectItem::ExprWithAlias { expr, alias } => (expr, Some(alias.value.clone())),
118            _ => continue,
119        };
120
121        if let Some(info) = extract_analytic_function(expr, alias, &mut |spec| {
122            if first_window {
123                partition_columns = spec
124                    .partition_by
125                    .iter()
126                    .filter_map(extract_column_name)
127                    .collect();
128                order_columns = spec
129                    .order_by
130                    .iter()
131                    .filter_map(|ob| extract_column_name(&ob.expr))
132                    .collect();
133                first_window = false;
134            }
135        }) {
136            functions.push(info);
137        }
138    }
139
140    if functions.is_empty() {
141        return None;
142    }
143
144    Some(AnalyticWindowAnalysis {
145        functions,
146        partition_columns,
147        order_columns,
148    })
149}
150
151/// Extracts an analytic function from an expression.
152///
153/// Returns function info if the expression is a recognized analytic function
154/// with an OVER clause. Calls `on_window_spec` with the window spec from the
155/// first function found so the caller can extract partition/order columns.
156fn extract_analytic_function(
157    expr: &Expr,
158    alias: Option<String>,
159    on_window_spec: &mut dyn FnMut(&sqlparser::ast::WindowSpec),
160) -> Option<AnalyticFunctionInfo> {
161    let Expr::Function(func) = expr else {
162        return None;
163    };
164
165    let name = func.name.to_string().to_uppercase();
166    let function_type = match name.as_str() {
167        "LAG" => AnalyticFunctionType::Lag,
168        "LEAD" => AnalyticFunctionType::Lead,
169        "FIRST_VALUE" => AnalyticFunctionType::FirstValue,
170        "LAST_VALUE" => AnalyticFunctionType::LastValue,
171        "NTH_VALUE" => AnalyticFunctionType::NthValue,
172        _ => return None,
173    };
174
175    // Must have an OVER clause to be an analytic function
176    let window_spec = func.over.as_ref()?;
177    match window_spec {
178        sqlparser::ast::WindowType::WindowSpec(spec) => {
179            on_window_spec(spec);
180        }
181        sqlparser::ast::WindowType::NamedWindow(_) => {}
182    }
183
184    // Extract arguments
185    let args = extract_function_args(func);
186
187    // First arg is the column
188    let column = args.first().cloned().unwrap_or_default();
189
190    // Second arg is offset (for LAG/LEAD) or N (for NTH_VALUE), default 1
191    let offset = args
192        .get(1)
193        .and_then(|s| s.parse::<usize>().ok())
194        .unwrap_or(1);
195
196    // Third arg is default value (for LAG/LEAD only)
197    let default_value = if matches!(
198        function_type,
199        AnalyticFunctionType::Lag | AnalyticFunctionType::Lead
200    ) {
201        args.get(2).cloned()
202    } else {
203        None
204    };
205
206    Some(AnalyticFunctionInfo {
207        function_type,
208        column,
209        offset,
210        default_value,
211        alias,
212    })
213}
214
215/// Extracts function argument expressions as strings.
216fn extract_function_args(func: &sqlparser::ast::Function) -> Vec<String> {
217    match &func.args {
218        sqlparser::ast::FunctionArguments::List(list) => list
219            .args
220            .iter()
221            .filter_map(|arg| match arg {
222                sqlparser::ast::FunctionArg::Unnamed(sqlparser::ast::FunctionArgExpr::Expr(
223                    expr,
224                )) => Some(expr_to_string(expr)),
225                _ => None,
226            })
227            .collect(),
228        _ => vec![],
229    }
230}
231
232/// Converts an expression to its string representation.
233fn expr_to_string(expr: &Expr) -> String {
234    match expr {
235        Expr::Identifier(ident) => ident.value.clone(),
236        Expr::CompoundIdentifier(parts) => parts.last().map_or(String::new(), |p| p.value.clone()),
237        Expr::Value(value_with_span) => match &value_with_span.value {
238            sqlparser::ast::Value::Number(n, _) => n.clone(),
239            sqlparser::ast::Value::SingleQuotedString(s) => s.clone(),
240            sqlparser::ast::Value::Null => "NULL".to_string(),
241            _ => format!("{}", value_with_span.value),
242        },
243        Expr::UnaryOp {
244            op: sqlparser::ast::UnaryOperator::Minus,
245            expr: inner,
246        } => format!("-{}", expr_to_string(inner)),
247        _ => expr.to_string(),
248    }
249}
250
251/// Extracts a simple column name from an expression.
252fn extract_column_name(expr: &Expr) -> Option<String> {
253    match expr {
254        Expr::Identifier(ident) => Some(ident.value.clone()),
255        Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
256        _ => None,
257    }
258}
259
260// --- Window Frame types ---
261
262/// Types of aggregate functions used with window frames.
263#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
264pub enum WindowFrameFunction {
265    /// AVG(col) OVER (... ROWS BETWEEN ...)
266    Avg,
267    /// SUM(col) OVER (... ROWS BETWEEN ...)
268    Sum,
269    /// MIN(col) OVER (... ROWS BETWEEN ...)
270    Min,
271    /// MAX(col) OVER (... ROWS BETWEEN ...)
272    Max,
273    /// COUNT(*) OVER (... ROWS BETWEEN ...)
274    Count,
275    /// FIRST_VALUE(col) OVER (... ROWS BETWEEN ...)
276    FirstValue,
277    /// LAST_VALUE(col) OVER (... ROWS BETWEEN ...)
278    LastValue,
279    /// CORR(x, y) OVER (... ROWS BETWEEN ...) — Pearson correlation over the
280    /// frame. Bivariate; `column` holds the first argument.
281    Corr,
282    /// COVAR_SAMP(x, y) OVER (... ROWS BETWEEN ...) — sample covariance.
283    CovarSamp,
284    /// COVAR_POP(x, y) OVER (... ROWS BETWEEN ...) — population covariance.
285    CovarPop,
286}
287
288impl WindowFrameFunction {
289    /// Returns the function name as used in SQL.
290    #[must_use]
291    pub fn sql_name(&self) -> &'static str {
292        match self {
293            Self::Avg => "AVG",
294            Self::Sum => "SUM",
295            Self::Min => "MIN",
296            Self::Max => "MAX",
297            Self::Count => "COUNT",
298            Self::FirstValue => "FIRST_VALUE",
299            Self::LastValue => "LAST_VALUE",
300            Self::Corr => "CORR",
301            Self::CovarSamp => "COVAR_SAMP",
302            Self::CovarPop => "COVAR_POP",
303        }
304    }
305}
306
307/// Window frame unit type.
308#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
309pub enum FrameUnits {
310    /// ROWS BETWEEN — physical row offsets
311    Rows,
312    /// RANGE BETWEEN — logical value range
313    Range,
314}
315
316/// A single bound in a window frame specification.
317#[derive(Debug, Clone, PartialEq, Eq, Hash)]
318pub enum FrameBound {
319    /// UNBOUNDED PRECEDING
320    UnboundedPreceding,
321    /// N PRECEDING
322    Preceding(u64),
323    /// CURRENT ROW
324    CurrentRow,
325    /// N FOLLOWING
326    Following(u64),
327    /// UNBOUNDED FOLLOWING
328    UnboundedFollowing,
329}
330
331/// Information about a single window frame function call.
332#[derive(Debug, Clone, PartialEq, Eq)]
333pub struct WindowFrameInfo {
334    /// Type of aggregate function
335    pub function_type: WindowFrameFunction,
336    /// Column being aggregated (or "*" for COUNT(*))
337    pub column: String,
338    /// Frame unit type (ROWS or RANGE)
339    pub units: FrameUnits,
340    /// Start bound of the frame
341    pub start_bound: FrameBound,
342    /// End bound of the frame
343    pub end_bound: FrameBound,
344    /// Output alias (AS name)
345    pub alias: Option<String>,
346}
347
348/// Result of analyzing window frame functions in a query.
349#[derive(Debug, Clone, PartialEq, Eq)]
350pub struct WindowFrameAnalysis {
351    /// Window frame functions found in the query
352    pub functions: Vec<WindowFrameInfo>,
353    /// PARTITION BY columns from the OVER clause
354    pub partition_columns: Vec<String>,
355    /// ORDER BY columns from the OVER clause
356    pub order_columns: Vec<String>,
357}
358
359impl WindowFrameAnalysis {
360    /// Returns true if any frame uses FOLLOWING bounds.
361    #[must_use]
362    pub fn has_following(&self) -> bool {
363        self.functions.iter().any(|f| {
364            matches!(
365                f.end_bound,
366                FrameBound::Following(_) | FrameBound::UnboundedFollowing
367            ) || matches!(
368                f.start_bound,
369                FrameBound::Following(_) | FrameBound::UnboundedFollowing
370            )
371        })
372    }
373
374    /// Returns the maximum preceding offset across all functions.
375    #[must_use]
376    pub fn max_preceding(&self) -> u64 {
377        self.functions
378            .iter()
379            .filter_map(|f| match &f.start_bound {
380                FrameBound::Preceding(n) => Some(*n),
381                _ => None,
382            })
383            .max()
384            .unwrap_or(0)
385    }
386}
387
388/// Analyzes a SQL statement for window frame aggregate functions.
389///
390/// Walks SELECT items looking for aggregate functions (AVG, SUM, MIN, MAX,
391/// COUNT, FIRST_VALUE, LAST_VALUE) with OVER clauses that contain explicit
392/// ROWS/RANGE frame specifications. Returns `None` if no such functions
393/// are found.
394///
395/// This is distinct from `analyze_analytic_functions()` which handles
396/// per-row offset functions (LAG/LEAD). Window frame functions compute
397/// aggregates over a sliding frame of rows.
398#[must_use]
399pub fn analyze_window_frames(stmt: &Statement) -> Option<WindowFrameAnalysis> {
400    let Statement::Query(query) = stmt else {
401        return None;
402    };
403
404    let SetExpr::Select(select) = query.body.as_ref() else {
405        return None;
406    };
407
408    let mut functions = Vec::new();
409    let mut partition_columns = Vec::new();
410    let mut order_columns = Vec::new();
411    let mut first_window = true;
412
413    for item in &select.projection {
414        let (expr, alias) = match item {
415            SelectItem::UnnamedExpr(expr) => (expr, None),
416            SelectItem::ExprWithAlias { expr, alias } => (expr, Some(alias.value.clone())),
417            _ => continue,
418        };
419
420        if let Some(info) = extract_window_frame_function(expr, alias, &mut |spec| {
421            if first_window {
422                partition_columns = spec
423                    .partition_by
424                    .iter()
425                    .filter_map(extract_column_name)
426                    .collect();
427                order_columns = spec
428                    .order_by
429                    .iter()
430                    .filter_map(|ob| extract_column_name(&ob.expr))
431                    .collect();
432                first_window = false;
433            }
434        }) {
435            functions.push(info);
436        }
437    }
438
439    if functions.is_empty() {
440        return None;
441    }
442
443    Some(WindowFrameAnalysis {
444        functions,
445        partition_columns,
446        order_columns,
447    })
448}
449
450/// Extracts a window frame aggregate function from an expression.
451fn extract_window_frame_function(
452    expr: &Expr,
453    alias: Option<String>,
454    on_window_spec: &mut dyn FnMut(&sqlparser::ast::WindowSpec),
455) -> Option<WindowFrameInfo> {
456    let Expr::Function(func) = expr else {
457        return None;
458    };
459
460    let name = func.name.to_string().to_uppercase();
461    let function_type = match name.as_str() {
462        "AVG" => WindowFrameFunction::Avg,
463        "SUM" => WindowFrameFunction::Sum,
464        "MIN" => WindowFrameFunction::Min,
465        "MAX" => WindowFrameFunction::Max,
466        "COUNT" => WindowFrameFunction::Count,
467        "FIRST_VALUE" => WindowFrameFunction::FirstValue,
468        "LAST_VALUE" => WindowFrameFunction::LastValue,
469        "CORR" => WindowFrameFunction::Corr,
470        "COVAR_SAMP" | "COVAR" => WindowFrameFunction::CovarSamp,
471        "COVAR_POP" => WindowFrameFunction::CovarPop,
472        _ => return None,
473    };
474
475    // Must have an OVER clause with an explicit window frame
476    let window_type = func.over.as_ref()?;
477    let spec = match window_type {
478        sqlparser::ast::WindowType::WindowSpec(spec) => spec,
479        sqlparser::ast::WindowType::NamedWindow(_) => return None,
480    };
481
482    // Only match functions with explicit ROWS/RANGE frame specs
483    let frame = spec.window_frame.as_ref()?;
484
485    on_window_spec(spec);
486
487    let units = match frame.units {
488        sqlparser::ast::WindowFrameUnits::Rows => FrameUnits::Rows,
489        sqlparser::ast::WindowFrameUnits::Range => FrameUnits::Range,
490        sqlparser::ast::WindowFrameUnits::Groups => return None,
491    };
492
493    let start_bound = convert_frame_bound(&frame.start_bound);
494    let end_bound = frame
495        .end_bound
496        .as_ref()
497        .map_or(FrameBound::CurrentRow, convert_frame_bound);
498
499    // Extract the column argument
500    let args = extract_function_args(func);
501    let column = args.first().cloned().unwrap_or_else(|| "*".to_string());
502
503    Some(WindowFrameInfo {
504        function_type,
505        column,
506        units,
507        start_bound,
508        end_bound,
509        alias,
510    })
511}
512
513/// Converts a sqlparser `WindowFrameBound` to our `FrameBound`.
514fn convert_frame_bound(bound: &sqlparser::ast::WindowFrameBound) -> FrameBound {
515    match bound {
516        sqlparser::ast::WindowFrameBound::CurrentRow => FrameBound::CurrentRow,
517        sqlparser::ast::WindowFrameBound::Preceding(None) => FrameBound::UnboundedPreceding,
518        sqlparser::ast::WindowFrameBound::Preceding(Some(expr)) => {
519            let n = expr_to_u64(expr).unwrap_or(0);
520            FrameBound::Preceding(n)
521        }
522        sqlparser::ast::WindowFrameBound::Following(None) => FrameBound::UnboundedFollowing,
523        sqlparser::ast::WindowFrameBound::Following(Some(expr)) => {
524            let n = expr_to_u64(expr).unwrap_or(0);
525            FrameBound::Following(n)
526        }
527    }
528}
529
530/// Extracts a u64 value from an expression (numeric literal).
531fn expr_to_u64(expr: &Expr) -> Option<u64> {
532    match expr {
533        Expr::Value(value_with_span) => match &value_with_span.value {
534            sqlparser::ast::Value::Number(n, _) => n.parse().ok(),
535            _ => None,
536        },
537        _ => None,
538    }
539}
540
541#[cfg(test)]
542mod tests {
543    use super::*;
544    use sqlparser::dialect::GenericDialect;
545    use sqlparser::parser::Parser;
546
547    fn parse_stmt(sql: &str) -> Statement {
548        let dialect = GenericDialect {};
549        let mut stmts = Parser::parse_sql(&dialect, sql).unwrap();
550        stmts.remove(0)
551    }
552
553    #[test]
554    fn test_lag_basic() {
555        let sql = "SELECT price, LAG(price) OVER (ORDER BY ts) AS prev_price FROM trades";
556        let stmt = parse_stmt(sql);
557        let analysis = analyze_analytic_functions(&stmt).unwrap();
558        assert_eq!(analysis.functions.len(), 1);
559        assert_eq!(
560            analysis.functions[0].function_type,
561            AnalyticFunctionType::Lag
562        );
563        assert_eq!(analysis.functions[0].column, "price");
564        assert_eq!(analysis.functions[0].offset, 1);
565        assert_eq!(analysis.functions[0].alias.as_deref(), Some("prev_price"));
566    }
567
568    #[test]
569    fn test_lag_with_offset() {
570        let sql = "SELECT LAG(price, 3) OVER (ORDER BY ts) AS prev3 FROM trades";
571        let stmt = parse_stmt(sql);
572        let analysis = analyze_analytic_functions(&stmt).unwrap();
573        assert_eq!(analysis.functions[0].offset, 3);
574    }
575
576    #[test]
577    fn test_lag_with_default() {
578        let sql = "SELECT LAG(price, 1, 0) OVER (ORDER BY ts) AS prev FROM trades";
579        let stmt = parse_stmt(sql);
580        let analysis = analyze_analytic_functions(&stmt).unwrap();
581        assert_eq!(analysis.functions[0].offset, 1);
582        assert_eq!(analysis.functions[0].default_value.as_deref(), Some("0"));
583    }
584
585    #[test]
586    fn test_lead_basic() {
587        let sql = "SELECT LEAD(price) OVER (ORDER BY ts) AS next_price FROM trades";
588        let stmt = parse_stmt(sql);
589        let analysis = analyze_analytic_functions(&stmt).unwrap();
590        assert_eq!(
591            analysis.functions[0].function_type,
592            AnalyticFunctionType::Lead
593        );
594        assert!(analysis.has_lookahead());
595    }
596
597    #[test]
598    fn test_lead_with_offset_and_default() {
599        let sql = "SELECT LEAD(price, 2, -1) OVER (ORDER BY ts) AS next2 FROM trades";
600        let stmt = parse_stmt(sql);
601        let analysis = analyze_analytic_functions(&stmt).unwrap();
602        assert_eq!(analysis.functions[0].offset, 2);
603        assert_eq!(analysis.functions[0].default_value.as_deref(), Some("-1"));
604    }
605
606    #[test]
607    fn test_partition_by_extraction() {
608        let sql = "SELECT symbol, LAG(price) OVER (PARTITION BY symbol ORDER BY ts) FROM trades";
609        let stmt = parse_stmt(sql);
610        let analysis = analyze_analytic_functions(&stmt).unwrap();
611        assert_eq!(analysis.partition_columns, vec!["symbol".to_string()]);
612        assert_eq!(analysis.order_columns, vec!["ts".to_string()]);
613    }
614
615    #[test]
616    fn test_multiple_analytic_functions() {
617        let sql = "SELECT
618            LAG(price) OVER (ORDER BY ts) AS prev,
619            LEAD(price) OVER (ORDER BY ts) AS next
620            FROM trades";
621        let stmt = parse_stmt(sql);
622        let analysis = analyze_analytic_functions(&stmt).unwrap();
623        assert_eq!(analysis.functions.len(), 2);
624        assert_eq!(
625            analysis.functions[0].function_type,
626            AnalyticFunctionType::Lag
627        );
628        assert_eq!(
629            analysis.functions[1].function_type,
630            AnalyticFunctionType::Lead
631        );
632    }
633
634    #[test]
635    fn test_first_value() {
636        let sql =
637            "SELECT FIRST_VALUE(price) OVER (PARTITION BY symbol ORDER BY ts) AS first FROM trades";
638        let stmt = parse_stmt(sql);
639        let analysis = analyze_analytic_functions(&stmt).unwrap();
640        assert_eq!(
641            analysis.functions[0].function_type,
642            AnalyticFunctionType::FirstValue
643        );
644        assert_eq!(analysis.functions[0].column, "price");
645    }
646
647    #[test]
648    fn test_last_value() {
649        let sql = "SELECT LAST_VALUE(price) OVER (ORDER BY ts) FROM trades";
650        let stmt = parse_stmt(sql);
651        let analysis = analyze_analytic_functions(&stmt).unwrap();
652        assert_eq!(
653            analysis.functions[0].function_type,
654            AnalyticFunctionType::LastValue
655        );
656    }
657
658    #[test]
659    fn test_no_analytic_functions() {
660        let sql = "SELECT price, volume FROM trades WHERE price > 100";
661        let stmt = parse_stmt(sql);
662        assert!(analyze_analytic_functions(&stmt).is_none());
663    }
664
665    #[test]
666    fn test_max_offset() {
667        let sql = "SELECT
668            LAG(price, 1) OVER (ORDER BY ts) AS p1,
669            LAG(price, 5) OVER (ORDER BY ts) AS p5,
670            LEAD(price, 3) OVER (ORDER BY ts) AS n3
671            FROM trades";
672        let stmt = parse_stmt(sql);
673        let analysis = analyze_analytic_functions(&stmt).unwrap();
674        assert_eq!(analysis.max_offset(), 5);
675    }
676
677    // --- Window Frame tests ---
678
679    #[test]
680    fn test_frame_rows_preceding_current() {
681        let sql = "SELECT AVG(price) OVER (ORDER BY ts \
682                    ROWS BETWEEN 9 PRECEDING AND CURRENT ROW) AS ma \
683                    FROM trades";
684        let stmt = parse_stmt(sql);
685        let analysis = analyze_window_frames(&stmt).unwrap();
686        assert_eq!(analysis.functions.len(), 1);
687        assert_eq!(
688            analysis.functions[0].function_type,
689            WindowFrameFunction::Avg
690        );
691        assert_eq!(analysis.functions[0].column, "price");
692        assert_eq!(analysis.functions[0].units, FrameUnits::Rows);
693        assert_eq!(analysis.functions[0].start_bound, FrameBound::Preceding(9));
694        assert_eq!(analysis.functions[0].end_bound, FrameBound::CurrentRow);
695        assert_eq!(analysis.functions[0].alias.as_deref(), Some("ma"));
696    }
697
698    #[test]
699    fn test_frame_rows_preceding_following() {
700        let sql = "SELECT SUM(amount) OVER (ORDER BY id \
701                    ROWS BETWEEN 5 PRECEDING AND 3 FOLLOWING) AS total \
702                    FROM orders";
703        let stmt = parse_stmt(sql);
704        let analysis = analyze_window_frames(&stmt).unwrap();
705        assert_eq!(
706            analysis.functions[0].function_type,
707            WindowFrameFunction::Sum
708        );
709        assert_eq!(analysis.functions[0].start_bound, FrameBound::Preceding(5));
710        assert_eq!(analysis.functions[0].end_bound, FrameBound::Following(3));
711    }
712
713    #[test]
714    fn test_frame_unbounded_preceding_running_sum() {
715        let sql = "SELECT SUM(amount) OVER (ORDER BY id \
716                    ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS running \
717                    FROM orders";
718        let stmt = parse_stmt(sql);
719        let analysis = analyze_window_frames(&stmt).unwrap();
720        assert_eq!(
721            analysis.functions[0].start_bound,
722            FrameBound::UnboundedPreceding
723        );
724        assert_eq!(analysis.functions[0].end_bound, FrameBound::CurrentRow);
725    }
726
727    #[test]
728    fn test_frame_range_units() {
729        let sql = "SELECT AVG(price) OVER (ORDER BY ts \
730                    RANGE BETWEEN 10 PRECEDING AND CURRENT ROW) AS ra \
731                    FROM trades";
732        let stmt = parse_stmt(sql);
733        let analysis = analyze_window_frames(&stmt).unwrap();
734        assert_eq!(analysis.functions[0].units, FrameUnits::Range);
735        assert_eq!(analysis.functions[0].start_bound, FrameBound::Preceding(10));
736    }
737
738    #[test]
739    fn test_frame_partition_order_columns() {
740        let sql = "SELECT AVG(price) OVER (PARTITION BY symbol ORDER BY ts \
741                    ROWS BETWEEN 4 PRECEDING AND CURRENT ROW) AS ma \
742                    FROM trades";
743        let stmt = parse_stmt(sql);
744        let analysis = analyze_window_frames(&stmt).unwrap();
745        assert_eq!(analysis.partition_columns, vec!["symbol".to_string()]);
746        assert_eq!(analysis.order_columns, vec!["ts".to_string()]);
747    }
748
749    #[test]
750    fn test_frame_multiple_functions() {
751        let sql = "SELECT \
752                    AVG(price) OVER (ORDER BY ts ROWS BETWEEN 9 PRECEDING AND CURRENT ROW) AS ma, \
753                    SUM(volume) OVER (ORDER BY ts ROWS BETWEEN 4 PRECEDING AND CURRENT ROW) AS sv \
754                    FROM trades";
755        let stmt = parse_stmt(sql);
756        let analysis = analyze_window_frames(&stmt).unwrap();
757        assert_eq!(analysis.functions.len(), 2);
758        assert_eq!(
759            analysis.functions[0].function_type,
760            WindowFrameFunction::Avg
761        );
762        assert_eq!(analysis.functions[0].column, "price");
763        assert_eq!(
764            analysis.functions[1].function_type,
765            WindowFrameFunction::Sum
766        );
767        assert_eq!(analysis.functions[1].column, "volume");
768    }
769
770    #[test]
771    fn test_frame_no_frame_returns_none() {
772        // AVG with OVER but no explicit frame → None
773        let sql = "SELECT AVG(price) OVER (ORDER BY ts) FROM trades";
774        let stmt = parse_stmt(sql);
775        assert!(analyze_window_frames(&stmt).is_none());
776    }
777
778    #[test]
779    fn test_frame_unbounded_following() {
780        let sql = "SELECT SUM(amount) OVER (ORDER BY id \
781                    ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) AS rest \
782                    FROM orders";
783        let stmt = parse_stmt(sql);
784        let analysis = analyze_window_frames(&stmt).unwrap();
785        assert_eq!(analysis.functions[0].start_bound, FrameBound::CurrentRow);
786        assert_eq!(
787            analysis.functions[0].end_bound,
788            FrameBound::UnboundedFollowing
789        );
790        assert!(analysis.has_following());
791    }
792
793    #[test]
794    fn test_frame_all_function_types() {
795        let sql = "SELECT \
796                    AVG(a) OVER (ORDER BY id ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) AS f1, \
797                    SUM(b) OVER (ORDER BY id ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) AS f2, \
798                    MIN(c) OVER (ORDER BY id ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) AS f3, \
799                    MAX(d) OVER (ORDER BY id ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) AS f4, \
800                    COUNT(e) OVER (ORDER BY id ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) AS f5 \
801                    FROM t";
802        let stmt = parse_stmt(sql);
803        let analysis = analyze_window_frames(&stmt).unwrap();
804        assert_eq!(analysis.functions.len(), 5);
805        assert_eq!(
806            analysis.functions[0].function_type,
807            WindowFrameFunction::Avg
808        );
809        assert_eq!(
810            analysis.functions[1].function_type,
811            WindowFrameFunction::Sum
812        );
813        assert_eq!(
814            analysis.functions[2].function_type,
815            WindowFrameFunction::Min
816        );
817        assert_eq!(
818            analysis.functions[3].function_type,
819            WindowFrameFunction::Max
820        );
821        assert_eq!(
822            analysis.functions[4].function_type,
823            WindowFrameFunction::Count
824        );
825    }
826
827    #[test]
828    fn test_frame_corr_bivariate() {
829        let sql = "SELECT CORR(price, sentiment) OVER (ORDER BY bucket \
830                    ROWS 30 PRECEDING) AS c FROM joined";
831        let stmt = parse_stmt(sql);
832        let analysis = analyze_window_frames(&stmt).unwrap();
833        assert_eq!(
834            analysis.functions[0].function_type,
835            WindowFrameFunction::Corr
836        );
837        assert_eq!(analysis.functions[0].start_bound, FrameBound::Preceding(30));
838        assert_eq!(analysis.order_columns, vec!["bucket".to_string()]);
839    }
840
841    #[test]
842    fn test_frame_max_preceding_helper() {
843        let sql = "SELECT \
844                    AVG(a) OVER (ORDER BY id ROWS BETWEEN 3 PRECEDING AND CURRENT ROW) AS f1, \
845                    SUM(b) OVER (ORDER BY id ROWS BETWEEN 10 PRECEDING AND CURRENT ROW) AS f2 \
846                    FROM t";
847        let stmt = parse_stmt(sql);
848        let analysis = analyze_window_frames(&stmt).unwrap();
849        assert_eq!(analysis.max_preceding(), 10);
850        assert!(!analysis.has_following());
851    }
852}