Skip to main content

laminar_sql/parser/
interval_rewriter.rs

1//! Interval arithmetic rewriter for BIGINT timestamp columns.
2//!
3//! LaminarDB uses BIGINT millisecond timestamps for event time. DataFusion
4//! cannot natively evaluate `Int64 ± INTERVAL`, so this module rewrites
5//! INTERVAL expressions in arithmetic operations to equivalent millisecond
6//! integer literals before the SQL reaches DataFusion.
7//!
8//! # Example
9//!
10//! ```sql
11//! -- Before rewrite:
12//! SELECT * FROM trades t
13//! INNER JOIN orders o ON t.symbol = o.symbol
14//!   AND o.ts BETWEEN t.ts - INTERVAL '10' SECOND AND t.ts + INTERVAL '10' SECOND
15//!
16//! -- After rewrite:
17//! SELECT * FROM trades t
18//! INNER JOIN orders o ON t.symbol = o.symbol
19//!   AND o.ts BETWEEN t.ts - 10000 AND t.ts + 10000
20//! ```
21
22use sqlparser::ast::{
23    BinaryOperator, DateTimeField, Expr, JoinConstraint, JoinOperator, Query, Select, SelectItem,
24    SetExpr, Statement, Value,
25};
26
27/// Convert an [`Interval`](sqlparser::ast::Interval) to its equivalent milliseconds value.
28///
29/// Returns `None` if the interval cannot be converted (unsupported unit or
30/// non-numeric value).
31fn interval_to_millis(interval: &sqlparser::ast::Interval) -> Option<i64> {
32    let value = extract_interval_numeric(&interval.value)?;
33    let unit = interval
34        .leading_field
35        .clone()
36        .unwrap_or(DateTimeField::Second);
37
38    let millis = match unit {
39        DateTimeField::Millisecond | DateTimeField::Milliseconds => value,
40        DateTimeField::Second | DateTimeField::Seconds => value.checked_mul(1_000)?,
41        DateTimeField::Minute | DateTimeField::Minutes => value.checked_mul(60_000)?,
42        DateTimeField::Hour | DateTimeField::Hours => value.checked_mul(3_600_000)?,
43        DateTimeField::Day | DateTimeField::Days => value.checked_mul(86_400_000)?,
44        _ => return None,
45    };
46
47    Some(millis)
48}
49
50/// Extract a numeric value from an interval's value expression.
51fn extract_interval_numeric(expr: &Expr) -> Option<i64> {
52    match expr {
53        Expr::Value(vws) => match &vws.value {
54            Value::Number(n, _) => n.parse().ok(),
55            Value::SingleQuotedString(s) => s.split_whitespace().next()?.parse().ok(),
56            _ => None,
57        },
58        _ => None,
59    }
60}
61
62/// Create a numeric literal `Expr` from a milliseconds value.
63///
64/// Uses sqlparser's own parser to construct the AST node, ensuring
65/// correct internal representation.
66fn make_number_expr(n: i64) -> Option<Expr> {
67    use sqlparser::dialect::GenericDialect;
68    let s = n.to_string();
69    let dialect = GenericDialect {};
70    sqlparser::parser::Parser::new(&dialect)
71        .try_with_sql(&s)
72        .ok()?
73        .parse_expr()
74        .ok()
75}
76
77// ---------------------------------------------------------------------------
78// Expression rewriter
79// ---------------------------------------------------------------------------
80
81/// Rewrite INTERVAL arithmetic in an expression tree, in place.
82///
83/// Converts patterns like `col ± INTERVAL 'N' UNIT` to `col ± <millis>` so
84/// that DataFusion can evaluate the expression when the column is `Int64`.
85pub fn rewrite_expr_mut(expr: &mut Expr) {
86    if let Expr::BinaryOp { left, op, right } = expr {
87        let is_add_sub = matches!(*op, BinaryOperator::Plus | BinaryOperator::Minus);
88
89        if is_add_sub {
90            // Check right side for INTERVAL: col ± INTERVAL → col ± millis
91            let right_ms: Option<i64> = if let Expr::Interval(interval) = right.as_ref() {
92                interval_to_millis(interval)
93            } else {
94                None
95            };
96
97            if let Some(ms) = right_ms {
98                if let Some(num_expr) = make_number_expr(ms) {
99                    **right = num_expr;
100                    rewrite_expr_mut(left);
101                    return;
102                }
103            }
104
105            // Check left side: INTERVAL + col → millis + col (only addition)
106            if matches!(*op, BinaryOperator::Plus) {
107                let left_ms: Option<i64> = if let Expr::Interval(interval) = left.as_ref() {
108                    interval_to_millis(interval)
109                } else {
110                    None
111                };
112
113                if let Some(ms) = left_ms {
114                    if let Some(num_expr) = make_number_expr(ms) {
115                        **left = num_expr;
116                        rewrite_expr_mut(right);
117                        return;
118                    }
119                }
120            }
121        }
122
123        // Default: recurse into both sides
124        rewrite_expr_mut(left);
125        rewrite_expr_mut(right);
126        return;
127    }
128
129    // Handle other expression types that can contain sub-expressions
130    match expr {
131        Expr::Between {
132            expr: e, low, high, ..
133        } => {
134            rewrite_expr_mut(e);
135            rewrite_expr_mut(low);
136            rewrite_expr_mut(high);
137        }
138        Expr::InList { expr: e, list, .. } => {
139            rewrite_expr_mut(e);
140            for item in list {
141                rewrite_expr_mut(item);
142            }
143        }
144        Expr::Nested(inner)
145        | Expr::UnaryOp { expr: inner, .. }
146        | Expr::Cast { expr: inner, .. }
147        | Expr::IsNull(inner)
148        | Expr::IsNotNull(inner)
149        | Expr::IsFalse(inner)
150        | Expr::IsNotFalse(inner)
151        | Expr::IsTrue(inner)
152        | Expr::IsNotTrue(inner) => rewrite_expr_mut(inner),
153        _ => {}
154    }
155}
156
157// ---------------------------------------------------------------------------
158// Statement / query walker
159// ---------------------------------------------------------------------------
160
161/// Rewrite all INTERVAL arithmetic in a SQL [`Statement`].
162///
163/// Walks the full AST and converts `expr ± INTERVAL 'N' UNIT` patterns
164/// to `expr ± <milliseconds>` for BIGINT timestamp compatibility.
165pub fn rewrite_interval_arithmetic(stmt: &mut Statement) {
166    if let Statement::Query(query) = stmt {
167        rewrite_query(query);
168    }
169}
170
171fn rewrite_query(query: &mut Query) {
172    rewrite_set_expr(&mut query.body);
173}
174
175fn rewrite_set_expr(body: &mut SetExpr) {
176    match body {
177        SetExpr::Select(select) => rewrite_select(select),
178        SetExpr::Query(query) => rewrite_query(query),
179        SetExpr::SetOperation { left, right, .. } => {
180            rewrite_set_expr(left);
181            rewrite_set_expr(right);
182        }
183        _ => {}
184    }
185}
186
187fn rewrite_select(select: &mut Select) {
188    // Rewrite SELECT projection expressions
189    for item in &mut select.projection {
190        match item {
191            SelectItem::UnnamedExpr(ref mut expr)
192            | SelectItem::ExprWithAlias { ref mut expr, .. } => {
193                rewrite_expr_mut(expr);
194            }
195            _ => {}
196        }
197    }
198
199    // Rewrite WHERE clause
200    if let Some(ref mut where_expr) = select.selection {
201        rewrite_expr_mut(where_expr);
202    }
203
204    // Rewrite HAVING clause
205    if let Some(ref mut having) = select.having {
206        rewrite_expr_mut(having);
207    }
208
209    // Rewrite JOIN ON conditions
210    for table_with_joins in &mut select.from {
211        for join in &mut table_with_joins.joins {
212            rewrite_join_operator(&mut join.join_operator);
213        }
214    }
215}
216
217fn rewrite_join_operator(jo: &mut JoinOperator) {
218    let (JoinOperator::Join(constraint)
219    | JoinOperator::Inner(constraint)
220    | JoinOperator::Left(constraint)
221    | JoinOperator::LeftOuter(constraint)
222    | JoinOperator::Right(constraint)
223    | JoinOperator::RightOuter(constraint)
224    | JoinOperator::FullOuter(constraint)
225    | JoinOperator::StraightJoin(constraint)
226    | JoinOperator::LeftSemi(constraint)
227    | JoinOperator::RightSemi(constraint)
228    | JoinOperator::LeftAnti(constraint)
229    | JoinOperator::RightAnti(constraint)
230    | JoinOperator::Semi(constraint)
231    | JoinOperator::Anti(constraint)) = jo
232    else {
233        return;
234    };
235    if let JoinConstraint::On(expr) = constraint {
236        rewrite_expr_mut(expr);
237    }
238}
239
240// ---------------------------------------------------------------------------
241// Tests
242// ---------------------------------------------------------------------------
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use crate::parser::dialect::LaminarDialect;
248
249    /// Helper: parse SQL, rewrite intervals, return the rewritten SQL string.
250    fn rewrite(sql: &str) -> String {
251        let dialect = LaminarDialect::default();
252        let mut stmts = sqlparser::parser::Parser::parse_sql(&dialect, sql).unwrap();
253        assert!(!stmts.is_empty());
254        rewrite_interval_arithmetic(&mut stmts[0]);
255        stmts[0].to_string()
256    }
257
258    // -- Basic arithmetic --
259
260    #[test]
261    fn test_subtract_interval_seconds() {
262        let result = rewrite("SELECT ts - INTERVAL '10' SECOND FROM events");
263        assert!(result.contains("ts - 10000"), "got: {result}");
264        assert!(!result.contains("INTERVAL"), "got: {result}");
265    }
266
267    #[test]
268    fn test_add_interval_seconds() {
269        let result = rewrite("SELECT ts + INTERVAL '5' SECOND FROM events");
270        assert!(result.contains("ts + 5000"), "got: {result}");
271    }
272
273    #[test]
274    fn test_interval_minutes() {
275        let result = rewrite("SELECT ts - INTERVAL '2' MINUTE FROM events");
276        assert!(result.contains("ts - 120000"), "got: {result}");
277    }
278
279    #[test]
280    fn test_interval_hours() {
281        let result = rewrite("SELECT ts + INTERVAL '1' HOUR FROM events");
282        assert!(result.contains("ts + 3600000"), "got: {result}");
283    }
284
285    #[test]
286    fn test_interval_days() {
287        let result = rewrite("SELECT ts - INTERVAL '1' DAY FROM events");
288        assert!(result.contains("ts - 86400000"), "got: {result}");
289    }
290
291    #[test]
292    fn test_interval_milliseconds() {
293        let result = rewrite("SELECT ts - INTERVAL '100' MILLISECOND FROM events");
294        assert!(result.contains("ts - 100"), "got: {result}");
295    }
296
297    // -- WHERE clause --
298
299    #[test]
300    fn test_where_clause_interval() {
301        let result = rewrite("SELECT * FROM events WHERE ts > ts2 - INTERVAL '10' SECOND");
302        assert!(result.contains("ts2 - 10000"), "got: {result}");
303    }
304
305    // -- BETWEEN (from issue example) --
306
307    #[test]
308    fn test_between_interval() {
309        let result = rewrite(
310            "SELECT * FROM trades t \
311             INNER JOIN orders o ON t.symbol = o.symbol \
312             AND o.ts BETWEEN t.ts - INTERVAL '10' SECOND AND t.ts + INTERVAL '10' SECOND",
313        );
314        assert!(result.contains("t.ts - 10000"), "got: {result}");
315        assert!(result.contains("t.ts + 10000"), "got: {result}");
316        assert!(!result.contains("INTERVAL"), "got: {result}");
317    }
318
319    // -- JOIN ON condition --
320
321    #[test]
322    fn test_join_on_interval() {
323        let result = rewrite(
324            "SELECT * FROM a JOIN b ON a.id = b.id \
325             AND b.ts BETWEEN a.ts - INTERVAL '5' MINUTE AND a.ts + INTERVAL '5' MINUTE",
326        );
327        assert!(result.contains("a.ts - 300000"), "got: {result}");
328        assert!(result.contains("a.ts + 300000"), "got: {result}");
329    }
330
331    // -- Nested expressions --
332
333    #[test]
334    fn test_nested_parens() {
335        let result = rewrite("SELECT * FROM e WHERE (ts - INTERVAL '1' SECOND) > 0");
336        assert!(result.contains("ts - 1000"), "got: {result}");
337    }
338
339    // -- Left-side INTERVAL (INTERVAL + col) --
340
341    #[test]
342    fn test_interval_on_left_side() {
343        let result = rewrite("SELECT INTERVAL '10' SECOND + ts FROM events");
344        assert!(result.contains("10000 + ts"), "got: {result}");
345    }
346
347    // -- No-op cases (should not be modified) --
348
349    #[test]
350    fn test_no_interval_unchanged() {
351        let result = rewrite("SELECT ts - 10000 FROM events");
352        assert!(result.contains("ts - 10000"), "got: {result}");
353    }
354
355    #[test]
356    fn test_interval_default_unit_is_second() {
357        // When no unit is specified, sqlparser defaults to SECOND
358        let result = rewrite("SELECT ts - INTERVAL '5' SECOND FROM events");
359        assert!(result.contains("ts - 5000"), "got: {result}");
360    }
361
362    // -- Multiple intervals in same query --
363
364    #[test]
365    fn test_multiple_intervals() {
366        let result = rewrite(
367            "SELECT * FROM events \
368             WHERE ts > start_ts - INTERVAL '10' SECOND \
369             AND ts < end_ts + INTERVAL '30' SECOND",
370        );
371        assert!(result.contains("start_ts - 10000"), "got: {result}");
372        assert!(result.contains("end_ts + 30000"), "got: {result}");
373    }
374
375    // -- HAVING clause --
376
377    #[test]
378    fn test_having_clause_interval() {
379        let result = rewrite(
380            "SELECT symbol, COUNT(*) FROM trades \
381             GROUP BY symbol \
382             HAVING MAX(ts) - MIN(ts) > INTERVAL '1' HOUR",
383        );
384        // The HAVING expression should remain valid; INTERVAL is on the
385        // right side of `>` which is a comparison, not +/-, so untouched
386        // is correct.
387        assert!(result.contains("HAVING"), "got: {result}");
388    }
389}