Skip to main content

laminar_core/compiler/
expr.rs

1//! Compiled expression evaluator using Cranelift JIT.
2//!
3//! [`ExprCompiler`] translates `DataFusion` [`Expr`] trees into native machine code
4//! via Cranelift. The generated functions operate directly on [`super::EventRow`] byte
5//! buffers using pointer arithmetic — zero allocations, zero virtual dispatch.
6//!
7//! # Supported Expressions
8//!
9//! - Column references (load from row pointer + field offset)
10//! - Literals (`i64`, `f64`, `i32`, `bool`, null)
11//! - Binary arithmetic (`+`, `-`, `*`, `/`, `%`) for integer and float types
12//! - Comparisons (`=`, `!=`, `<`, `<=`, `>`, `>=`)
13//! - Boolean logic (`AND`, `OR`) with short-circuit evaluation
14//! - `NOT`, `IS NULL`, `IS NOT NULL`
15//! - `CAST` between numeric types
16//! - `CASE WHEN ... THEN ... ELSE ... END`
17//!
18//! # Null Handling
19//!
20//! Null propagation follows SQL semantics: any arithmetic or comparison with a
21//! NULL operand produces NULL. `IS NULL` / `IS NOT NULL` inspect the null flag
22//! directly.
23
24use cranelift_codegen::ir::condcodes::{FloatCC, IntCC};
25use cranelift_codegen::ir::types::{self as cl_types};
26use cranelift_codegen::ir::{
27    AbiParam, BlockArg, Function, InstBuilder, MemFlags, Type as CraneliftType, UserFuncName, Value,
28};
29use cranelift_codegen::Context;
30use cranelift_frontend::FunctionBuilder;
31use cranelift_module::Module;
32use datafusion_common::ScalarValue;
33use datafusion_expr::{BinaryExpr, Expr, Operator};
34
35use super::error::{CompileError, FilterFn, ScalarFn};
36use super::fold::fold_constants;
37use super::jit::JitContext;
38use super::row::{FieldType, RowSchema};
39
40/// Pointer type for the target architecture.
41const PTR_TYPE: CraneliftType = cl_types::I64;
42
43/// Tracks a compiled SSA value plus its null state.
44pub(crate) struct CompiledValue {
45    /// The SSA value in Cranelift IR.
46    pub(crate) value: Value,
47    /// Whether this value can be null.
48    pub(crate) is_nullable: bool,
49    /// SSA value holding the null flag (i8: 0 = valid, nonzero = null).
50    pub(crate) null_flag: Option<Value>,
51    /// The [`FieldType`] of this value.
52    pub(crate) value_type: FieldType,
53}
54
55/// Compiles `DataFusion` expressions into native functions via Cranelift.
56///
57/// Borrows a [`JitContext`] (for the module/builder) and a [`RowSchema`]
58/// (for field layout resolution). Each `compile_*` call produces one native
59/// function.
60pub struct ExprCompiler<'a> {
61    jit: &'a mut JitContext,
62    schema: &'a RowSchema,
63}
64
65impl<'a> ExprCompiler<'a> {
66    /// Creates a new compiler for the given JIT context and row schema.
67    pub fn new(jit: &'a mut JitContext, schema: &'a RowSchema) -> Self {
68        Self { jit, schema }
69    }
70
71    /// Compiles a filter expression into a native `FilterFn`.
72    ///
73    /// The generated function has signature `fn(*const u8) -> u8` where the
74    /// argument is a pointer to an `EventRow` byte buffer. Returns 1 if the
75    /// row passes the filter, 0 otherwise. NULL filter results are treated
76    /// as false (row is rejected).
77    ///
78    /// # Errors
79    ///
80    /// Returns [`CompileError`] if the expression contains unsupported nodes.
81    ///
82    /// # Panics
83    ///
84    /// Panics if Cranelift fails to finalize function definitions (internal error).
85    pub fn compile_filter(&mut self, expr: &Expr) -> Result<FilterFn, CompileError> {
86        let expr = fold_constants(expr);
87        let func_name = self.jit.next_func_name("filter");
88
89        let mut sig = self.jit.module().make_signature();
90        sig.params.push(AbiParam::new(PTR_TYPE));
91        sig.returns.push(AbiParam::new(cl_types::I8));
92
93        let func_id = self.jit.module().declare_function(
94            &func_name,
95            cranelift_module::Linkage::Local,
96            &sig,
97        )?;
98
99        let mut func = Function::with_name_signature(UserFuncName::testcase(&func_name), sig);
100
101        {
102            let builder_ctx = self.jit.builder_ctx();
103            let mut builder = FunctionBuilder::new(&mut func, builder_ctx);
104            let entry = builder.create_block();
105            builder.append_block_params_for_function_params(entry);
106            builder.switch_to_block(entry);
107            builder.seal_block(entry);
108
109            let row_ptr = builder.block_params(entry)[0];
110            let compiled = compile_expr_inner(&mut builder, self.schema, &expr, row_ptr)?;
111
112            // NULL filter result → 0 (reject row)
113            let result = if compiled.is_nullable {
114                if let Some(null_flag) = compiled.null_flag {
115                    let zero = builder.ins().iconst(cl_types::I8, 0);
116                    let is_null = builder.ins().icmp_imm(IntCC::NotEqual, null_flag, 0);
117                    builder.ins().select(is_null, zero, compiled.value)
118                } else {
119                    compiled.value
120                }
121            } else {
122                compiled.value
123            };
124
125            builder.ins().return_(&[result]);
126            builder.finalize();
127        }
128
129        let mut ctx = Context::for_function(func);
130        self.jit
131            .module()
132            .define_function(func_id, &mut ctx)
133            .map_err(|e| CompileError::Cranelift(Box::new(e)))?;
134        self.jit.module().finalize_definitions().unwrap();
135
136        let code_ptr = self.jit.module().get_finalized_function(func_id);
137        // SAFETY: `get_finalized_function` returns a pointer to machine code compiled by
138        // Cranelift with the ABI signature declared via `declare_function` (one i64 param,
139        // one i8 return). The pointer remains valid for the lifetime of the JIT module
140        // (held by `JitContext`), which outlives every compiled function reference.
141        Ok(unsafe { std::mem::transmute::<*const u8, FilterFn>(code_ptr) })
142    }
143
144    /// Compiles a scalar expression into a native `ScalarFn`.
145    ///
146    /// The generated function has signature `fn(*const u8, *mut u8) -> u8`
147    /// where the first argument is a pointer to an `EventRow` byte buffer and
148    /// the second is a pointer to the output buffer. Returns 1 if the result
149    /// is null, 0 if valid.
150    ///
151    /// # Errors
152    ///
153    /// Returns [`CompileError`] if the expression contains unsupported nodes.
154    ///
155    /// # Panics
156    ///
157    /// Panics if Cranelift fails to finalize function definitions (internal error).
158    pub fn compile_scalar(
159        &mut self,
160        expr: &Expr,
161        output_type: &FieldType,
162    ) -> Result<ScalarFn, CompileError> {
163        let expr = fold_constants(expr);
164        let func_name = self.jit.next_func_name("scalar");
165
166        let mut sig = self.jit.module().make_signature();
167        sig.params.push(AbiParam::new(PTR_TYPE));
168        sig.params.push(AbiParam::new(PTR_TYPE));
169        sig.returns.push(AbiParam::new(cl_types::I8));
170
171        let func_id = self.jit.module().declare_function(
172            &func_name,
173            cranelift_module::Linkage::Local,
174            &sig,
175        )?;
176
177        let mut func = Function::with_name_signature(UserFuncName::testcase(&func_name), sig);
178
179        {
180            let builder_ctx = self.jit.builder_ctx();
181            let mut builder = FunctionBuilder::new(&mut func, builder_ctx);
182            let entry = builder.create_block();
183            builder.append_block_params_for_function_params(entry);
184            builder.switch_to_block(entry);
185            builder.seal_block(entry);
186
187            let row_ptr = builder.block_params(entry)[0];
188            let out_ptr = builder.block_params(entry)[1];
189
190            let compiled = compile_expr_inner(&mut builder, self.schema, &expr, row_ptr)?;
191
192            // Write the value to the output pointer.
193            let mem_flags = MemFlags::trusted();
194            builder.ins().store(mem_flags, compiled.value, out_ptr, 0);
195
196            // Return is_null flag.
197            let is_null = if compiled.is_nullable {
198                compiled
199                    .null_flag
200                    .unwrap_or_else(|| builder.ins().iconst(cl_types::I8, 0))
201            } else {
202                builder.ins().iconst(cl_types::I8, 0)
203            };
204
205            let is_null_i8 = if builder.func.dfg.value_type(is_null) == cl_types::I8 {
206                is_null
207            } else {
208                builder.ins().ireduce(cl_types::I8, is_null)
209            };
210
211            _ = output_type;
212            builder.ins().return_(&[is_null_i8]);
213            builder.finalize();
214        }
215
216        let mut ctx = Context::for_function(func);
217        self.jit
218            .module()
219            .define_function(func_id, &mut ctx)
220            .map_err(|e| CompileError::Cranelift(Box::new(e)))?;
221        self.jit.module().finalize_definitions().unwrap();
222
223        let code_ptr = self.jit.module().get_finalized_function(func_id);
224        // SAFETY: `get_finalized_function` returns a pointer to machine code compiled by
225        // Cranelift with the ABI signature declared via `declare_function` (two pointer
226        // params, one i8 return). The pointer remains valid for the lifetime of the JIT
227        // module (held by `JitContext`), which outlives every compiled function reference.
228        Ok(unsafe { std::mem::transmute::<*const u8, ScalarFn>(code_ptr) })
229    }
230}
231
232// ---------------------------------------------------------------------------
233// Free functions for IR generation — decoupled from JitContext borrow
234// ---------------------------------------------------------------------------
235
236/// Recursively compiles an expression node into Cranelift IR.
237pub(crate) fn compile_expr_inner(
238    builder: &mut FunctionBuilder,
239    schema: &RowSchema,
240    expr: &Expr,
241    row_ptr: Value,
242) -> Result<CompiledValue, CompileError> {
243    match expr {
244        Expr::Column(col) => compile_column(builder, schema, &col.name, row_ptr),
245        Expr::Literal(scalar, _) => compile_literal(builder, scalar),
246        Expr::BinaryExpr(binary) => compile_binary(builder, schema, binary, row_ptr),
247        Expr::Not(inner) => compile_not(builder, schema, inner, row_ptr),
248        Expr::IsNull(inner) => compile_is_null(builder, schema, inner, row_ptr, false),
249        Expr::IsNotNull(inner) => compile_is_null(builder, schema, inner, row_ptr, true),
250        Expr::Cast(cast) => compile_cast(builder, schema, &cast.expr, &cast.data_type, row_ptr),
251        Expr::Case(case) => compile_case(builder, schema, case, row_ptr),
252        other => Err(CompileError::UnsupportedExpr(format!("{other}"))),
253    }
254}
255
256/// Compiles a column reference — loads the field value from the row pointer.
257fn compile_column(
258    builder: &mut FunctionBuilder,
259    schema: &RowSchema,
260    name: &str,
261    row_ptr: Value,
262) -> Result<CompiledValue, CompileError> {
263    let field_idx = schema
264        .arrow_schema()
265        .index_of(name)
266        .map_err(|_| CompileError::ColumnNotFound(name.to_string()))?;
267
268    let layout = schema.field(field_idx);
269    let cl_type = field_type_to_cranelift(layout.field_type);
270    let mem_flags = MemFlags::trusted();
271
272    #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
273    let offset = layout.offset as i32;
274    let value = builder.ins().load(cl_type, mem_flags, row_ptr, offset);
275
276    let arrow_field = &schema.arrow_schema().fields()[field_idx];
277    let is_nullable = arrow_field.is_nullable();
278
279    let null_flag = if is_nullable {
280        Some(emit_null_check(builder, schema, field_idx, row_ptr))
281    } else {
282        None
283    };
284
285    Ok(CompiledValue {
286        value,
287        is_nullable,
288        null_flag,
289        value_type: layout.field_type,
290    })
291}
292
293/// Compiles a scalar literal into a Cranelift constant.
294fn compile_literal(
295    builder: &mut FunctionBuilder,
296    scalar: &ScalarValue,
297) -> Result<CompiledValue, CompileError> {
298    match scalar {
299        ScalarValue::Boolean(Some(b)) => Ok(CompiledValue {
300            value: builder.ins().iconst(cl_types::I8, i64::from(*b)),
301            is_nullable: false,
302            null_flag: None,
303            value_type: FieldType::Bool,
304        }),
305        ScalarValue::Int8(Some(v)) => Ok(CompiledValue {
306            value: builder.ins().iconst(cl_types::I8, i64::from(*v)),
307            is_nullable: false,
308            null_flag: None,
309            value_type: FieldType::Int8,
310        }),
311        ScalarValue::Int16(Some(v)) => Ok(CompiledValue {
312            value: builder.ins().iconst(cl_types::I16, i64::from(*v)),
313            is_nullable: false,
314            null_flag: None,
315            value_type: FieldType::Int16,
316        }),
317        ScalarValue::Int32(Some(v)) => Ok(CompiledValue {
318            value: builder.ins().iconst(cl_types::I32, i64::from(*v)),
319            is_nullable: false,
320            null_flag: None,
321            value_type: FieldType::Int32,
322        }),
323        ScalarValue::Int64(Some(v)) => Ok(CompiledValue {
324            value: builder.ins().iconst(cl_types::I64, *v),
325            is_nullable: false,
326            null_flag: None,
327            value_type: FieldType::Int64,
328        }),
329        ScalarValue::UInt8(Some(v)) => Ok(CompiledValue {
330            value: builder.ins().iconst(cl_types::I8, i64::from(*v)),
331            is_nullable: false,
332            null_flag: None,
333            value_type: FieldType::UInt8,
334        }),
335        ScalarValue::UInt16(Some(v)) => Ok(CompiledValue {
336            value: builder.ins().iconst(cl_types::I16, i64::from(*v)),
337            is_nullable: false,
338            null_flag: None,
339            value_type: FieldType::UInt16,
340        }),
341        ScalarValue::UInt32(Some(v)) => Ok(CompiledValue {
342            value: builder.ins().iconst(cl_types::I32, i64::from(*v)),
343            is_nullable: false,
344            null_flag: None,
345            value_type: FieldType::UInt32,
346        }),
347        ScalarValue::UInt64(Some(v)) => Ok(CompiledValue {
348            #[allow(clippy::cast_possible_wrap)]
349            value: builder.ins().iconst(cl_types::I64, (*v).cast_signed()),
350            is_nullable: false,
351            null_flag: None,
352            value_type: FieldType::UInt64,
353        }),
354        ScalarValue::Float32(Some(v)) => Ok(CompiledValue {
355            value: builder.ins().f32const(*v),
356            is_nullable: false,
357            null_flag: None,
358            value_type: FieldType::Float32,
359        }),
360        ScalarValue::Float64(Some(v)) => Ok(CompiledValue {
361            value: builder.ins().f64const(*v),
362            is_nullable: false,
363            null_flag: None,
364            value_type: FieldType::Float64,
365        }),
366        // Null literals: produce a zero value with null_flag = 1
367        ScalarValue::Boolean(None) => Ok(null_value(builder, FieldType::Bool)),
368        ScalarValue::Int64(None) => Ok(null_value(builder, FieldType::Int64)),
369        ScalarValue::Float64(None) => Ok(null_value(builder, FieldType::Float64)),
370        ScalarValue::Int32(None) => Ok(null_value(builder, FieldType::Int32)),
371        _ => Err(CompileError::UnsupportedLiteral),
372    }
373}
374
375/// Creates a null value of the given type.
376pub(crate) fn null_value(builder: &mut FunctionBuilder, field_type: FieldType) -> CompiledValue {
377    let cl_type = field_type_to_cranelift(field_type);
378    let value = if cl_type.is_float() {
379        if cl_type == cl_types::F32 {
380            builder.ins().f32const(0.0)
381        } else {
382            builder.ins().f64const(0.0)
383        }
384    } else {
385        builder.ins().iconst(cl_type, 0)
386    };
387    let null_flag = builder.ins().iconst(cl_types::I8, 1);
388    CompiledValue {
389        value,
390        is_nullable: true,
391        null_flag: Some(null_flag),
392        value_type: field_type,
393    }
394}
395
396/// Compiles a binary expression (arithmetic, comparison, boolean logic).
397fn compile_binary(
398    builder: &mut FunctionBuilder,
399    schema: &RowSchema,
400    binary: &BinaryExpr,
401    row_ptr: Value,
402) -> Result<CompiledValue, CompileError> {
403    match binary.op {
404        Operator::And => return compile_short_circuit_and(builder, schema, binary, row_ptr),
405        Operator::Or => return compile_short_circuit_or(builder, schema, binary, row_ptr),
406        _ => {}
407    }
408
409    let lhs = compile_expr_inner(builder, schema, &binary.left, row_ptr)?;
410    let rhs = compile_expr_inner(builder, schema, &binary.right, row_ptr)?;
411
412    let (is_nullable, merged_null) = merge_null_flags(builder, &lhs, &rhs);
413    let result_value = emit_binary_op(builder, &lhs, &rhs, binary.op)?;
414
415    Ok(CompiledValue {
416        value: result_value.value,
417        is_nullable,
418        null_flag: merged_null,
419        value_type: result_value.value_type,
420    })
421}
422
423/// Emits the binary operation IR for the given operands and operator.
424fn emit_binary_op(
425    builder: &mut FunctionBuilder,
426    lhs: &CompiledValue,
427    rhs: &CompiledValue,
428    op: Operator,
429) -> Result<CompiledValue, CompileError> {
430    let lhs_type = lhs.value_type;
431    let cl_type = field_type_to_cranelift(lhs_type);
432
433    if cl_type.is_int() {
434        return emit_int_binary(builder, lhs, rhs, op);
435    }
436    if cl_type.is_float() {
437        return emit_float_binary(builder, lhs, rhs, op);
438    }
439
440    Err(CompileError::UnsupportedBinaryOp(lhs_type, op))
441}
442
443/// Emits integer arithmetic and comparison operations.
444fn emit_int_binary(
445    builder: &mut FunctionBuilder,
446    lhs: &CompiledValue,
447    rhs: &CompiledValue,
448    op: Operator,
449) -> Result<CompiledValue, CompileError> {
450    let lhs_type = lhs.value_type;
451    let is_signed = is_signed_type(lhs_type);
452
453    let value = match op {
454        Operator::Plus => builder.ins().iadd(lhs.value, rhs.value),
455        Operator::Minus => builder.ins().isub(lhs.value, rhs.value),
456        Operator::Multiply => builder.ins().imul(lhs.value, rhs.value),
457        Operator::Divide => {
458            if is_signed {
459                builder.ins().sdiv(lhs.value, rhs.value)
460            } else {
461                builder.ins().udiv(lhs.value, rhs.value)
462            }
463        }
464        Operator::Modulo => {
465            if is_signed {
466                builder.ins().srem(lhs.value, rhs.value)
467            } else {
468                builder.ins().urem(lhs.value, rhs.value)
469            }
470        }
471        Operator::Eq
472        | Operator::NotEq
473        | Operator::Lt
474        | Operator::LtEq
475        | Operator::Gt
476        | Operator::GtEq => {
477            let cc = int_cmp_cond(op, is_signed);
478            let cmp_val = builder.ins().icmp(cc, lhs.value, rhs.value);
479            return Ok(CompiledValue {
480                value: cmp_val,
481                is_nullable: false,
482                null_flag: None,
483                value_type: FieldType::Bool,
484            });
485        }
486        _ => return Err(CompileError::UnsupportedBinaryOp(lhs_type, op)),
487    };
488
489    Ok(CompiledValue {
490        value,
491        is_nullable: false,
492        null_flag: None,
493        value_type: lhs_type,
494    })
495}
496
497/// Emits floating-point arithmetic and comparison operations.
498fn emit_float_binary(
499    builder: &mut FunctionBuilder,
500    lhs: &CompiledValue,
501    rhs: &CompiledValue,
502    op: Operator,
503) -> Result<CompiledValue, CompileError> {
504    let lhs_type = lhs.value_type;
505
506    let value = match op {
507        Operator::Plus => builder.ins().fadd(lhs.value, rhs.value),
508        Operator::Minus => builder.ins().fsub(lhs.value, rhs.value),
509        Operator::Multiply => builder.ins().fmul(lhs.value, rhs.value),
510        Operator::Divide => builder.ins().fdiv(lhs.value, rhs.value),
511        Operator::Eq
512        | Operator::NotEq
513        | Operator::Lt
514        | Operator::LtEq
515        | Operator::Gt
516        | Operator::GtEq => {
517            let cc = float_cmp_cond(op);
518            let cmp_val = builder.ins().fcmp(cc, lhs.value, rhs.value);
519            return Ok(CompiledValue {
520                value: cmp_val,
521                is_nullable: false,
522                null_flag: None,
523                value_type: FieldType::Bool,
524            });
525        }
526        _ => return Err(CompileError::UnsupportedBinaryOp(lhs_type, op)),
527    };
528
529    Ok(CompiledValue {
530        value,
531        is_nullable: false,
532        null_flag: None,
533        value_type: lhs_type,
534    })
535}
536
537/// Compiles `NOT expr` — flips a boolean i8 value.
538fn compile_not(
539    builder: &mut FunctionBuilder,
540    schema: &RowSchema,
541    inner: &Expr,
542    row_ptr: Value,
543) -> Result<CompiledValue, CompileError> {
544    let compiled = compile_expr_inner(builder, schema, inner, row_ptr)?;
545    let one = builder.ins().iconst(cl_types::I8, 1);
546    let flipped = builder.ins().bxor(compiled.value, one);
547    Ok(CompiledValue {
548        value: flipped,
549        is_nullable: compiled.is_nullable,
550        null_flag: compiled.null_flag,
551        value_type: FieldType::Bool,
552    })
553}
554
555/// Compiles `IS NULL` / `IS NOT NULL`.
556fn compile_is_null(
557    builder: &mut FunctionBuilder,
558    schema: &RowSchema,
559    inner: &Expr,
560    row_ptr: Value,
561    invert: bool,
562) -> Result<CompiledValue, CompileError> {
563    let compiled = compile_expr_inner(builder, schema, inner, row_ptr)?;
564
565    let result = if let Some(nf) = compiled.null_flag {
566        if invert {
567            builder.ins().icmp_imm(IntCC::Equal, nf, 0)
568        } else {
569            builder.ins().icmp_imm(IntCC::NotEqual, nf, 0)
570        }
571    } else {
572        builder.ins().iconst(cl_types::I8, i64::from(invert))
573    };
574
575    Ok(CompiledValue {
576        value: result,
577        is_nullable: false,
578        null_flag: None,
579        value_type: FieldType::Bool,
580    })
581}
582
583/// Compiles a CAST expression between numeric types.
584fn compile_cast(
585    builder: &mut FunctionBuilder,
586    schema: &RowSchema,
587    inner: &Expr,
588    target_dt: &arrow_schema::DataType,
589    row_ptr: Value,
590) -> Result<CompiledValue, CompileError> {
591    let compiled = compile_expr_inner(builder, schema, inner, row_ptr)?;
592    let out_type = FieldType::from_arrow(target_dt)
593        .ok_or_else(|| CompileError::UnsupportedExpr(format!("CAST to {target_dt}")))?;
594    let target_cl = field_type_to_cranelift(out_type);
595    let source_cl = field_type_to_cranelift(compiled.value_type);
596
597    if source_cl == target_cl {
598        return Ok(CompiledValue {
599            value_type: out_type,
600            ..compiled
601        });
602    }
603
604    let value = if source_cl.is_int() && target_cl.is_int() {
605        if target_cl.bits() > source_cl.bits() {
606            if is_signed_type(compiled.value_type) {
607                builder.ins().sextend(target_cl, compiled.value)
608            } else {
609                builder.ins().uextend(target_cl, compiled.value)
610            }
611        } else {
612            builder.ins().ireduce(target_cl, compiled.value)
613        }
614    } else if source_cl.is_int() && target_cl.is_float() {
615        if is_signed_type(compiled.value_type) {
616            builder.ins().fcvt_from_sint(target_cl, compiled.value)
617        } else {
618            builder.ins().fcvt_from_uint(target_cl, compiled.value)
619        }
620    } else if source_cl.is_float() && target_cl.is_int() {
621        if is_signed_type(out_type) {
622            builder.ins().fcvt_to_sint(target_cl, compiled.value)
623        } else {
624            builder.ins().fcvt_to_uint(target_cl, compiled.value)
625        }
626    } else if source_cl.is_float() && target_cl.is_float() {
627        if target_cl.bits() > source_cl.bits() {
628            builder.ins().fpromote(target_cl, compiled.value)
629        } else {
630            builder.ins().fdemote(target_cl, compiled.value)
631        }
632    } else {
633        return Err(CompileError::UnsupportedExpr(format!(
634            "CAST from {source_cl} to {target_cl}"
635        )));
636    };
637
638    Ok(CompiledValue {
639        value,
640        is_nullable: compiled.is_nullable,
641        null_flag: compiled.null_flag,
642        value_type: out_type,
643    })
644}
645
646/// Compiles a CASE WHEN expression.
647fn compile_case(
648    builder: &mut FunctionBuilder,
649    schema: &RowSchema,
650    case: &datafusion_expr::Case,
651    row_ptr: Value,
652) -> Result<CompiledValue, CompileError> {
653    let merge_block = builder.create_block();
654
655    if case.when_then_expr.is_empty() {
656        return Err(CompileError::UnsupportedExpr("empty CASE".to_string()));
657    }
658
659    let first_then = &case.when_then_expr[0].1;
660    let result_type = infer_expr_type(schema, first_then)?;
661    let cl_type = field_type_to_cranelift(result_type);
662    builder.append_block_param(merge_block, cl_type);
663    builder.append_block_param(merge_block, cl_types::I8);
664
665    for (when_expr, then_expr) in &case.when_then_expr {
666        let then_block = builder.create_block();
667        let else_block = builder.create_block();
668
669        let cond = compile_expr_inner(builder, schema, when_expr, row_ptr)?;
670        builder
671            .ins()
672            .brif(cond.value, then_block, &[], else_block, &[]);
673
674        builder.switch_to_block(then_block);
675        builder.seal_block(then_block);
676        let then_val = compile_expr_inner(builder, schema, then_expr, row_ptr)?;
677        let then_null = then_val
678            .null_flag
679            .unwrap_or_else(|| builder.ins().iconst(cl_types::I8, 0));
680        builder.ins().jump(
681            merge_block,
682            &[BlockArg::Value(then_val.value), BlockArg::Value(then_null)],
683        );
684
685        builder.switch_to_block(else_block);
686        builder.seal_block(else_block);
687    }
688
689    let else_val = if let Some(else_expr) = &case.else_expr {
690        compile_expr_inner(builder, schema, else_expr, row_ptr)?
691    } else {
692        null_value(builder, result_type)
693    };
694    let else_null = else_val
695        .null_flag
696        .unwrap_or_else(|| builder.ins().iconst(cl_types::I8, 0));
697    builder.ins().jump(
698        merge_block,
699        &[BlockArg::Value(else_val.value), BlockArg::Value(else_null)],
700    );
701
702    builder.switch_to_block(merge_block);
703    builder.seal_block(merge_block);
704
705    let result_value = builder.block_params(merge_block)[0];
706    let result_null = builder.block_params(merge_block)[1];
707
708    Ok(CompiledValue {
709        value: result_value,
710        is_nullable: true,
711        null_flag: Some(result_null),
712        value_type: result_type,
713    })
714}
715
716/// Compiles AND with short-circuit: if LHS is false, skip RHS.
717fn compile_short_circuit_and(
718    builder: &mut FunctionBuilder,
719    schema: &RowSchema,
720    binary: &BinaryExpr,
721    row_ptr: Value,
722) -> Result<CompiledValue, CompileError> {
723    let merge_block = builder.create_block();
724    builder.append_block_param(merge_block, cl_types::I8);
725    let rhs_block = builder.create_block();
726
727    let lhs = compile_expr_inner(builder, schema, &binary.left, row_ptr)?;
728
729    let false_val = builder.ins().iconst(cl_types::I8, 0);
730    builder.ins().brif(
731        lhs.value,
732        rhs_block,
733        &[],
734        merge_block,
735        &[BlockArg::Value(false_val)],
736    );
737
738    builder.switch_to_block(rhs_block);
739    builder.seal_block(rhs_block);
740    let rhs = compile_expr_inner(builder, schema, &binary.right, row_ptr)?;
741    builder
742        .ins()
743        .jump(merge_block, &[BlockArg::Value(rhs.value)]);
744
745    builder.switch_to_block(merge_block);
746    builder.seal_block(merge_block);
747
748    let result = builder.block_params(merge_block)[0];
749    Ok(CompiledValue {
750        value: result,
751        is_nullable: false,
752        null_flag: None,
753        value_type: FieldType::Bool,
754    })
755}
756
757/// Compiles OR with short-circuit: if LHS is true, skip RHS.
758fn compile_short_circuit_or(
759    builder: &mut FunctionBuilder,
760    schema: &RowSchema,
761    binary: &BinaryExpr,
762    row_ptr: Value,
763) -> Result<CompiledValue, CompileError> {
764    let merge_block = builder.create_block();
765    builder.append_block_param(merge_block, cl_types::I8);
766    let rhs_block = builder.create_block();
767
768    let lhs = compile_expr_inner(builder, schema, &binary.left, row_ptr)?;
769
770    let true_val = builder.ins().iconst(cl_types::I8, 1);
771    builder.ins().brif(
772        lhs.value,
773        merge_block,
774        &[BlockArg::Value(true_val)],
775        rhs_block,
776        &[],
777    );
778
779    builder.switch_to_block(rhs_block);
780    builder.seal_block(rhs_block);
781    let rhs = compile_expr_inner(builder, schema, &binary.right, row_ptr)?;
782    builder
783        .ins()
784        .jump(merge_block, &[BlockArg::Value(rhs.value)]);
785
786    builder.switch_to_block(merge_block);
787    builder.seal_block(merge_block);
788
789    let result = builder.block_params(merge_block)[0];
790    Ok(CompiledValue {
791        value: result,
792        is_nullable: false,
793        null_flag: None,
794        value_type: FieldType::Bool,
795    })
796}
797
798/// Loads the null bitmap bit for the given field index.
799pub(crate) fn emit_null_check(
800    builder: &mut FunctionBuilder,
801    schema: &RowSchema,
802    field_idx: usize,
803    row_ptr: Value,
804) -> Value {
805    let layout = schema.field(field_idx);
806    let null_bit = layout.null_bit;
807    let byte_idx = RowSchema::header_size() + null_bit / 8;
808    let bit_idx = null_bit % 8;
809
810    let mem_flags = MemFlags::trusted();
811    #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
812    let byte_offset = byte_idx as i32;
813    let byte_val = builder
814        .ins()
815        .load(cl_types::I8, mem_flags, row_ptr, byte_offset);
816    let mask = builder.ins().iconst(cl_types::I8, 1 << bit_idx);
817    builder.ins().band(byte_val, mask)
818}
819
820/// Merges null flags from two operands (SQL null propagation).
821pub(crate) fn merge_null_flags(
822    builder: &mut FunctionBuilder,
823    lhs: &CompiledValue,
824    rhs: &CompiledValue,
825) -> (bool, Option<Value>) {
826    match (lhs.null_flag, rhs.null_flag) {
827        (Some(l), Some(r)) => {
828            let merged = builder.ins().bor(l, r);
829            (true, Some(merged))
830        }
831        (Some(f), None) | (None, Some(f)) => (true, Some(f)),
832        (None, None) => (false, None),
833    }
834}
835
836/// Infers the output [`FieldType`] of an expression without compiling it.
837pub(crate) fn infer_expr_type(schema: &RowSchema, expr: &Expr) -> Result<FieldType, CompileError> {
838    match expr {
839        Expr::Column(col) => {
840            let idx = schema
841                .arrow_schema()
842                .index_of(&col.name)
843                .map_err(|_| CompileError::ColumnNotFound(col.name.clone()))?;
844            Ok(schema.field(idx).field_type)
845        }
846        Expr::Literal(scalar, _) => match scalar {
847            ScalarValue::Boolean(_) => Ok(FieldType::Bool),
848            ScalarValue::Int8(_) => Ok(FieldType::Int8),
849            ScalarValue::Int16(_) => Ok(FieldType::Int16),
850            ScalarValue::Int32(_) => Ok(FieldType::Int32),
851            ScalarValue::Int64(_) => Ok(FieldType::Int64),
852            ScalarValue::UInt8(_) => Ok(FieldType::UInt8),
853            ScalarValue::UInt16(_) => Ok(FieldType::UInt16),
854            ScalarValue::UInt32(_) => Ok(FieldType::UInt32),
855            ScalarValue::UInt64(_) => Ok(FieldType::UInt64),
856            ScalarValue::Float32(_) => Ok(FieldType::Float32),
857            ScalarValue::Float64(_) => Ok(FieldType::Float64),
858            _ => Err(CompileError::UnsupportedLiteral),
859        },
860        Expr::BinaryExpr(binary) => match binary.op {
861            Operator::Eq
862            | Operator::NotEq
863            | Operator::Lt
864            | Operator::LtEq
865            | Operator::Gt
866            | Operator::GtEq
867            | Operator::And
868            | Operator::Or => Ok(FieldType::Bool),
869            _ => infer_expr_type(schema, &binary.left),
870        },
871        Expr::Not(_) | Expr::IsNull(_) | Expr::IsNotNull(_) => Ok(FieldType::Bool),
872        Expr::Cast(cast) => FieldType::from_arrow(&cast.data_type)
873            .ok_or_else(|| CompileError::UnsupportedExpr(format!("CAST to {}", cast.data_type))),
874        Expr::Case(case) => {
875            if let Some((_, then_expr)) = case.when_then_expr.first() {
876                infer_expr_type(schema, then_expr)
877            } else {
878                Err(CompileError::UnsupportedExpr("empty CASE".to_string()))
879            }
880        }
881        other => Err(CompileError::UnsupportedExpr(format!("{other}"))),
882    }
883}
884
885// ---------------------------------------------------------------------------
886// Helper functions
887// ---------------------------------------------------------------------------
888
889/// Maps a [`FieldType`] to the corresponding Cranelift IR type.
890pub(crate) fn field_type_to_cranelift(ft: FieldType) -> CraneliftType {
891    match ft {
892        FieldType::Bool | FieldType::Int8 | FieldType::UInt8 => cl_types::I8,
893        FieldType::Int16 | FieldType::UInt16 => cl_types::I16,
894        FieldType::Int32 | FieldType::UInt32 => cl_types::I32,
895        FieldType::Float32 => cl_types::F32,
896        FieldType::Int64
897        | FieldType::UInt64
898        | FieldType::TimestampMicros
899        | FieldType::Utf8
900        | FieldType::Binary => cl_types::I64,
901        FieldType::Float64 => cl_types::F64,
902    }
903}
904
905/// Returns `true` for signed integer types.
906fn is_signed_type(ft: FieldType) -> bool {
907    matches!(
908        ft,
909        FieldType::Int8
910            | FieldType::Int16
911            | FieldType::Int32
912            | FieldType::Int64
913            | FieldType::TimestampMicros
914    )
915}
916
917/// Maps a comparison [`Operator`] to a Cranelift integer condition code.
918fn int_cmp_cond(op: Operator, signed: bool) -> IntCC {
919    match (op, signed) {
920        (Operator::Eq, _) => IntCC::Equal,
921        (Operator::NotEq, _) => IntCC::NotEqual,
922        (Operator::Lt, true) => IntCC::SignedLessThan,
923        (Operator::Lt, false) => IntCC::UnsignedLessThan,
924        (Operator::LtEq, true) => IntCC::SignedLessThanOrEqual,
925        (Operator::LtEq, false) => IntCC::UnsignedLessThanOrEqual,
926        (Operator::Gt, true) => IntCC::SignedGreaterThan,
927        (Operator::Gt, false) => IntCC::UnsignedGreaterThan,
928        (Operator::GtEq, true) => IntCC::SignedGreaterThanOrEqual,
929        (Operator::GtEq, false) => IntCC::UnsignedGreaterThanOrEqual,
930        _ => unreachable!("int_cmp_cond called with non-comparison operator"),
931    }
932}
933
934/// Maps a comparison [`Operator`] to a Cranelift float condition code.
935fn float_cmp_cond(op: Operator) -> FloatCC {
936    match op {
937        Operator::Eq => FloatCC::Equal,
938        Operator::NotEq => FloatCC::NotEqual,
939        Operator::Lt => FloatCC::LessThan,
940        Operator::LtEq => FloatCC::LessThanOrEqual,
941        Operator::Gt => FloatCC::GreaterThan,
942        Operator::GtEq => FloatCC::GreaterThanOrEqual,
943        _ => unreachable!("float_cmp_cond called with non-comparison operator"),
944    }
945}
946
947#[cfg(test)]
948mod tests {
949    use super::*;
950    use crate::compiler::row::{MutableEventRow, RowSchema};
951    use arrow_schema::{DataType, Field, Schema};
952    use bumpalo::Bump;
953    use datafusion_expr::{col, lit};
954    use std::sync::Arc;
955
956    fn make_schema(fields: Vec<(&str, DataType, bool)>) -> Arc<Schema> {
957        Arc::new(Schema::new(
958            fields
959                .into_iter()
960                .map(|(name, dt, nullable)| Field::new(name, dt, nullable))
961                .collect::<Vec<_>>(),
962        ))
963    }
964
965    fn make_row_bytes<'a>(
966        arena: &'a Bump,
967        schema: &'a RowSchema,
968        values: &[(usize, i64)],
969        nulls: &[usize],
970    ) -> &'a [u8] {
971        let mut row = MutableEventRow::new_in(arena, schema, 0);
972        for &(idx, val) in values {
973            row.set_i64(idx, val);
974        }
975        for &idx in nulls {
976            row.set_null(idx, true);
977        }
978        row.freeze().data()
979    }
980
981    // ---- Filter: column vs literal comparisons ----
982
983    #[test]
984    fn filter_col_gt_literal() {
985        let arrow = make_schema(vec![("val", DataType::Int64, false)]);
986        let rs = RowSchema::from_arrow(&arrow).unwrap();
987        let mut jit = JitContext::new().unwrap();
988        let mut compiler = ExprCompiler::new(&mut jit, &rs);
989
990        let expr = col("val").gt(lit(100_i64));
991        let filter = compiler.compile_filter(&expr).unwrap();
992
993        let arena = Bump::new();
994        let bytes = make_row_bytes(&arena, &rs, &[(0, 200)], &[]);
995        assert_eq!(unsafe { filter(bytes.as_ptr()) }, 1);
996
997        let bytes = make_row_bytes(&arena, &rs, &[(0, 50)], &[]);
998        assert_eq!(unsafe { filter(bytes.as_ptr()) }, 0);
999    }
1000
1001    #[test]
1002    fn filter_col_eq_literal() {
1003        let arrow = make_schema(vec![("id", DataType::Int64, false)]);
1004        let rs = RowSchema::from_arrow(&arrow).unwrap();
1005        let mut jit = JitContext::new().unwrap();
1006        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1007
1008        let expr = col("id").eq(lit(42_i64));
1009        let filter = compiler.compile_filter(&expr).unwrap();
1010
1011        let arena = Bump::new();
1012        let bytes = make_row_bytes(&arena, &rs, &[(0, 42)], &[]);
1013        assert_eq!(unsafe { filter(bytes.as_ptr()) }, 1);
1014
1015        let bytes = make_row_bytes(&arena, &rs, &[(0, 43)], &[]);
1016        assert_eq!(unsafe { filter(bytes.as_ptr()) }, 0);
1017    }
1018
1019    #[test]
1020    fn filter_col_lt_literal() {
1021        let arrow = make_schema(vec![("x", DataType::Int64, false)]);
1022        let rs = RowSchema::from_arrow(&arrow).unwrap();
1023        let mut jit = JitContext::new().unwrap();
1024        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1025
1026        let expr = col("x").lt(lit(10_i64));
1027        let filter = compiler.compile_filter(&expr).unwrap();
1028
1029        let arena = Bump::new();
1030        let bytes = make_row_bytes(&arena, &rs, &[(0, 5)], &[]);
1031        assert_eq!(unsafe { filter(bytes.as_ptr()) }, 1);
1032
1033        let bytes = make_row_bytes(&arena, &rs, &[(0, 15)], &[]);
1034        assert_eq!(unsafe { filter(bytes.as_ptr()) }, 0);
1035    }
1036
1037    #[test]
1038    fn filter_all_comparison_ops() {
1039        let arrow = make_schema(vec![("x", DataType::Int64, false)]);
1040        let rs = RowSchema::from_arrow(&arrow).unwrap();
1041        let mut jit = JitContext::new().unwrap();
1042        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1043
1044        let arena = Bump::new();
1045
1046        // lt_eq
1047        let f = compiler
1048            .compile_filter(&col("x").lt_eq(lit(10_i64)))
1049            .unwrap();
1050        let b = make_row_bytes(&arena, &rs, &[(0, 10)], &[]);
1051        assert_eq!(unsafe { f(b.as_ptr()) }, 1);
1052        let b = make_row_bytes(&arena, &rs, &[(0, 11)], &[]);
1053        assert_eq!(unsafe { f(b.as_ptr()) }, 0);
1054
1055        // gt_eq
1056        let f = compiler
1057            .compile_filter(&col("x").gt_eq(lit(10_i64)))
1058            .unwrap();
1059        let b = make_row_bytes(&arena, &rs, &[(0, 10)], &[]);
1060        assert_eq!(unsafe { f(b.as_ptr()) }, 1);
1061        let b = make_row_bytes(&arena, &rs, &[(0, 9)], &[]);
1062        assert_eq!(unsafe { f(b.as_ptr()) }, 0);
1063
1064        // not_eq
1065        let f = compiler
1066            .compile_filter(&col("x").not_eq(lit(42_i64)))
1067            .unwrap();
1068        let b = make_row_bytes(&arena, &rs, &[(0, 43)], &[]);
1069        assert_eq!(unsafe { f(b.as_ptr()) }, 1);
1070        let b = make_row_bytes(&arena, &rs, &[(0, 42)], &[]);
1071        assert_eq!(unsafe { f(b.as_ptr()) }, 0);
1072    }
1073
1074    // ---- Filter: compound boolean ----
1075
1076    #[test]
1077    fn filter_and_compound() {
1078        let arrow = make_schema(vec![
1079            ("a", DataType::Int64, false),
1080            ("b", DataType::Int64, false),
1081        ]);
1082        let rs = RowSchema::from_arrow(&arrow).unwrap();
1083        let mut jit = JitContext::new().unwrap();
1084        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1085
1086        let expr = col("a").gt(lit(1_i64)).and(col("b").lt(lit(10_i64)));
1087        let filter = compiler.compile_filter(&expr).unwrap();
1088
1089        let arena = Bump::new();
1090        let bytes = make_row_bytes(&arena, &rs, &[(0, 5), (1, 3)], &[]);
1091        assert_eq!(unsafe { filter(bytes.as_ptr()) }, 1);
1092
1093        let bytes = make_row_bytes(&arena, &rs, &[(0, 0), (1, 3)], &[]);
1094        assert_eq!(unsafe { filter(bytes.as_ptr()) }, 0);
1095
1096        let bytes = make_row_bytes(&arena, &rs, &[(0, 5), (1, 20)], &[]);
1097        assert_eq!(unsafe { filter(bytes.as_ptr()) }, 0);
1098    }
1099
1100    #[test]
1101    fn filter_or_compound() {
1102        let arrow = make_schema(vec![
1103            ("a", DataType::Int64, false),
1104            ("b", DataType::Int64, false),
1105        ]);
1106        let rs = RowSchema::from_arrow(&arrow).unwrap();
1107        let mut jit = JitContext::new().unwrap();
1108        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1109
1110        let expr = col("a").eq(lit(42_i64)).or(col("b").eq(lit(99_i64)));
1111        let filter = compiler.compile_filter(&expr).unwrap();
1112
1113        let arena = Bump::new();
1114        let bytes = make_row_bytes(&arena, &rs, &[(0, 42), (1, 0)], &[]);
1115        assert_eq!(unsafe { filter(bytes.as_ptr()) }, 1);
1116
1117        let bytes = make_row_bytes(&arena, &rs, &[(0, 0), (1, 99)], &[]);
1118        assert_eq!(unsafe { filter(bytes.as_ptr()) }, 1);
1119
1120        let bytes = make_row_bytes(&arena, &rs, &[(0, 0), (1, 0)], &[]);
1121        assert_eq!(unsafe { filter(bytes.as_ptr()) }, 0);
1122    }
1123
1124    #[test]
1125    fn filter_nested_and_or() {
1126        let arrow = make_schema(vec![
1127            ("a", DataType::Int64, false),
1128            ("b", DataType::Int64, false),
1129            ("c", DataType::Int64, false),
1130        ]);
1131        let rs = RowSchema::from_arrow(&arrow).unwrap();
1132        let mut jit = JitContext::new().unwrap();
1133        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1134
1135        let expr = col("a")
1136            .gt(lit(0_i64))
1137            .and(col("b").gt(lit(0_i64)))
1138            .or(col("c").gt(lit(0_i64)));
1139        let filter = compiler.compile_filter(&expr).unwrap();
1140
1141        let arena = Bump::new();
1142        let bytes = make_row_bytes(&arena, &rs, &[(0, 1), (1, 1), (2, 0)], &[]);
1143        assert_eq!(unsafe { filter(bytes.as_ptr()) }, 1);
1144
1145        let bytes = make_row_bytes(&arena, &rs, &[(0, 0), (1, 0), (2, 1)], &[]);
1146        assert_eq!(unsafe { filter(bytes.as_ptr()) }, 1);
1147
1148        let bytes = make_row_bytes(&arena, &rs, &[(0, 0), (1, 0), (2, 0)], &[]);
1149        assert_eq!(unsafe { filter(bytes.as_ptr()) }, 0);
1150    }
1151
1152    #[test]
1153    fn filter_not() {
1154        let arrow = make_schema(vec![("x", DataType::Int64, false)]);
1155        let rs = RowSchema::from_arrow(&arrow).unwrap();
1156        let mut jit = JitContext::new().unwrap();
1157        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1158
1159        let expr = Expr::Not(Box::new(col("x").gt(lit(10_i64))));
1160        let filter = compiler.compile_filter(&expr).unwrap();
1161
1162        let arena = Bump::new();
1163        let bytes = make_row_bytes(&arena, &rs, &[(0, 5)], &[]);
1164        assert_eq!(unsafe { filter(bytes.as_ptr()) }, 1);
1165
1166        let bytes = make_row_bytes(&arena, &rs, &[(0, 20)], &[]);
1167        assert_eq!(unsafe { filter(bytes.as_ptr()) }, 0);
1168    }
1169
1170    // ---- Scalar: arithmetic ----
1171
1172    #[test]
1173    fn scalar_add_i64() {
1174        let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1175        let rs = RowSchema::from_arrow(&arrow).unwrap();
1176        let mut jit = JitContext::new().unwrap();
1177        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1178
1179        let expr = col("val") + lit(10_i64);
1180        let scalar = compiler.compile_scalar(&expr, &FieldType::Int64).unwrap();
1181
1182        let arena = Bump::new();
1183        let bytes = make_row_bytes(&arena, &rs, &[(0, 32)], &[]);
1184        let mut output = [0u8; 8];
1185        let is_null = unsafe { scalar(bytes.as_ptr(), output.as_mut_ptr()) };
1186        assert_eq!(is_null, 0);
1187        assert_eq!(i64::from_le_bytes(output), 42);
1188    }
1189
1190    #[test]
1191    fn scalar_multiply_i64() {
1192        let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1193        let rs = RowSchema::from_arrow(&arrow).unwrap();
1194        let mut jit = JitContext::new().unwrap();
1195        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1196
1197        let expr = col("val") * lit(7_i64);
1198        let scalar = compiler.compile_scalar(&expr, &FieldType::Int64).unwrap();
1199
1200        let arena = Bump::new();
1201        let bytes = make_row_bytes(&arena, &rs, &[(0, 6)], &[]);
1202        let mut output = [0u8; 8];
1203        let is_null = unsafe { scalar(bytes.as_ptr(), output.as_mut_ptr()) };
1204        assert_eq!(is_null, 0);
1205        assert_eq!(i64::from_le_bytes(output), 42);
1206    }
1207
1208    #[test]
1209    fn scalar_sub_div_mod() {
1210        let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1211        let rs = RowSchema::from_arrow(&arrow).unwrap();
1212        let mut jit = JitContext::new().unwrap();
1213        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1214        let arena = Bump::new();
1215        let mut output = [0u8; 8];
1216
1217        // sub
1218        let s = compiler
1219            .compile_scalar(&(col("val") - lit(8_i64)), &FieldType::Int64)
1220            .unwrap();
1221        let b = make_row_bytes(&arena, &rs, &[(0, 50)], &[]);
1222        assert_eq!(unsafe { s(b.as_ptr(), output.as_mut_ptr()) }, 0);
1223        assert_eq!(i64::from_le_bytes(output), 42);
1224
1225        // div
1226        let s = compiler
1227            .compile_scalar(&(col("val") / lit(2_i64)), &FieldType::Int64)
1228            .unwrap();
1229        let b = make_row_bytes(&arena, &rs, &[(0, 84)], &[]);
1230        assert_eq!(unsafe { s(b.as_ptr(), output.as_mut_ptr()) }, 0);
1231        assert_eq!(i64::from_le_bytes(output), 42);
1232
1233        // mod
1234        let s = compiler
1235            .compile_scalar(&(col("val") % lit(10_i64)), &FieldType::Int64)
1236            .unwrap();
1237        let b = make_row_bytes(&arena, &rs, &[(0, 42)], &[]);
1238        assert_eq!(unsafe { s(b.as_ptr(), output.as_mut_ptr()) }, 0);
1239        assert_eq!(i64::from_le_bytes(output), 2);
1240    }
1241
1242    #[test]
1243    fn scalar_f64_add() {
1244        let arrow = make_schema(vec![("val", DataType::Float64, false)]);
1245        let rs = RowSchema::from_arrow(&arrow).unwrap();
1246        let mut jit = JitContext::new().unwrap();
1247        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1248
1249        let expr = col("val") + lit(1.5_f64);
1250        let scalar = compiler.compile_scalar(&expr, &FieldType::Float64).unwrap();
1251
1252        let arena = Bump::new();
1253        let mut row = MutableEventRow::new_in(&arena, &rs, 0);
1254        row.set_f64(0, 2.5);
1255        let frozen = row.freeze();
1256        let mut output = [0u8; 8];
1257        let is_null = unsafe { scalar(frozen.data().as_ptr(), output.as_mut_ptr()) };
1258        assert_eq!(is_null, 0);
1259        assert!((f64::from_le_bytes(output) - 4.0).abs() < f64::EPSILON);
1260    }
1261
1262    #[test]
1263    fn scalar_column_passthrough() {
1264        let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1265        let rs = RowSchema::from_arrow(&arrow).unwrap();
1266        let mut jit = JitContext::new().unwrap();
1267        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1268
1269        let scalar = compiler
1270            .compile_scalar(&col("val"), &FieldType::Int64)
1271            .unwrap();
1272
1273        let arena = Bump::new();
1274        let bytes = make_row_bytes(&arena, &rs, &[(0, 999)], &[]);
1275        let mut output = [0u8; 8];
1276        assert_eq!(unsafe { scalar(bytes.as_ptr(), output.as_mut_ptr()) }, 0);
1277        assert_eq!(i64::from_le_bytes(output), 999);
1278    }
1279
1280    // ---- Null handling ----
1281
1282    #[test]
1283    fn filter_null_column_rejects() {
1284        let arrow = make_schema(vec![("val", DataType::Int64, true)]);
1285        let rs = RowSchema::from_arrow(&arrow).unwrap();
1286        let mut jit = JitContext::new().unwrap();
1287        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1288
1289        let filter = compiler
1290            .compile_filter(&col("val").gt(lit(10_i64)))
1291            .unwrap();
1292
1293        let arena = Bump::new();
1294        let bytes = make_row_bytes(&arena, &rs, &[(0, 999)], &[0]);
1295        assert_eq!(unsafe { filter(bytes.as_ptr()) }, 0);
1296    }
1297
1298    #[test]
1299    fn scalar_null_propagation() {
1300        let arrow = make_schema(vec![("val", DataType::Int64, true)]);
1301        let rs = RowSchema::from_arrow(&arrow).unwrap();
1302        let mut jit = JitContext::new().unwrap();
1303        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1304
1305        let scalar = compiler
1306            .compile_scalar(&(col("val") + lit(10_i64)), &FieldType::Int64)
1307            .unwrap();
1308
1309        let arena = Bump::new();
1310        let bytes = make_row_bytes(&arena, &rs, &[(0, 0)], &[0]);
1311        let mut output = [0u8; 8];
1312        assert_ne!(unsafe { scalar(bytes.as_ptr(), output.as_mut_ptr()) }, 0);
1313    }
1314
1315    #[test]
1316    fn filter_is_null() {
1317        let arrow = make_schema(vec![("val", DataType::Int64, true)]);
1318        let rs = RowSchema::from_arrow(&arrow).unwrap();
1319        let mut jit = JitContext::new().unwrap();
1320        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1321
1322        let filter = compiler.compile_filter(&col("val").is_null()).unwrap();
1323
1324        let arena = Bump::new();
1325        let bytes = make_row_bytes(&arena, &rs, &[(0, 0)], &[0]);
1326        assert_eq!(unsafe { filter(bytes.as_ptr()) }, 1);
1327
1328        let bytes = make_row_bytes(&arena, &rs, &[(0, 42)], &[]);
1329        assert_eq!(unsafe { filter(bytes.as_ptr()) }, 0);
1330    }
1331
1332    #[test]
1333    fn filter_is_not_null() {
1334        let arrow = make_schema(vec![("val", DataType::Int64, true)]);
1335        let rs = RowSchema::from_arrow(&arrow).unwrap();
1336        let mut jit = JitContext::new().unwrap();
1337        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1338
1339        let filter = compiler.compile_filter(&col("val").is_not_null()).unwrap();
1340
1341        let arena = Bump::new();
1342        let bytes = make_row_bytes(&arena, &rs, &[(0, 42)], &[]);
1343        assert_eq!(unsafe { filter(bytes.as_ptr()) }, 1);
1344
1345        let bytes = make_row_bytes(&arena, &rs, &[(0, 0)], &[0]);
1346        assert_eq!(unsafe { filter(bytes.as_ptr()) }, 0);
1347    }
1348
1349    #[test]
1350    fn null_literal_produces_null() {
1351        let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1352        let rs = RowSchema::from_arrow(&arrow).unwrap();
1353        let mut jit = JitContext::new().unwrap();
1354        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1355
1356        let expr = Expr::BinaryExpr(BinaryExpr::new(
1357            Box::new(col("val")),
1358            Operator::Plus,
1359            Box::new(Expr::Literal(ScalarValue::Int64(None), None)),
1360        ));
1361        let scalar = compiler.compile_scalar(&expr, &FieldType::Int64).unwrap();
1362
1363        let arena = Bump::new();
1364        let bytes = make_row_bytes(&arena, &rs, &[(0, 42)], &[]);
1365        let mut output = [0u8; 8];
1366        assert_ne!(unsafe { scalar(bytes.as_ptr(), output.as_mut_ptr()) }, 0);
1367    }
1368
1369    #[test]
1370    fn null_and_true_short_circuits_false() {
1371        let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1372        let rs = RowSchema::from_arrow(&arrow).unwrap();
1373        let mut jit = JitContext::new().unwrap();
1374        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1375
1376        let expr = Expr::BinaryExpr(BinaryExpr::new(
1377            Box::new(Expr::Literal(ScalarValue::Boolean(None), None)),
1378            Operator::And,
1379            Box::new(Expr::Literal(ScalarValue::Boolean(Some(true)), None)),
1380        ));
1381        let filter = compiler.compile_filter(&expr).unwrap();
1382
1383        let arena = Bump::new();
1384        let bytes = make_row_bytes(&arena, &rs, &[(0, 0)], &[]);
1385        assert_eq!(unsafe { filter(bytes.as_ptr()) }, 0);
1386    }
1387
1388    // ---- CAST ----
1389
1390    #[test]
1391    fn cast_i32_to_i64() {
1392        let arrow = make_schema(vec![("val", DataType::Int32, false)]);
1393        let rs = RowSchema::from_arrow(&arrow).unwrap();
1394        let mut jit = JitContext::new().unwrap();
1395        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1396
1397        let expr = Expr::Cast(datafusion_expr::Cast {
1398            expr: Box::new(col("val")),
1399            data_type: DataType::Int64,
1400        });
1401        let scalar = compiler.compile_scalar(&expr, &FieldType::Int64).unwrap();
1402
1403        let arena = Bump::new();
1404        let mut row = MutableEventRow::new_in(&arena, &rs, 0);
1405        row.set_i32(0, -42);
1406        let frozen = row.freeze();
1407        let mut output = [0u8; 8];
1408        assert_eq!(
1409            unsafe { scalar(frozen.data().as_ptr(), output.as_mut_ptr()) },
1410            0
1411        );
1412        assert_eq!(i64::from_le_bytes(output), -42);
1413    }
1414
1415    #[test]
1416    fn cast_i64_to_f64() {
1417        let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1418        let rs = RowSchema::from_arrow(&arrow).unwrap();
1419        let mut jit = JitContext::new().unwrap();
1420        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1421
1422        let expr = Expr::Cast(datafusion_expr::Cast {
1423            expr: Box::new(col("val")),
1424            data_type: DataType::Float64,
1425        });
1426        let scalar = compiler.compile_scalar(&expr, &FieldType::Float64).unwrap();
1427
1428        let arena = Bump::new();
1429        let bytes = make_row_bytes(&arena, &rs, &[(0, 42)], &[]);
1430        let mut output = [0u8; 8];
1431        assert_eq!(unsafe { scalar(bytes.as_ptr(), output.as_mut_ptr()) }, 0);
1432        assert!((f64::from_le_bytes(output) - 42.0).abs() < f64::EPSILON);
1433    }
1434
1435    #[test]
1436    fn cast_f64_to_i64() {
1437        let arrow = make_schema(vec![("val", DataType::Float64, false)]);
1438        let rs = RowSchema::from_arrow(&arrow).unwrap();
1439        let mut jit = JitContext::new().unwrap();
1440        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1441
1442        let expr = Expr::Cast(datafusion_expr::Cast {
1443            expr: Box::new(col("val")),
1444            data_type: DataType::Int64,
1445        });
1446        let scalar = compiler.compile_scalar(&expr, &FieldType::Int64).unwrap();
1447
1448        let arena = Bump::new();
1449        let mut row = MutableEventRow::new_in(&arena, &rs, 0);
1450        row.set_f64(0, 42.9);
1451        let frozen = row.freeze();
1452        let mut output = [0u8; 8];
1453        assert_eq!(
1454            unsafe { scalar(frozen.data().as_ptr(), output.as_mut_ptr()) },
1455            0
1456        );
1457        assert_eq!(i64::from_le_bytes(output), 42);
1458    }
1459
1460    #[test]
1461    fn cast_bool_to_i64() {
1462        let arrow = make_schema(vec![("flag", DataType::Boolean, false)]);
1463        let rs = RowSchema::from_arrow(&arrow).unwrap();
1464        let mut jit = JitContext::new().unwrap();
1465        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1466
1467        let expr = Expr::Cast(datafusion_expr::Cast {
1468            expr: Box::new(col("flag")),
1469            data_type: DataType::Int64,
1470        });
1471        let scalar = compiler.compile_scalar(&expr, &FieldType::Int64).unwrap();
1472
1473        let arena = Bump::new();
1474        let mut row = MutableEventRow::new_in(&arena, &rs, 0);
1475        row.set_bool(0, true);
1476        let frozen = row.freeze();
1477        let mut output = [0u8; 8];
1478        assert_eq!(
1479            unsafe { scalar(frozen.data().as_ptr(), output.as_mut_ptr()) },
1480            0
1481        );
1482        assert_eq!(i64::from_le_bytes(output), 1);
1483    }
1484
1485    // ---- CASE/WHEN ----
1486
1487    #[test]
1488    fn case_simple() {
1489        let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1490        let rs = RowSchema::from_arrow(&arrow).unwrap();
1491        let mut jit = JitContext::new().unwrap();
1492        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1493
1494        let expr = Expr::Case(datafusion_expr::Case {
1495            expr: None,
1496            when_then_expr: vec![(Box::new(col("val").gt(lit(10_i64))), Box::new(lit(1_i64)))],
1497            else_expr: Some(Box::new(lit(0_i64))),
1498        });
1499        let scalar = compiler.compile_scalar(&expr, &FieldType::Int64).unwrap();
1500
1501        let arena = Bump::new();
1502        let mut output = [0u8; 8];
1503
1504        let bytes = make_row_bytes(&arena, &rs, &[(0, 20)], &[]);
1505        unsafe { scalar(bytes.as_ptr(), output.as_mut_ptr()) };
1506        assert_eq!(i64::from_le_bytes(output), 1);
1507
1508        let bytes = make_row_bytes(&arena, &rs, &[(0, 5)], &[]);
1509        unsafe { scalar(bytes.as_ptr(), output.as_mut_ptr()) };
1510        assert_eq!(i64::from_le_bytes(output), 0);
1511    }
1512
1513    #[test]
1514    fn case_multiple_whens() {
1515        let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1516        let rs = RowSchema::from_arrow(&arrow).unwrap();
1517        let mut jit = JitContext::new().unwrap();
1518        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1519
1520        let expr = Expr::Case(datafusion_expr::Case {
1521            expr: None,
1522            when_then_expr: vec![
1523                (Box::new(col("val").gt(lit(100_i64))), Box::new(lit(3_i64))),
1524                (Box::new(col("val").gt(lit(10_i64))), Box::new(lit(2_i64))),
1525            ],
1526            else_expr: Some(Box::new(lit(1_i64))),
1527        });
1528        let scalar = compiler.compile_scalar(&expr, &FieldType::Int64).unwrap();
1529
1530        let arena = Bump::new();
1531        let mut output = [0u8; 8];
1532
1533        let b = make_row_bytes(&arena, &rs, &[(0, 200)], &[]);
1534        unsafe { scalar(b.as_ptr(), output.as_mut_ptr()) };
1535        assert_eq!(i64::from_le_bytes(output), 3);
1536
1537        let b = make_row_bytes(&arena, &rs, &[(0, 50)], &[]);
1538        unsafe { scalar(b.as_ptr(), output.as_mut_ptr()) };
1539        assert_eq!(i64::from_le_bytes(output), 2);
1540
1541        let b = make_row_bytes(&arena, &rs, &[(0, 5)], &[]);
1542        unsafe { scalar(b.as_ptr(), output.as_mut_ptr()) };
1543        assert_eq!(i64::from_le_bytes(output), 1);
1544    }
1545
1546    #[test]
1547    fn case_no_else_returns_null() {
1548        let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1549        let rs = RowSchema::from_arrow(&arrow).unwrap();
1550        let mut jit = JitContext::new().unwrap();
1551        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1552
1553        let expr = Expr::Case(datafusion_expr::Case {
1554            expr: None,
1555            when_then_expr: vec![(Box::new(col("val").gt(lit(100_i64))), Box::new(lit(1_i64)))],
1556            else_expr: None,
1557        });
1558        let scalar = compiler.compile_scalar(&expr, &FieldType::Int64).unwrap();
1559
1560        let arena = Bump::new();
1561        let mut output = [0u8; 8];
1562        let bytes = make_row_bytes(&arena, &rs, &[(0, 5)], &[]);
1563        let is_null = unsafe { scalar(bytes.as_ptr(), output.as_mut_ptr()) };
1564        assert_ne!(is_null, 0);
1565    }
1566
1567    // ---- Fallback / error cases ----
1568
1569    #[test]
1570    fn unsupported_expr_error() {
1571        let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1572        let rs = RowSchema::from_arrow(&arrow).unwrap();
1573        let mut jit = JitContext::new().unwrap();
1574        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1575
1576        let expr = Expr::Unnest(datafusion_expr::expr::Unnest::new(col("val")));
1577        assert!(compiler.compile_filter(&expr).is_err());
1578    }
1579
1580    #[test]
1581    fn column_not_found_error() {
1582        let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1583        let rs = RowSchema::from_arrow(&arrow).unwrap();
1584        let mut jit = JitContext::new().unwrap();
1585        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1586
1587        let result = compiler.compile_filter(&col("nonexistent").gt(lit(1_i64)));
1588        assert!(matches!(result, Err(CompileError::ColumnNotFound(_))));
1589    }
1590
1591    // ---- Float comparison ----
1592
1593    #[test]
1594    fn filter_f64_comparison() {
1595        let arrow = make_schema(vec![("price", DataType::Float64, false)]);
1596        let rs = RowSchema::from_arrow(&arrow).unwrap();
1597        let mut jit = JitContext::new().unwrap();
1598        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1599
1600        let filter = compiler
1601            .compile_filter(&col("price").gt(lit(100.0_f64)))
1602            .unwrap();
1603
1604        let arena = Bump::new();
1605        let mut row = MutableEventRow::new_in(&arena, &rs, 0);
1606        row.set_f64(0, 150.0);
1607        assert_eq!(unsafe { filter(row.freeze().data().as_ptr()) }, 1);
1608
1609        let mut row = MutableEventRow::new_in(&arena, &rs, 0);
1610        row.set_f64(0, 50.0);
1611        assert_eq!(unsafe { filter(row.freeze().data().as_ptr()) }, 0);
1612    }
1613
1614    // ---- Multi-column ----
1615
1616    #[test]
1617    fn scalar_two_columns_add() {
1618        let arrow = make_schema(vec![
1619            ("a", DataType::Int64, false),
1620            ("b", DataType::Int64, false),
1621        ]);
1622        let rs = RowSchema::from_arrow(&arrow).unwrap();
1623        let mut jit = JitContext::new().unwrap();
1624        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1625
1626        let scalar = compiler
1627            .compile_scalar(&(col("a") + col("b")), &FieldType::Int64)
1628            .unwrap();
1629
1630        let arena = Bump::new();
1631        let bytes = make_row_bytes(&arena, &rs, &[(0, 30), (1, 12)], &[]);
1632        let mut output = [0u8; 8];
1633        unsafe { scalar(bytes.as_ptr(), output.as_mut_ptr()) };
1634        assert_eq!(i64::from_le_bytes(output), 42);
1635    }
1636
1637    // ---- Multiple compilations ----
1638
1639    #[test]
1640    fn multiple_functions_same_context() {
1641        let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1642        let rs = RowSchema::from_arrow(&arrow).unwrap();
1643        let mut jit = JitContext::new().unwrap();
1644        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1645
1646        let f1 = compiler
1647            .compile_filter(&col("val").gt(lit(10_i64)))
1648            .unwrap();
1649        let f2 = compiler
1650            .compile_filter(&col("val").lt(lit(100_i64)))
1651            .unwrap();
1652
1653        let arena = Bump::new();
1654        let bytes = make_row_bytes(&arena, &rs, &[(0, 50)], &[]);
1655        assert_eq!(unsafe { f1(bytes.as_ptr()) }, 1);
1656        assert_eq!(unsafe { f2(bytes.as_ptr()) }, 1);
1657
1658        let bytes = make_row_bytes(&arena, &rs, &[(0, 5)], &[]);
1659        assert_eq!(unsafe { f1(bytes.as_ptr()) }, 0);
1660        assert_eq!(unsafe { f2(bytes.as_ptr()) }, 1);
1661    }
1662
1663    // ---- Literal-only ----
1664
1665    #[test]
1666    fn filter_literal_true_false() {
1667        let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1668        let rs = RowSchema::from_arrow(&arrow).unwrap();
1669        let mut jit = JitContext::new().unwrap();
1670        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1671        let arena = Bump::new();
1672        let bytes = make_row_bytes(&arena, &rs, &[(0, 0)], &[]);
1673
1674        let ft = compiler.compile_filter(&lit(true)).unwrap();
1675        assert_eq!(unsafe { ft(bytes.as_ptr()) }, 1);
1676
1677        let ff = compiler.compile_filter(&lit(false)).unwrap();
1678        assert_eq!(unsafe { ff(bytes.as_ptr()) }, 0);
1679    }
1680
1681    // ---- Constant folding integration ----
1682
1683    #[test]
1684    fn constant_folding_in_compile() {
1685        let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1686        let rs = RowSchema::from_arrow(&arrow).unwrap();
1687        let mut jit = JitContext::new().unwrap();
1688        let mut compiler = ExprCompiler::new(&mut jit, &rs);
1689
1690        let expr = col("val").gt(lit(2_i64) + lit(3_i64));
1691        let filter = compiler.compile_filter(&expr).unwrap();
1692
1693        let arena = Bump::new();
1694        let bytes = make_row_bytes(&arena, &rs, &[(0, 10)], &[]);
1695        assert_eq!(unsafe { filter(bytes.as_ptr()) }, 1);
1696
1697        let bytes = make_row_bytes(&arena, &rs, &[(0, 3)], &[]);
1698        assert_eq!(unsafe { filter(bytes.as_ptr()) }, 0);
1699    }
1700}