1use datafusion_common::ScalarValue;
7use datafusion_expr::{BinaryExpr, Expr, Operator};
8
9#[must_use]
18pub fn fold_constants(expr: &Expr) -> Expr {
19 match expr {
20 Expr::BinaryExpr(binary) => fold_binary(binary),
21 Expr::Not(inner) => fold_not(inner),
22 _ => expr.clone(),
23 }
24}
25
26fn fold_binary(binary: &BinaryExpr) -> Expr {
28 let left = fold_constants(&binary.left);
29 let right = fold_constants(&binary.right);
30
31 if let Some(simplified) = try_boolean_identity(&left, &right, binary.op) {
33 return simplified;
34 }
35
36 if let (Expr::Literal(lv, _), Expr::Literal(rv, _)) = (&left, &right) {
38 if let Some(result) = fold_literal_pair(lv, rv, binary.op) {
39 return Expr::Literal(result, None);
40 }
41 }
42
43 Expr::BinaryExpr(BinaryExpr::new(Box::new(left), binary.op, Box::new(right)))
44}
45
46fn fold_not(inner: &Expr) -> Expr {
48 let folded = fold_constants(inner);
49 match &folded {
50 Expr::Literal(ScalarValue::Boolean(Some(b)), _) => {
51 Expr::Literal(ScalarValue::Boolean(Some(!b)), None)
52 }
53 _ => Expr::Not(Box::new(folded)),
54 }
55}
56
57fn try_boolean_identity(left: &Expr, right: &Expr, op: Operator) -> Option<Expr> {
68 match op {
69 Operator::And => {
70 if let Expr::Literal(ScalarValue::Boolean(Some(b)), _) = left {
71 return Some(if *b { right.clone() } else { left.clone() });
72 }
73 if let Expr::Literal(ScalarValue::Boolean(Some(b)), _) = right {
74 return Some(if *b { left.clone() } else { right.clone() });
75 }
76 None
77 }
78 Operator::Or => {
79 if let Expr::Literal(ScalarValue::Boolean(Some(b)), _) = left {
80 return Some(if *b { left.clone() } else { right.clone() });
81 }
82 if let Expr::Literal(ScalarValue::Boolean(Some(b)), _) = right {
83 return Some(if *b { right.clone() } else { left.clone() });
84 }
85 None
86 }
87 _ => None,
88 }
89}
90
91fn fold_literal_pair(lhs: &ScalarValue, rhs: &ScalarValue, op: Operator) -> Option<ScalarValue> {
93 if let (ScalarValue::Int64(Some(l)), ScalarValue::Int64(Some(r))) = (lhs, rhs) {
95 return fold_i64(*l, *r, op);
96 }
97 if let (ScalarValue::Float64(Some(l)), ScalarValue::Float64(Some(r))) = (lhs, rhs) {
99 return fold_f64(*l, *r, op);
100 }
101 if let (ScalarValue::Boolean(Some(l)), ScalarValue::Boolean(Some(r))) = (lhs, rhs) {
103 return fold_bool(*l, *r, op);
104 }
105 None
106}
107
108fn fold_i64(l: i64, r: i64, op: Operator) -> Option<ScalarValue> {
109 match op {
110 Operator::Plus => l.checked_add(r).map(|v| ScalarValue::Int64(Some(v))),
111 Operator::Minus => l.checked_sub(r).map(|v| ScalarValue::Int64(Some(v))),
112 Operator::Multiply => l.checked_mul(r).map(|v| ScalarValue::Int64(Some(v))),
113 Operator::Divide if r != 0 => l.checked_div(r).map(|v| ScalarValue::Int64(Some(v))),
114 Operator::Modulo if r != 0 => l.checked_rem(r).map(|v| ScalarValue::Int64(Some(v))),
115 Operator::Eq => Some(ScalarValue::Boolean(Some(l == r))),
116 Operator::NotEq => Some(ScalarValue::Boolean(Some(l != r))),
117 Operator::Lt => Some(ScalarValue::Boolean(Some(l < r))),
118 Operator::LtEq => Some(ScalarValue::Boolean(Some(l <= r))),
119 Operator::Gt => Some(ScalarValue::Boolean(Some(l > r))),
120 Operator::GtEq => Some(ScalarValue::Boolean(Some(l >= r))),
121 _ => None,
122 }
123}
124
125fn fold_f64(l: f64, r: f64, op: Operator) -> Option<ScalarValue> {
126 match op {
127 Operator::Plus => Some(ScalarValue::Float64(Some(l + r))),
128 Operator::Minus => Some(ScalarValue::Float64(Some(l - r))),
129 Operator::Multiply => Some(ScalarValue::Float64(Some(l * r))),
130 Operator::Divide if r != 0.0 => Some(ScalarValue::Float64(Some(l / r))),
131 _ => None,
132 }
133}
134
135fn fold_bool(l: bool, r: bool, op: Operator) -> Option<ScalarValue> {
136 match op {
137 Operator::And => Some(ScalarValue::Boolean(Some(l && r))),
138 Operator::Or => Some(ScalarValue::Boolean(Some(l || r))),
139 Operator::Eq => Some(ScalarValue::Boolean(Some(l == r))),
140 Operator::NotEq => Some(ScalarValue::Boolean(Some(l != r))),
141 _ => None,
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148 use datafusion_expr::col;
149
150 fn lit_i64(v: i64) -> Expr {
151 Expr::Literal(ScalarValue::Int64(Some(v)), None)
152 }
153
154 fn lit_f64(v: f64) -> Expr {
155 Expr::Literal(ScalarValue::Float64(Some(v)), None)
156 }
157
158 fn lit_bool(v: bool) -> Expr {
159 Expr::Literal(ScalarValue::Boolean(Some(v)), None)
160 }
161
162 #[test]
163 fn fold_i64_addition() {
164 let expr = Expr::BinaryExpr(BinaryExpr::new(
165 Box::new(lit_i64(10)),
166 Operator::Plus,
167 Box::new(lit_i64(20)),
168 ));
169 let folded = fold_constants(&expr);
170 assert_eq!(folded, lit_i64(30));
171 }
172
173 #[test]
174 fn fold_i64_subtraction() {
175 let expr = Expr::BinaryExpr(BinaryExpr::new(
176 Box::new(lit_i64(50)),
177 Operator::Minus,
178 Box::new(lit_i64(8)),
179 ));
180 assert_eq!(fold_constants(&expr), lit_i64(42));
181 }
182
183 #[test]
184 fn fold_i64_multiplication() {
185 let expr = Expr::BinaryExpr(BinaryExpr::new(
186 Box::new(lit_i64(6)),
187 Operator::Multiply,
188 Box::new(lit_i64(7)),
189 ));
190 assert_eq!(fold_constants(&expr), lit_i64(42));
191 }
192
193 #[test]
194 fn fold_f64_arithmetic() {
195 let expr = Expr::BinaryExpr(BinaryExpr::new(
196 Box::new(lit_f64(1.5)),
197 Operator::Plus,
198 Box::new(lit_f64(2.5)),
199 ));
200 let folded = fold_constants(&expr);
201 assert_eq!(folded, lit_f64(4.0));
202 }
203
204 #[test]
205 fn fold_nested_constants() {
206 let inner = Expr::BinaryExpr(BinaryExpr::new(
208 Box::new(lit_i64(2)),
209 Operator::Plus,
210 Box::new(lit_i64(3)),
211 ));
212 let expr = Expr::BinaryExpr(BinaryExpr::new(
213 Box::new(inner),
214 Operator::Multiply,
215 Box::new(lit_i64(4)),
216 ));
217 assert_eq!(fold_constants(&expr), lit_i64(20));
218 }
219
220 #[test]
221 fn fold_not_literal() {
222 assert_eq!(
223 fold_constants(&Expr::Not(Box::new(lit_bool(true)))),
224 lit_bool(false)
225 );
226 assert_eq!(
227 fold_constants(&Expr::Not(Box::new(lit_bool(false)))),
228 lit_bool(true)
229 );
230 }
231
232 #[test]
233 fn fold_boolean_identity_and() {
234 let x = col("x");
235 let expr = Expr::BinaryExpr(BinaryExpr::new(
237 Box::new(lit_bool(true)),
238 Operator::And,
239 Box::new(x.clone()),
240 ));
241 assert_eq!(fold_constants(&expr), x);
242
243 let expr = Expr::BinaryExpr(BinaryExpr::new(
245 Box::new(lit_bool(false)),
246 Operator::And,
247 Box::new(col("x")),
248 ));
249 assert_eq!(fold_constants(&expr), lit_bool(false));
250 }
251
252 #[test]
253 fn fold_boolean_identity_or() {
254 let x = col("x");
255 let expr = Expr::BinaryExpr(BinaryExpr::new(
257 Box::new(lit_bool(false)),
258 Operator::Or,
259 Box::new(x.clone()),
260 ));
261 assert_eq!(fold_constants(&expr), x);
262
263 let expr = Expr::BinaryExpr(BinaryExpr::new(
265 Box::new(lit_bool(true)),
266 Operator::Or,
267 Box::new(col("x")),
268 ));
269 assert_eq!(fold_constants(&expr), lit_bool(true));
270 }
271
272 #[test]
273 fn no_fold_on_column_expr() {
274 let expr = Expr::BinaryExpr(BinaryExpr::new(
275 Box::new(col("a")),
276 Operator::Plus,
277 Box::new(lit_i64(1)),
278 ));
279 let folded = fold_constants(&expr);
280 assert!(matches!(folded, Expr::BinaryExpr(_)));
282 }
283
284 #[test]
285 fn fold_division_by_zero_no_fold() {
286 let expr = Expr::BinaryExpr(BinaryExpr::new(
287 Box::new(lit_i64(10)),
288 Operator::Divide,
289 Box::new(lit_i64(0)),
290 ));
291 let folded = fold_constants(&expr);
293 assert!(matches!(folded, Expr::BinaryExpr(_)));
294 }
295
296 #[test]
297 fn fold_comparison() {
298 let expr = Expr::BinaryExpr(BinaryExpr::new(
299 Box::new(lit_i64(5)),
300 Operator::Gt,
301 Box::new(lit_i64(3)),
302 ));
303 assert_eq!(fold_constants(&expr), lit_bool(true));
304 }
305}