Skip to main content

laminar_core/compiler/
fold.rs

1//! Constant folding pre-pass for expression compilation.
2//!
3//! [`fold_constants`] rewrites an expression tree, evaluating constant
4//! sub-expressions at compile time before Cranelift code generation.
5
6use datafusion_common::ScalarValue;
7use datafusion_expr::{BinaryExpr, Expr, Operator};
8
9/// Folds constant sub-expressions in a `DataFusion` [`Expr`] tree.
10///
11/// Handles:
12/// - Arithmetic on literal pairs (`Literal op Literal`)
13/// - Boolean identity rules (`true AND x` → `x`, `false OR x` → `x`, etc.)
14/// - Recursive descent into `BinaryExpr`, `Not` children
15///
16/// Returns a (possibly simplified) clone of the input expression.
17#[must_use]
18pub fn fold_constants(expr: &Expr) -> Expr {
19    match expr {
20        Expr::BinaryExpr(binary) => fold_binary(binary),
21        Expr::Not(inner) => fold_not(inner),
22        _ => expr.clone(),
23    }
24}
25
26/// Folds a binary expression, first recursing into children.
27fn fold_binary(binary: &BinaryExpr) -> Expr {
28    let left = fold_constants(&binary.left);
29    let right = fold_constants(&binary.right);
30
31    // Try boolean identity rules first.
32    if let Some(simplified) = try_boolean_identity(&left, &right, binary.op) {
33        return simplified;
34    }
35
36    // Try literal-literal folding.
37    if let (Expr::Literal(lv, _), Expr::Literal(rv, _)) = (&left, &right) {
38        if let Some(result) = fold_literal_pair(lv, rv, binary.op) {
39            return Expr::Literal(result, None);
40        }
41    }
42
43    Expr::BinaryExpr(BinaryExpr::new(Box::new(left), binary.op, Box::new(right)))
44}
45
46/// Folds `NOT` expressions: `NOT true` → `false`, `NOT false` → `true`.
47fn fold_not(inner: &Expr) -> Expr {
48    let folded = fold_constants(inner);
49    match &folded {
50        Expr::Literal(ScalarValue::Boolean(Some(b)), _) => {
51            Expr::Literal(ScalarValue::Boolean(Some(!b)), None)
52        }
53        _ => Expr::Not(Box::new(folded)),
54    }
55}
56
57/// Applies boolean identity simplifications.
58///
59/// - `true AND x` → `x`
60/// - `false AND x` → `false`
61/// - `x AND true` → `x`
62/// - `x AND false` → `false`
63/// - `true OR x` → `true`
64/// - `false OR x` → `x`
65/// - `x OR true` → `true`
66/// - `x OR false` → `x`
67fn try_boolean_identity(left: &Expr, right: &Expr, op: Operator) -> Option<Expr> {
68    match op {
69        Operator::And => {
70            if let Expr::Literal(ScalarValue::Boolean(Some(b)), _) = left {
71                return Some(if *b { right.clone() } else { left.clone() });
72            }
73            if let Expr::Literal(ScalarValue::Boolean(Some(b)), _) = right {
74                return Some(if *b { left.clone() } else { right.clone() });
75            }
76            None
77        }
78        Operator::Or => {
79            if let Expr::Literal(ScalarValue::Boolean(Some(b)), _) = left {
80                return Some(if *b { left.clone() } else { right.clone() });
81            }
82            if let Expr::Literal(ScalarValue::Boolean(Some(b)), _) = right {
83                return Some(if *b { right.clone() } else { left.clone() });
84            }
85            None
86        }
87        _ => None,
88    }
89}
90
91/// Evaluates `lhs op rhs` for two scalar literals.
92fn fold_literal_pair(lhs: &ScalarValue, rhs: &ScalarValue, op: Operator) -> Option<ScalarValue> {
93    // i64 arithmetic
94    if let (ScalarValue::Int64(Some(l)), ScalarValue::Int64(Some(r))) = (lhs, rhs) {
95        return fold_i64(*l, *r, op);
96    }
97    // f64 arithmetic
98    if let (ScalarValue::Float64(Some(l)), ScalarValue::Float64(Some(r))) = (lhs, rhs) {
99        return fold_f64(*l, *r, op);
100    }
101    // Boolean logic
102    if let (ScalarValue::Boolean(Some(l)), ScalarValue::Boolean(Some(r))) = (lhs, rhs) {
103        return fold_bool(*l, *r, op);
104    }
105    None
106}
107
108fn fold_i64(l: i64, r: i64, op: Operator) -> Option<ScalarValue> {
109    match op {
110        Operator::Plus => l.checked_add(r).map(|v| ScalarValue::Int64(Some(v))),
111        Operator::Minus => l.checked_sub(r).map(|v| ScalarValue::Int64(Some(v))),
112        Operator::Multiply => l.checked_mul(r).map(|v| ScalarValue::Int64(Some(v))),
113        Operator::Divide if r != 0 => l.checked_div(r).map(|v| ScalarValue::Int64(Some(v))),
114        Operator::Modulo if r != 0 => l.checked_rem(r).map(|v| ScalarValue::Int64(Some(v))),
115        Operator::Eq => Some(ScalarValue::Boolean(Some(l == r))),
116        Operator::NotEq => Some(ScalarValue::Boolean(Some(l != r))),
117        Operator::Lt => Some(ScalarValue::Boolean(Some(l < r))),
118        Operator::LtEq => Some(ScalarValue::Boolean(Some(l <= r))),
119        Operator::Gt => Some(ScalarValue::Boolean(Some(l > r))),
120        Operator::GtEq => Some(ScalarValue::Boolean(Some(l >= r))),
121        _ => None,
122    }
123}
124
125fn fold_f64(l: f64, r: f64, op: Operator) -> Option<ScalarValue> {
126    match op {
127        Operator::Plus => Some(ScalarValue::Float64(Some(l + r))),
128        Operator::Minus => Some(ScalarValue::Float64(Some(l - r))),
129        Operator::Multiply => Some(ScalarValue::Float64(Some(l * r))),
130        Operator::Divide if r != 0.0 => Some(ScalarValue::Float64(Some(l / r))),
131        _ => None,
132    }
133}
134
135fn fold_bool(l: bool, r: bool, op: Operator) -> Option<ScalarValue> {
136    match op {
137        Operator::And => Some(ScalarValue::Boolean(Some(l && r))),
138        Operator::Or => Some(ScalarValue::Boolean(Some(l || r))),
139        Operator::Eq => Some(ScalarValue::Boolean(Some(l == r))),
140        Operator::NotEq => Some(ScalarValue::Boolean(Some(l != r))),
141        _ => None,
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use datafusion_expr::col;
149
150    fn lit_i64(v: i64) -> Expr {
151        Expr::Literal(ScalarValue::Int64(Some(v)), None)
152    }
153
154    fn lit_f64(v: f64) -> Expr {
155        Expr::Literal(ScalarValue::Float64(Some(v)), None)
156    }
157
158    fn lit_bool(v: bool) -> Expr {
159        Expr::Literal(ScalarValue::Boolean(Some(v)), None)
160    }
161
162    #[test]
163    fn fold_i64_addition() {
164        let expr = Expr::BinaryExpr(BinaryExpr::new(
165            Box::new(lit_i64(10)),
166            Operator::Plus,
167            Box::new(lit_i64(20)),
168        ));
169        let folded = fold_constants(&expr);
170        assert_eq!(folded, lit_i64(30));
171    }
172
173    #[test]
174    fn fold_i64_subtraction() {
175        let expr = Expr::BinaryExpr(BinaryExpr::new(
176            Box::new(lit_i64(50)),
177            Operator::Minus,
178            Box::new(lit_i64(8)),
179        ));
180        assert_eq!(fold_constants(&expr), lit_i64(42));
181    }
182
183    #[test]
184    fn fold_i64_multiplication() {
185        let expr = Expr::BinaryExpr(BinaryExpr::new(
186            Box::new(lit_i64(6)),
187            Operator::Multiply,
188            Box::new(lit_i64(7)),
189        ));
190        assert_eq!(fold_constants(&expr), lit_i64(42));
191    }
192
193    #[test]
194    fn fold_f64_arithmetic() {
195        let expr = Expr::BinaryExpr(BinaryExpr::new(
196            Box::new(lit_f64(1.5)),
197            Operator::Plus,
198            Box::new(lit_f64(2.5)),
199        ));
200        let folded = fold_constants(&expr);
201        assert_eq!(folded, lit_f64(4.0));
202    }
203
204    #[test]
205    fn fold_nested_constants() {
206        // (2 + 3) * 4 → 5 * 4 → 20
207        let inner = Expr::BinaryExpr(BinaryExpr::new(
208            Box::new(lit_i64(2)),
209            Operator::Plus,
210            Box::new(lit_i64(3)),
211        ));
212        let expr = Expr::BinaryExpr(BinaryExpr::new(
213            Box::new(inner),
214            Operator::Multiply,
215            Box::new(lit_i64(4)),
216        ));
217        assert_eq!(fold_constants(&expr), lit_i64(20));
218    }
219
220    #[test]
221    fn fold_not_literal() {
222        assert_eq!(
223            fold_constants(&Expr::Not(Box::new(lit_bool(true)))),
224            lit_bool(false)
225        );
226        assert_eq!(
227            fold_constants(&Expr::Not(Box::new(lit_bool(false)))),
228            lit_bool(true)
229        );
230    }
231
232    #[test]
233    fn fold_boolean_identity_and() {
234        let x = col("x");
235        // true AND x → x
236        let expr = Expr::BinaryExpr(BinaryExpr::new(
237            Box::new(lit_bool(true)),
238            Operator::And,
239            Box::new(x.clone()),
240        ));
241        assert_eq!(fold_constants(&expr), x);
242
243        // false AND x → false
244        let expr = Expr::BinaryExpr(BinaryExpr::new(
245            Box::new(lit_bool(false)),
246            Operator::And,
247            Box::new(col("x")),
248        ));
249        assert_eq!(fold_constants(&expr), lit_bool(false));
250    }
251
252    #[test]
253    fn fold_boolean_identity_or() {
254        let x = col("x");
255        // false OR x → x
256        let expr = Expr::BinaryExpr(BinaryExpr::new(
257            Box::new(lit_bool(false)),
258            Operator::Or,
259            Box::new(x.clone()),
260        ));
261        assert_eq!(fold_constants(&expr), x);
262
263        // true OR x → true
264        let expr = Expr::BinaryExpr(BinaryExpr::new(
265            Box::new(lit_bool(true)),
266            Operator::Or,
267            Box::new(col("x")),
268        ));
269        assert_eq!(fold_constants(&expr), lit_bool(true));
270    }
271
272    #[test]
273    fn no_fold_on_column_expr() {
274        let expr = Expr::BinaryExpr(BinaryExpr::new(
275            Box::new(col("a")),
276            Operator::Plus,
277            Box::new(lit_i64(1)),
278        ));
279        let folded = fold_constants(&expr);
280        // Should remain a BinaryExpr (column can't be folded).
281        assert!(matches!(folded, Expr::BinaryExpr(_)));
282    }
283
284    #[test]
285    fn fold_division_by_zero_no_fold() {
286        let expr = Expr::BinaryExpr(BinaryExpr::new(
287            Box::new(lit_i64(10)),
288            Operator::Divide,
289            Box::new(lit_i64(0)),
290        ));
291        // Division by zero — don't fold, leave as binary expr.
292        let folded = fold_constants(&expr);
293        assert!(matches!(folded, Expr::BinaryExpr(_)));
294    }
295
296    #[test]
297    fn fold_comparison() {
298        let expr = Expr::BinaryExpr(BinaryExpr::new(
299            Box::new(lit_i64(5)),
300            Operator::Gt,
301            Box::new(lit_i64(3)),
302        ));
303        assert_eq!(fold_constants(&expr), lit_bool(true));
304    }
305}