1use cranelift_codegen::ir::condcodes::{FloatCC, IntCC};
25use cranelift_codegen::ir::types::{self as cl_types};
26use cranelift_codegen::ir::{
27 AbiParam, BlockArg, Function, InstBuilder, MemFlags, Type as CraneliftType, UserFuncName, Value,
28};
29use cranelift_codegen::Context;
30use cranelift_frontend::FunctionBuilder;
31use cranelift_module::Module;
32use datafusion_common::ScalarValue;
33use datafusion_expr::{BinaryExpr, Expr, Operator};
34
35use super::error::{CompileError, FilterFn, ScalarFn};
36use super::fold::fold_constants;
37use super::jit::JitContext;
38use super::row::{FieldType, RowSchema};
39
40const PTR_TYPE: CraneliftType = cl_types::I64;
42
43pub(crate) struct CompiledValue {
45 pub(crate) value: Value,
47 pub(crate) is_nullable: bool,
49 pub(crate) null_flag: Option<Value>,
51 pub(crate) value_type: FieldType,
53}
54
55pub struct ExprCompiler<'a> {
61 jit: &'a mut JitContext,
62 schema: &'a RowSchema,
63}
64
65impl<'a> ExprCompiler<'a> {
66 pub fn new(jit: &'a mut JitContext, schema: &'a RowSchema) -> Self {
68 Self { jit, schema }
69 }
70
71 pub fn compile_filter(&mut self, expr: &Expr) -> Result<FilterFn, CompileError> {
86 let expr = fold_constants(expr);
87 let func_name = self.jit.next_func_name("filter");
88
89 let mut sig = self.jit.module().make_signature();
90 sig.params.push(AbiParam::new(PTR_TYPE));
91 sig.returns.push(AbiParam::new(cl_types::I8));
92
93 let func_id = self.jit.module().declare_function(
94 &func_name,
95 cranelift_module::Linkage::Local,
96 &sig,
97 )?;
98
99 let mut func = Function::with_name_signature(UserFuncName::testcase(&func_name), sig);
100
101 {
102 let builder_ctx = self.jit.builder_ctx();
103 let mut builder = FunctionBuilder::new(&mut func, builder_ctx);
104 let entry = builder.create_block();
105 builder.append_block_params_for_function_params(entry);
106 builder.switch_to_block(entry);
107 builder.seal_block(entry);
108
109 let row_ptr = builder.block_params(entry)[0];
110 let compiled = compile_expr_inner(&mut builder, self.schema, &expr, row_ptr)?;
111
112 let result = if compiled.is_nullable {
114 if let Some(null_flag) = compiled.null_flag {
115 let zero = builder.ins().iconst(cl_types::I8, 0);
116 let is_null = builder.ins().icmp_imm(IntCC::NotEqual, null_flag, 0);
117 builder.ins().select(is_null, zero, compiled.value)
118 } else {
119 compiled.value
120 }
121 } else {
122 compiled.value
123 };
124
125 builder.ins().return_(&[result]);
126 builder.finalize();
127 }
128
129 let mut ctx = Context::for_function(func);
130 self.jit
131 .module()
132 .define_function(func_id, &mut ctx)
133 .map_err(|e| CompileError::Cranelift(Box::new(e)))?;
134 self.jit.module().finalize_definitions().unwrap();
135
136 let code_ptr = self.jit.module().get_finalized_function(func_id);
137 Ok(unsafe { std::mem::transmute::<*const u8, FilterFn>(code_ptr) })
142 }
143
144 pub fn compile_scalar(
159 &mut self,
160 expr: &Expr,
161 output_type: &FieldType,
162 ) -> Result<ScalarFn, CompileError> {
163 let expr = fold_constants(expr);
164 let func_name = self.jit.next_func_name("scalar");
165
166 let mut sig = self.jit.module().make_signature();
167 sig.params.push(AbiParam::new(PTR_TYPE));
168 sig.params.push(AbiParam::new(PTR_TYPE));
169 sig.returns.push(AbiParam::new(cl_types::I8));
170
171 let func_id = self.jit.module().declare_function(
172 &func_name,
173 cranelift_module::Linkage::Local,
174 &sig,
175 )?;
176
177 let mut func = Function::with_name_signature(UserFuncName::testcase(&func_name), sig);
178
179 {
180 let builder_ctx = self.jit.builder_ctx();
181 let mut builder = FunctionBuilder::new(&mut func, builder_ctx);
182 let entry = builder.create_block();
183 builder.append_block_params_for_function_params(entry);
184 builder.switch_to_block(entry);
185 builder.seal_block(entry);
186
187 let row_ptr = builder.block_params(entry)[0];
188 let out_ptr = builder.block_params(entry)[1];
189
190 let compiled = compile_expr_inner(&mut builder, self.schema, &expr, row_ptr)?;
191
192 let mem_flags = MemFlags::trusted();
194 builder.ins().store(mem_flags, compiled.value, out_ptr, 0);
195
196 let is_null = if compiled.is_nullable {
198 compiled
199 .null_flag
200 .unwrap_or_else(|| builder.ins().iconst(cl_types::I8, 0))
201 } else {
202 builder.ins().iconst(cl_types::I8, 0)
203 };
204
205 let is_null_i8 = if builder.func.dfg.value_type(is_null) == cl_types::I8 {
206 is_null
207 } else {
208 builder.ins().ireduce(cl_types::I8, is_null)
209 };
210
211 _ = output_type;
212 builder.ins().return_(&[is_null_i8]);
213 builder.finalize();
214 }
215
216 let mut ctx = Context::for_function(func);
217 self.jit
218 .module()
219 .define_function(func_id, &mut ctx)
220 .map_err(|e| CompileError::Cranelift(Box::new(e)))?;
221 self.jit.module().finalize_definitions().unwrap();
222
223 let code_ptr = self.jit.module().get_finalized_function(func_id);
224 Ok(unsafe { std::mem::transmute::<*const u8, ScalarFn>(code_ptr) })
229 }
230}
231
232pub(crate) fn compile_expr_inner(
238 builder: &mut FunctionBuilder,
239 schema: &RowSchema,
240 expr: &Expr,
241 row_ptr: Value,
242) -> Result<CompiledValue, CompileError> {
243 match expr {
244 Expr::Column(col) => compile_column(builder, schema, &col.name, row_ptr),
245 Expr::Literal(scalar, _) => compile_literal(builder, scalar),
246 Expr::BinaryExpr(binary) => compile_binary(builder, schema, binary, row_ptr),
247 Expr::Not(inner) => compile_not(builder, schema, inner, row_ptr),
248 Expr::IsNull(inner) => compile_is_null(builder, schema, inner, row_ptr, false),
249 Expr::IsNotNull(inner) => compile_is_null(builder, schema, inner, row_ptr, true),
250 Expr::Cast(cast) => compile_cast(builder, schema, &cast.expr, &cast.data_type, row_ptr),
251 Expr::Case(case) => compile_case(builder, schema, case, row_ptr),
252 other => Err(CompileError::UnsupportedExpr(format!("{other}"))),
253 }
254}
255
256fn compile_column(
258 builder: &mut FunctionBuilder,
259 schema: &RowSchema,
260 name: &str,
261 row_ptr: Value,
262) -> Result<CompiledValue, CompileError> {
263 let field_idx = schema
264 .arrow_schema()
265 .index_of(name)
266 .map_err(|_| CompileError::ColumnNotFound(name.to_string()))?;
267
268 let layout = schema.field(field_idx);
269 let cl_type = field_type_to_cranelift(layout.field_type);
270 let mem_flags = MemFlags::trusted();
271
272 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
273 let offset = layout.offset as i32;
274 let value = builder.ins().load(cl_type, mem_flags, row_ptr, offset);
275
276 let arrow_field = &schema.arrow_schema().fields()[field_idx];
277 let is_nullable = arrow_field.is_nullable();
278
279 let null_flag = if is_nullable {
280 Some(emit_null_check(builder, schema, field_idx, row_ptr))
281 } else {
282 None
283 };
284
285 Ok(CompiledValue {
286 value,
287 is_nullable,
288 null_flag,
289 value_type: layout.field_type,
290 })
291}
292
293fn compile_literal(
295 builder: &mut FunctionBuilder,
296 scalar: &ScalarValue,
297) -> Result<CompiledValue, CompileError> {
298 match scalar {
299 ScalarValue::Boolean(Some(b)) => Ok(CompiledValue {
300 value: builder.ins().iconst(cl_types::I8, i64::from(*b)),
301 is_nullable: false,
302 null_flag: None,
303 value_type: FieldType::Bool,
304 }),
305 ScalarValue::Int8(Some(v)) => Ok(CompiledValue {
306 value: builder.ins().iconst(cl_types::I8, i64::from(*v)),
307 is_nullable: false,
308 null_flag: None,
309 value_type: FieldType::Int8,
310 }),
311 ScalarValue::Int16(Some(v)) => Ok(CompiledValue {
312 value: builder.ins().iconst(cl_types::I16, i64::from(*v)),
313 is_nullable: false,
314 null_flag: None,
315 value_type: FieldType::Int16,
316 }),
317 ScalarValue::Int32(Some(v)) => Ok(CompiledValue {
318 value: builder.ins().iconst(cl_types::I32, i64::from(*v)),
319 is_nullable: false,
320 null_flag: None,
321 value_type: FieldType::Int32,
322 }),
323 ScalarValue::Int64(Some(v)) => Ok(CompiledValue {
324 value: builder.ins().iconst(cl_types::I64, *v),
325 is_nullable: false,
326 null_flag: None,
327 value_type: FieldType::Int64,
328 }),
329 ScalarValue::UInt8(Some(v)) => Ok(CompiledValue {
330 value: builder.ins().iconst(cl_types::I8, i64::from(*v)),
331 is_nullable: false,
332 null_flag: None,
333 value_type: FieldType::UInt8,
334 }),
335 ScalarValue::UInt16(Some(v)) => Ok(CompiledValue {
336 value: builder.ins().iconst(cl_types::I16, i64::from(*v)),
337 is_nullable: false,
338 null_flag: None,
339 value_type: FieldType::UInt16,
340 }),
341 ScalarValue::UInt32(Some(v)) => Ok(CompiledValue {
342 value: builder.ins().iconst(cl_types::I32, i64::from(*v)),
343 is_nullable: false,
344 null_flag: None,
345 value_type: FieldType::UInt32,
346 }),
347 ScalarValue::UInt64(Some(v)) => Ok(CompiledValue {
348 #[allow(clippy::cast_possible_wrap)]
349 value: builder.ins().iconst(cl_types::I64, (*v).cast_signed()),
350 is_nullable: false,
351 null_flag: None,
352 value_type: FieldType::UInt64,
353 }),
354 ScalarValue::Float32(Some(v)) => Ok(CompiledValue {
355 value: builder.ins().f32const(*v),
356 is_nullable: false,
357 null_flag: None,
358 value_type: FieldType::Float32,
359 }),
360 ScalarValue::Float64(Some(v)) => Ok(CompiledValue {
361 value: builder.ins().f64const(*v),
362 is_nullable: false,
363 null_flag: None,
364 value_type: FieldType::Float64,
365 }),
366 ScalarValue::Boolean(None) => Ok(null_value(builder, FieldType::Bool)),
368 ScalarValue::Int64(None) => Ok(null_value(builder, FieldType::Int64)),
369 ScalarValue::Float64(None) => Ok(null_value(builder, FieldType::Float64)),
370 ScalarValue::Int32(None) => Ok(null_value(builder, FieldType::Int32)),
371 _ => Err(CompileError::UnsupportedLiteral),
372 }
373}
374
375pub(crate) fn null_value(builder: &mut FunctionBuilder, field_type: FieldType) -> CompiledValue {
377 let cl_type = field_type_to_cranelift(field_type);
378 let value = if cl_type.is_float() {
379 if cl_type == cl_types::F32 {
380 builder.ins().f32const(0.0)
381 } else {
382 builder.ins().f64const(0.0)
383 }
384 } else {
385 builder.ins().iconst(cl_type, 0)
386 };
387 let null_flag = builder.ins().iconst(cl_types::I8, 1);
388 CompiledValue {
389 value,
390 is_nullable: true,
391 null_flag: Some(null_flag),
392 value_type: field_type,
393 }
394}
395
396fn compile_binary(
398 builder: &mut FunctionBuilder,
399 schema: &RowSchema,
400 binary: &BinaryExpr,
401 row_ptr: Value,
402) -> Result<CompiledValue, CompileError> {
403 match binary.op {
404 Operator::And => return compile_short_circuit_and(builder, schema, binary, row_ptr),
405 Operator::Or => return compile_short_circuit_or(builder, schema, binary, row_ptr),
406 _ => {}
407 }
408
409 let lhs = compile_expr_inner(builder, schema, &binary.left, row_ptr)?;
410 let rhs = compile_expr_inner(builder, schema, &binary.right, row_ptr)?;
411
412 let (is_nullable, merged_null) = merge_null_flags(builder, &lhs, &rhs);
413 let result_value = emit_binary_op(builder, &lhs, &rhs, binary.op)?;
414
415 Ok(CompiledValue {
416 value: result_value.value,
417 is_nullable,
418 null_flag: merged_null,
419 value_type: result_value.value_type,
420 })
421}
422
423fn emit_binary_op(
425 builder: &mut FunctionBuilder,
426 lhs: &CompiledValue,
427 rhs: &CompiledValue,
428 op: Operator,
429) -> Result<CompiledValue, CompileError> {
430 let lhs_type = lhs.value_type;
431 let cl_type = field_type_to_cranelift(lhs_type);
432
433 if cl_type.is_int() {
434 return emit_int_binary(builder, lhs, rhs, op);
435 }
436 if cl_type.is_float() {
437 return emit_float_binary(builder, lhs, rhs, op);
438 }
439
440 Err(CompileError::UnsupportedBinaryOp(lhs_type, op))
441}
442
443fn emit_int_binary(
445 builder: &mut FunctionBuilder,
446 lhs: &CompiledValue,
447 rhs: &CompiledValue,
448 op: Operator,
449) -> Result<CompiledValue, CompileError> {
450 let lhs_type = lhs.value_type;
451 let is_signed = is_signed_type(lhs_type);
452
453 let value = match op {
454 Operator::Plus => builder.ins().iadd(lhs.value, rhs.value),
455 Operator::Minus => builder.ins().isub(lhs.value, rhs.value),
456 Operator::Multiply => builder.ins().imul(lhs.value, rhs.value),
457 Operator::Divide => {
458 if is_signed {
459 builder.ins().sdiv(lhs.value, rhs.value)
460 } else {
461 builder.ins().udiv(lhs.value, rhs.value)
462 }
463 }
464 Operator::Modulo => {
465 if is_signed {
466 builder.ins().srem(lhs.value, rhs.value)
467 } else {
468 builder.ins().urem(lhs.value, rhs.value)
469 }
470 }
471 Operator::Eq
472 | Operator::NotEq
473 | Operator::Lt
474 | Operator::LtEq
475 | Operator::Gt
476 | Operator::GtEq => {
477 let cc = int_cmp_cond(op, is_signed);
478 let cmp_val = builder.ins().icmp(cc, lhs.value, rhs.value);
479 return Ok(CompiledValue {
480 value: cmp_val,
481 is_nullable: false,
482 null_flag: None,
483 value_type: FieldType::Bool,
484 });
485 }
486 _ => return Err(CompileError::UnsupportedBinaryOp(lhs_type, op)),
487 };
488
489 Ok(CompiledValue {
490 value,
491 is_nullable: false,
492 null_flag: None,
493 value_type: lhs_type,
494 })
495}
496
497fn emit_float_binary(
499 builder: &mut FunctionBuilder,
500 lhs: &CompiledValue,
501 rhs: &CompiledValue,
502 op: Operator,
503) -> Result<CompiledValue, CompileError> {
504 let lhs_type = lhs.value_type;
505
506 let value = match op {
507 Operator::Plus => builder.ins().fadd(lhs.value, rhs.value),
508 Operator::Minus => builder.ins().fsub(lhs.value, rhs.value),
509 Operator::Multiply => builder.ins().fmul(lhs.value, rhs.value),
510 Operator::Divide => builder.ins().fdiv(lhs.value, rhs.value),
511 Operator::Eq
512 | Operator::NotEq
513 | Operator::Lt
514 | Operator::LtEq
515 | Operator::Gt
516 | Operator::GtEq => {
517 let cc = float_cmp_cond(op);
518 let cmp_val = builder.ins().fcmp(cc, lhs.value, rhs.value);
519 return Ok(CompiledValue {
520 value: cmp_val,
521 is_nullable: false,
522 null_flag: None,
523 value_type: FieldType::Bool,
524 });
525 }
526 _ => return Err(CompileError::UnsupportedBinaryOp(lhs_type, op)),
527 };
528
529 Ok(CompiledValue {
530 value,
531 is_nullable: false,
532 null_flag: None,
533 value_type: lhs_type,
534 })
535}
536
537fn compile_not(
539 builder: &mut FunctionBuilder,
540 schema: &RowSchema,
541 inner: &Expr,
542 row_ptr: Value,
543) -> Result<CompiledValue, CompileError> {
544 let compiled = compile_expr_inner(builder, schema, inner, row_ptr)?;
545 let one = builder.ins().iconst(cl_types::I8, 1);
546 let flipped = builder.ins().bxor(compiled.value, one);
547 Ok(CompiledValue {
548 value: flipped,
549 is_nullable: compiled.is_nullable,
550 null_flag: compiled.null_flag,
551 value_type: FieldType::Bool,
552 })
553}
554
555fn compile_is_null(
557 builder: &mut FunctionBuilder,
558 schema: &RowSchema,
559 inner: &Expr,
560 row_ptr: Value,
561 invert: bool,
562) -> Result<CompiledValue, CompileError> {
563 let compiled = compile_expr_inner(builder, schema, inner, row_ptr)?;
564
565 let result = if let Some(nf) = compiled.null_flag {
566 if invert {
567 builder.ins().icmp_imm(IntCC::Equal, nf, 0)
568 } else {
569 builder.ins().icmp_imm(IntCC::NotEqual, nf, 0)
570 }
571 } else {
572 builder.ins().iconst(cl_types::I8, i64::from(invert))
573 };
574
575 Ok(CompiledValue {
576 value: result,
577 is_nullable: false,
578 null_flag: None,
579 value_type: FieldType::Bool,
580 })
581}
582
583fn compile_cast(
585 builder: &mut FunctionBuilder,
586 schema: &RowSchema,
587 inner: &Expr,
588 target_dt: &arrow_schema::DataType,
589 row_ptr: Value,
590) -> Result<CompiledValue, CompileError> {
591 let compiled = compile_expr_inner(builder, schema, inner, row_ptr)?;
592 let out_type = FieldType::from_arrow(target_dt)
593 .ok_or_else(|| CompileError::UnsupportedExpr(format!("CAST to {target_dt}")))?;
594 let target_cl = field_type_to_cranelift(out_type);
595 let source_cl = field_type_to_cranelift(compiled.value_type);
596
597 if source_cl == target_cl {
598 return Ok(CompiledValue {
599 value_type: out_type,
600 ..compiled
601 });
602 }
603
604 let value = if source_cl.is_int() && target_cl.is_int() {
605 if target_cl.bits() > source_cl.bits() {
606 if is_signed_type(compiled.value_type) {
607 builder.ins().sextend(target_cl, compiled.value)
608 } else {
609 builder.ins().uextend(target_cl, compiled.value)
610 }
611 } else {
612 builder.ins().ireduce(target_cl, compiled.value)
613 }
614 } else if source_cl.is_int() && target_cl.is_float() {
615 if is_signed_type(compiled.value_type) {
616 builder.ins().fcvt_from_sint(target_cl, compiled.value)
617 } else {
618 builder.ins().fcvt_from_uint(target_cl, compiled.value)
619 }
620 } else if source_cl.is_float() && target_cl.is_int() {
621 if is_signed_type(out_type) {
622 builder.ins().fcvt_to_sint(target_cl, compiled.value)
623 } else {
624 builder.ins().fcvt_to_uint(target_cl, compiled.value)
625 }
626 } else if source_cl.is_float() && target_cl.is_float() {
627 if target_cl.bits() > source_cl.bits() {
628 builder.ins().fpromote(target_cl, compiled.value)
629 } else {
630 builder.ins().fdemote(target_cl, compiled.value)
631 }
632 } else {
633 return Err(CompileError::UnsupportedExpr(format!(
634 "CAST from {source_cl} to {target_cl}"
635 )));
636 };
637
638 Ok(CompiledValue {
639 value,
640 is_nullable: compiled.is_nullable,
641 null_flag: compiled.null_flag,
642 value_type: out_type,
643 })
644}
645
646fn compile_case(
648 builder: &mut FunctionBuilder,
649 schema: &RowSchema,
650 case: &datafusion_expr::Case,
651 row_ptr: Value,
652) -> Result<CompiledValue, CompileError> {
653 let merge_block = builder.create_block();
654
655 if case.when_then_expr.is_empty() {
656 return Err(CompileError::UnsupportedExpr("empty CASE".to_string()));
657 }
658
659 let first_then = &case.when_then_expr[0].1;
660 let result_type = infer_expr_type(schema, first_then)?;
661 let cl_type = field_type_to_cranelift(result_type);
662 builder.append_block_param(merge_block, cl_type);
663 builder.append_block_param(merge_block, cl_types::I8);
664
665 for (when_expr, then_expr) in &case.when_then_expr {
666 let then_block = builder.create_block();
667 let else_block = builder.create_block();
668
669 let cond = compile_expr_inner(builder, schema, when_expr, row_ptr)?;
670 builder
671 .ins()
672 .brif(cond.value, then_block, &[], else_block, &[]);
673
674 builder.switch_to_block(then_block);
675 builder.seal_block(then_block);
676 let then_val = compile_expr_inner(builder, schema, then_expr, row_ptr)?;
677 let then_null = then_val
678 .null_flag
679 .unwrap_or_else(|| builder.ins().iconst(cl_types::I8, 0));
680 builder.ins().jump(
681 merge_block,
682 &[BlockArg::Value(then_val.value), BlockArg::Value(then_null)],
683 );
684
685 builder.switch_to_block(else_block);
686 builder.seal_block(else_block);
687 }
688
689 let else_val = if let Some(else_expr) = &case.else_expr {
690 compile_expr_inner(builder, schema, else_expr, row_ptr)?
691 } else {
692 null_value(builder, result_type)
693 };
694 let else_null = else_val
695 .null_flag
696 .unwrap_or_else(|| builder.ins().iconst(cl_types::I8, 0));
697 builder.ins().jump(
698 merge_block,
699 &[BlockArg::Value(else_val.value), BlockArg::Value(else_null)],
700 );
701
702 builder.switch_to_block(merge_block);
703 builder.seal_block(merge_block);
704
705 let result_value = builder.block_params(merge_block)[0];
706 let result_null = builder.block_params(merge_block)[1];
707
708 Ok(CompiledValue {
709 value: result_value,
710 is_nullable: true,
711 null_flag: Some(result_null),
712 value_type: result_type,
713 })
714}
715
716fn compile_short_circuit_and(
718 builder: &mut FunctionBuilder,
719 schema: &RowSchema,
720 binary: &BinaryExpr,
721 row_ptr: Value,
722) -> Result<CompiledValue, CompileError> {
723 let merge_block = builder.create_block();
724 builder.append_block_param(merge_block, cl_types::I8);
725 let rhs_block = builder.create_block();
726
727 let lhs = compile_expr_inner(builder, schema, &binary.left, row_ptr)?;
728
729 let false_val = builder.ins().iconst(cl_types::I8, 0);
730 builder.ins().brif(
731 lhs.value,
732 rhs_block,
733 &[],
734 merge_block,
735 &[BlockArg::Value(false_val)],
736 );
737
738 builder.switch_to_block(rhs_block);
739 builder.seal_block(rhs_block);
740 let rhs = compile_expr_inner(builder, schema, &binary.right, row_ptr)?;
741 builder
742 .ins()
743 .jump(merge_block, &[BlockArg::Value(rhs.value)]);
744
745 builder.switch_to_block(merge_block);
746 builder.seal_block(merge_block);
747
748 let result = builder.block_params(merge_block)[0];
749 Ok(CompiledValue {
750 value: result,
751 is_nullable: false,
752 null_flag: None,
753 value_type: FieldType::Bool,
754 })
755}
756
757fn compile_short_circuit_or(
759 builder: &mut FunctionBuilder,
760 schema: &RowSchema,
761 binary: &BinaryExpr,
762 row_ptr: Value,
763) -> Result<CompiledValue, CompileError> {
764 let merge_block = builder.create_block();
765 builder.append_block_param(merge_block, cl_types::I8);
766 let rhs_block = builder.create_block();
767
768 let lhs = compile_expr_inner(builder, schema, &binary.left, row_ptr)?;
769
770 let true_val = builder.ins().iconst(cl_types::I8, 1);
771 builder.ins().brif(
772 lhs.value,
773 merge_block,
774 &[BlockArg::Value(true_val)],
775 rhs_block,
776 &[],
777 );
778
779 builder.switch_to_block(rhs_block);
780 builder.seal_block(rhs_block);
781 let rhs = compile_expr_inner(builder, schema, &binary.right, row_ptr)?;
782 builder
783 .ins()
784 .jump(merge_block, &[BlockArg::Value(rhs.value)]);
785
786 builder.switch_to_block(merge_block);
787 builder.seal_block(merge_block);
788
789 let result = builder.block_params(merge_block)[0];
790 Ok(CompiledValue {
791 value: result,
792 is_nullable: false,
793 null_flag: None,
794 value_type: FieldType::Bool,
795 })
796}
797
798pub(crate) fn emit_null_check(
800 builder: &mut FunctionBuilder,
801 schema: &RowSchema,
802 field_idx: usize,
803 row_ptr: Value,
804) -> Value {
805 let layout = schema.field(field_idx);
806 let null_bit = layout.null_bit;
807 let byte_idx = RowSchema::header_size() + null_bit / 8;
808 let bit_idx = null_bit % 8;
809
810 let mem_flags = MemFlags::trusted();
811 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
812 let byte_offset = byte_idx as i32;
813 let byte_val = builder
814 .ins()
815 .load(cl_types::I8, mem_flags, row_ptr, byte_offset);
816 let mask = builder.ins().iconst(cl_types::I8, 1 << bit_idx);
817 builder.ins().band(byte_val, mask)
818}
819
820pub(crate) fn merge_null_flags(
822 builder: &mut FunctionBuilder,
823 lhs: &CompiledValue,
824 rhs: &CompiledValue,
825) -> (bool, Option<Value>) {
826 match (lhs.null_flag, rhs.null_flag) {
827 (Some(l), Some(r)) => {
828 let merged = builder.ins().bor(l, r);
829 (true, Some(merged))
830 }
831 (Some(f), None) | (None, Some(f)) => (true, Some(f)),
832 (None, None) => (false, None),
833 }
834}
835
836pub(crate) fn infer_expr_type(schema: &RowSchema, expr: &Expr) -> Result<FieldType, CompileError> {
838 match expr {
839 Expr::Column(col) => {
840 let idx = schema
841 .arrow_schema()
842 .index_of(&col.name)
843 .map_err(|_| CompileError::ColumnNotFound(col.name.clone()))?;
844 Ok(schema.field(idx).field_type)
845 }
846 Expr::Literal(scalar, _) => match scalar {
847 ScalarValue::Boolean(_) => Ok(FieldType::Bool),
848 ScalarValue::Int8(_) => Ok(FieldType::Int8),
849 ScalarValue::Int16(_) => Ok(FieldType::Int16),
850 ScalarValue::Int32(_) => Ok(FieldType::Int32),
851 ScalarValue::Int64(_) => Ok(FieldType::Int64),
852 ScalarValue::UInt8(_) => Ok(FieldType::UInt8),
853 ScalarValue::UInt16(_) => Ok(FieldType::UInt16),
854 ScalarValue::UInt32(_) => Ok(FieldType::UInt32),
855 ScalarValue::UInt64(_) => Ok(FieldType::UInt64),
856 ScalarValue::Float32(_) => Ok(FieldType::Float32),
857 ScalarValue::Float64(_) => Ok(FieldType::Float64),
858 _ => Err(CompileError::UnsupportedLiteral),
859 },
860 Expr::BinaryExpr(binary) => match binary.op {
861 Operator::Eq
862 | Operator::NotEq
863 | Operator::Lt
864 | Operator::LtEq
865 | Operator::Gt
866 | Operator::GtEq
867 | Operator::And
868 | Operator::Or => Ok(FieldType::Bool),
869 _ => infer_expr_type(schema, &binary.left),
870 },
871 Expr::Not(_) | Expr::IsNull(_) | Expr::IsNotNull(_) => Ok(FieldType::Bool),
872 Expr::Cast(cast) => FieldType::from_arrow(&cast.data_type)
873 .ok_or_else(|| CompileError::UnsupportedExpr(format!("CAST to {}", cast.data_type))),
874 Expr::Case(case) => {
875 if let Some((_, then_expr)) = case.when_then_expr.first() {
876 infer_expr_type(schema, then_expr)
877 } else {
878 Err(CompileError::UnsupportedExpr("empty CASE".to_string()))
879 }
880 }
881 other => Err(CompileError::UnsupportedExpr(format!("{other}"))),
882 }
883}
884
885pub(crate) fn field_type_to_cranelift(ft: FieldType) -> CraneliftType {
891 match ft {
892 FieldType::Bool | FieldType::Int8 | FieldType::UInt8 => cl_types::I8,
893 FieldType::Int16 | FieldType::UInt16 => cl_types::I16,
894 FieldType::Int32 | FieldType::UInt32 => cl_types::I32,
895 FieldType::Float32 => cl_types::F32,
896 FieldType::Int64
897 | FieldType::UInt64
898 | FieldType::TimestampMicros
899 | FieldType::Utf8
900 | FieldType::Binary => cl_types::I64,
901 FieldType::Float64 => cl_types::F64,
902 }
903}
904
905fn is_signed_type(ft: FieldType) -> bool {
907 matches!(
908 ft,
909 FieldType::Int8
910 | FieldType::Int16
911 | FieldType::Int32
912 | FieldType::Int64
913 | FieldType::TimestampMicros
914 )
915}
916
917fn int_cmp_cond(op: Operator, signed: bool) -> IntCC {
919 match (op, signed) {
920 (Operator::Eq, _) => IntCC::Equal,
921 (Operator::NotEq, _) => IntCC::NotEqual,
922 (Operator::Lt, true) => IntCC::SignedLessThan,
923 (Operator::Lt, false) => IntCC::UnsignedLessThan,
924 (Operator::LtEq, true) => IntCC::SignedLessThanOrEqual,
925 (Operator::LtEq, false) => IntCC::UnsignedLessThanOrEqual,
926 (Operator::Gt, true) => IntCC::SignedGreaterThan,
927 (Operator::Gt, false) => IntCC::UnsignedGreaterThan,
928 (Operator::GtEq, true) => IntCC::SignedGreaterThanOrEqual,
929 (Operator::GtEq, false) => IntCC::UnsignedGreaterThanOrEqual,
930 _ => unreachable!("int_cmp_cond called with non-comparison operator"),
931 }
932}
933
934fn float_cmp_cond(op: Operator) -> FloatCC {
936 match op {
937 Operator::Eq => FloatCC::Equal,
938 Operator::NotEq => FloatCC::NotEqual,
939 Operator::Lt => FloatCC::LessThan,
940 Operator::LtEq => FloatCC::LessThanOrEqual,
941 Operator::Gt => FloatCC::GreaterThan,
942 Operator::GtEq => FloatCC::GreaterThanOrEqual,
943 _ => unreachable!("float_cmp_cond called with non-comparison operator"),
944 }
945}
946
947#[cfg(test)]
948mod tests {
949 use super::*;
950 use crate::compiler::row::{MutableEventRow, RowSchema};
951 use arrow_schema::{DataType, Field, Schema};
952 use bumpalo::Bump;
953 use datafusion_expr::{col, lit};
954 use std::sync::Arc;
955
956 fn make_schema(fields: Vec<(&str, DataType, bool)>) -> Arc<Schema> {
957 Arc::new(Schema::new(
958 fields
959 .into_iter()
960 .map(|(name, dt, nullable)| Field::new(name, dt, nullable))
961 .collect::<Vec<_>>(),
962 ))
963 }
964
965 fn make_row_bytes<'a>(
966 arena: &'a Bump,
967 schema: &'a RowSchema,
968 values: &[(usize, i64)],
969 nulls: &[usize],
970 ) -> &'a [u8] {
971 let mut row = MutableEventRow::new_in(arena, schema, 0);
972 for &(idx, val) in values {
973 row.set_i64(idx, val);
974 }
975 for &idx in nulls {
976 row.set_null(idx, true);
977 }
978 row.freeze().data()
979 }
980
981 #[test]
984 fn filter_col_gt_literal() {
985 let arrow = make_schema(vec![("val", DataType::Int64, false)]);
986 let rs = RowSchema::from_arrow(&arrow).unwrap();
987 let mut jit = JitContext::new().unwrap();
988 let mut compiler = ExprCompiler::new(&mut jit, &rs);
989
990 let expr = col("val").gt(lit(100_i64));
991 let filter = compiler.compile_filter(&expr).unwrap();
992
993 let arena = Bump::new();
994 let bytes = make_row_bytes(&arena, &rs, &[(0, 200)], &[]);
995 assert_eq!(unsafe { filter(bytes.as_ptr()) }, 1);
996
997 let bytes = make_row_bytes(&arena, &rs, &[(0, 50)], &[]);
998 assert_eq!(unsafe { filter(bytes.as_ptr()) }, 0);
999 }
1000
1001 #[test]
1002 fn filter_col_eq_literal() {
1003 let arrow = make_schema(vec![("id", DataType::Int64, false)]);
1004 let rs = RowSchema::from_arrow(&arrow).unwrap();
1005 let mut jit = JitContext::new().unwrap();
1006 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1007
1008 let expr = col("id").eq(lit(42_i64));
1009 let filter = compiler.compile_filter(&expr).unwrap();
1010
1011 let arena = Bump::new();
1012 let bytes = make_row_bytes(&arena, &rs, &[(0, 42)], &[]);
1013 assert_eq!(unsafe { filter(bytes.as_ptr()) }, 1);
1014
1015 let bytes = make_row_bytes(&arena, &rs, &[(0, 43)], &[]);
1016 assert_eq!(unsafe { filter(bytes.as_ptr()) }, 0);
1017 }
1018
1019 #[test]
1020 fn filter_col_lt_literal() {
1021 let arrow = make_schema(vec![("x", DataType::Int64, false)]);
1022 let rs = RowSchema::from_arrow(&arrow).unwrap();
1023 let mut jit = JitContext::new().unwrap();
1024 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1025
1026 let expr = col("x").lt(lit(10_i64));
1027 let filter = compiler.compile_filter(&expr).unwrap();
1028
1029 let arena = Bump::new();
1030 let bytes = make_row_bytes(&arena, &rs, &[(0, 5)], &[]);
1031 assert_eq!(unsafe { filter(bytes.as_ptr()) }, 1);
1032
1033 let bytes = make_row_bytes(&arena, &rs, &[(0, 15)], &[]);
1034 assert_eq!(unsafe { filter(bytes.as_ptr()) }, 0);
1035 }
1036
1037 #[test]
1038 fn filter_all_comparison_ops() {
1039 let arrow = make_schema(vec![("x", DataType::Int64, false)]);
1040 let rs = RowSchema::from_arrow(&arrow).unwrap();
1041 let mut jit = JitContext::new().unwrap();
1042 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1043
1044 let arena = Bump::new();
1045
1046 let f = compiler
1048 .compile_filter(&col("x").lt_eq(lit(10_i64)))
1049 .unwrap();
1050 let b = make_row_bytes(&arena, &rs, &[(0, 10)], &[]);
1051 assert_eq!(unsafe { f(b.as_ptr()) }, 1);
1052 let b = make_row_bytes(&arena, &rs, &[(0, 11)], &[]);
1053 assert_eq!(unsafe { f(b.as_ptr()) }, 0);
1054
1055 let f = compiler
1057 .compile_filter(&col("x").gt_eq(lit(10_i64)))
1058 .unwrap();
1059 let b = make_row_bytes(&arena, &rs, &[(0, 10)], &[]);
1060 assert_eq!(unsafe { f(b.as_ptr()) }, 1);
1061 let b = make_row_bytes(&arena, &rs, &[(0, 9)], &[]);
1062 assert_eq!(unsafe { f(b.as_ptr()) }, 0);
1063
1064 let f = compiler
1066 .compile_filter(&col("x").not_eq(lit(42_i64)))
1067 .unwrap();
1068 let b = make_row_bytes(&arena, &rs, &[(0, 43)], &[]);
1069 assert_eq!(unsafe { f(b.as_ptr()) }, 1);
1070 let b = make_row_bytes(&arena, &rs, &[(0, 42)], &[]);
1071 assert_eq!(unsafe { f(b.as_ptr()) }, 0);
1072 }
1073
1074 #[test]
1077 fn filter_and_compound() {
1078 let arrow = make_schema(vec![
1079 ("a", DataType::Int64, false),
1080 ("b", DataType::Int64, false),
1081 ]);
1082 let rs = RowSchema::from_arrow(&arrow).unwrap();
1083 let mut jit = JitContext::new().unwrap();
1084 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1085
1086 let expr = col("a").gt(lit(1_i64)).and(col("b").lt(lit(10_i64)));
1087 let filter = compiler.compile_filter(&expr).unwrap();
1088
1089 let arena = Bump::new();
1090 let bytes = make_row_bytes(&arena, &rs, &[(0, 5), (1, 3)], &[]);
1091 assert_eq!(unsafe { filter(bytes.as_ptr()) }, 1);
1092
1093 let bytes = make_row_bytes(&arena, &rs, &[(0, 0), (1, 3)], &[]);
1094 assert_eq!(unsafe { filter(bytes.as_ptr()) }, 0);
1095
1096 let bytes = make_row_bytes(&arena, &rs, &[(0, 5), (1, 20)], &[]);
1097 assert_eq!(unsafe { filter(bytes.as_ptr()) }, 0);
1098 }
1099
1100 #[test]
1101 fn filter_or_compound() {
1102 let arrow = make_schema(vec![
1103 ("a", DataType::Int64, false),
1104 ("b", DataType::Int64, false),
1105 ]);
1106 let rs = RowSchema::from_arrow(&arrow).unwrap();
1107 let mut jit = JitContext::new().unwrap();
1108 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1109
1110 let expr = col("a").eq(lit(42_i64)).or(col("b").eq(lit(99_i64)));
1111 let filter = compiler.compile_filter(&expr).unwrap();
1112
1113 let arena = Bump::new();
1114 let bytes = make_row_bytes(&arena, &rs, &[(0, 42), (1, 0)], &[]);
1115 assert_eq!(unsafe { filter(bytes.as_ptr()) }, 1);
1116
1117 let bytes = make_row_bytes(&arena, &rs, &[(0, 0), (1, 99)], &[]);
1118 assert_eq!(unsafe { filter(bytes.as_ptr()) }, 1);
1119
1120 let bytes = make_row_bytes(&arena, &rs, &[(0, 0), (1, 0)], &[]);
1121 assert_eq!(unsafe { filter(bytes.as_ptr()) }, 0);
1122 }
1123
1124 #[test]
1125 fn filter_nested_and_or() {
1126 let arrow = make_schema(vec![
1127 ("a", DataType::Int64, false),
1128 ("b", DataType::Int64, false),
1129 ("c", DataType::Int64, false),
1130 ]);
1131 let rs = RowSchema::from_arrow(&arrow).unwrap();
1132 let mut jit = JitContext::new().unwrap();
1133 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1134
1135 let expr = col("a")
1136 .gt(lit(0_i64))
1137 .and(col("b").gt(lit(0_i64)))
1138 .or(col("c").gt(lit(0_i64)));
1139 let filter = compiler.compile_filter(&expr).unwrap();
1140
1141 let arena = Bump::new();
1142 let bytes = make_row_bytes(&arena, &rs, &[(0, 1), (1, 1), (2, 0)], &[]);
1143 assert_eq!(unsafe { filter(bytes.as_ptr()) }, 1);
1144
1145 let bytes = make_row_bytes(&arena, &rs, &[(0, 0), (1, 0), (2, 1)], &[]);
1146 assert_eq!(unsafe { filter(bytes.as_ptr()) }, 1);
1147
1148 let bytes = make_row_bytes(&arena, &rs, &[(0, 0), (1, 0), (2, 0)], &[]);
1149 assert_eq!(unsafe { filter(bytes.as_ptr()) }, 0);
1150 }
1151
1152 #[test]
1153 fn filter_not() {
1154 let arrow = make_schema(vec![("x", DataType::Int64, false)]);
1155 let rs = RowSchema::from_arrow(&arrow).unwrap();
1156 let mut jit = JitContext::new().unwrap();
1157 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1158
1159 let expr = Expr::Not(Box::new(col("x").gt(lit(10_i64))));
1160 let filter = compiler.compile_filter(&expr).unwrap();
1161
1162 let arena = Bump::new();
1163 let bytes = make_row_bytes(&arena, &rs, &[(0, 5)], &[]);
1164 assert_eq!(unsafe { filter(bytes.as_ptr()) }, 1);
1165
1166 let bytes = make_row_bytes(&arena, &rs, &[(0, 20)], &[]);
1167 assert_eq!(unsafe { filter(bytes.as_ptr()) }, 0);
1168 }
1169
1170 #[test]
1173 fn scalar_add_i64() {
1174 let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1175 let rs = RowSchema::from_arrow(&arrow).unwrap();
1176 let mut jit = JitContext::new().unwrap();
1177 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1178
1179 let expr = col("val") + lit(10_i64);
1180 let scalar = compiler.compile_scalar(&expr, &FieldType::Int64).unwrap();
1181
1182 let arena = Bump::new();
1183 let bytes = make_row_bytes(&arena, &rs, &[(0, 32)], &[]);
1184 let mut output = [0u8; 8];
1185 let is_null = unsafe { scalar(bytes.as_ptr(), output.as_mut_ptr()) };
1186 assert_eq!(is_null, 0);
1187 assert_eq!(i64::from_le_bytes(output), 42);
1188 }
1189
1190 #[test]
1191 fn scalar_multiply_i64() {
1192 let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1193 let rs = RowSchema::from_arrow(&arrow).unwrap();
1194 let mut jit = JitContext::new().unwrap();
1195 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1196
1197 let expr = col("val") * lit(7_i64);
1198 let scalar = compiler.compile_scalar(&expr, &FieldType::Int64).unwrap();
1199
1200 let arena = Bump::new();
1201 let bytes = make_row_bytes(&arena, &rs, &[(0, 6)], &[]);
1202 let mut output = [0u8; 8];
1203 let is_null = unsafe { scalar(bytes.as_ptr(), output.as_mut_ptr()) };
1204 assert_eq!(is_null, 0);
1205 assert_eq!(i64::from_le_bytes(output), 42);
1206 }
1207
1208 #[test]
1209 fn scalar_sub_div_mod() {
1210 let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1211 let rs = RowSchema::from_arrow(&arrow).unwrap();
1212 let mut jit = JitContext::new().unwrap();
1213 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1214 let arena = Bump::new();
1215 let mut output = [0u8; 8];
1216
1217 let s = compiler
1219 .compile_scalar(&(col("val") - lit(8_i64)), &FieldType::Int64)
1220 .unwrap();
1221 let b = make_row_bytes(&arena, &rs, &[(0, 50)], &[]);
1222 assert_eq!(unsafe { s(b.as_ptr(), output.as_mut_ptr()) }, 0);
1223 assert_eq!(i64::from_le_bytes(output), 42);
1224
1225 let s = compiler
1227 .compile_scalar(&(col("val") / lit(2_i64)), &FieldType::Int64)
1228 .unwrap();
1229 let b = make_row_bytes(&arena, &rs, &[(0, 84)], &[]);
1230 assert_eq!(unsafe { s(b.as_ptr(), output.as_mut_ptr()) }, 0);
1231 assert_eq!(i64::from_le_bytes(output), 42);
1232
1233 let s = compiler
1235 .compile_scalar(&(col("val") % lit(10_i64)), &FieldType::Int64)
1236 .unwrap();
1237 let b = make_row_bytes(&arena, &rs, &[(0, 42)], &[]);
1238 assert_eq!(unsafe { s(b.as_ptr(), output.as_mut_ptr()) }, 0);
1239 assert_eq!(i64::from_le_bytes(output), 2);
1240 }
1241
1242 #[test]
1243 fn scalar_f64_add() {
1244 let arrow = make_schema(vec![("val", DataType::Float64, false)]);
1245 let rs = RowSchema::from_arrow(&arrow).unwrap();
1246 let mut jit = JitContext::new().unwrap();
1247 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1248
1249 let expr = col("val") + lit(1.5_f64);
1250 let scalar = compiler.compile_scalar(&expr, &FieldType::Float64).unwrap();
1251
1252 let arena = Bump::new();
1253 let mut row = MutableEventRow::new_in(&arena, &rs, 0);
1254 row.set_f64(0, 2.5);
1255 let frozen = row.freeze();
1256 let mut output = [0u8; 8];
1257 let is_null = unsafe { scalar(frozen.data().as_ptr(), output.as_mut_ptr()) };
1258 assert_eq!(is_null, 0);
1259 assert!((f64::from_le_bytes(output) - 4.0).abs() < f64::EPSILON);
1260 }
1261
1262 #[test]
1263 fn scalar_column_passthrough() {
1264 let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1265 let rs = RowSchema::from_arrow(&arrow).unwrap();
1266 let mut jit = JitContext::new().unwrap();
1267 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1268
1269 let scalar = compiler
1270 .compile_scalar(&col("val"), &FieldType::Int64)
1271 .unwrap();
1272
1273 let arena = Bump::new();
1274 let bytes = make_row_bytes(&arena, &rs, &[(0, 999)], &[]);
1275 let mut output = [0u8; 8];
1276 assert_eq!(unsafe { scalar(bytes.as_ptr(), output.as_mut_ptr()) }, 0);
1277 assert_eq!(i64::from_le_bytes(output), 999);
1278 }
1279
1280 #[test]
1283 fn filter_null_column_rejects() {
1284 let arrow = make_schema(vec![("val", DataType::Int64, true)]);
1285 let rs = RowSchema::from_arrow(&arrow).unwrap();
1286 let mut jit = JitContext::new().unwrap();
1287 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1288
1289 let filter = compiler
1290 .compile_filter(&col("val").gt(lit(10_i64)))
1291 .unwrap();
1292
1293 let arena = Bump::new();
1294 let bytes = make_row_bytes(&arena, &rs, &[(0, 999)], &[0]);
1295 assert_eq!(unsafe { filter(bytes.as_ptr()) }, 0);
1296 }
1297
1298 #[test]
1299 fn scalar_null_propagation() {
1300 let arrow = make_schema(vec![("val", DataType::Int64, true)]);
1301 let rs = RowSchema::from_arrow(&arrow).unwrap();
1302 let mut jit = JitContext::new().unwrap();
1303 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1304
1305 let scalar = compiler
1306 .compile_scalar(&(col("val") + lit(10_i64)), &FieldType::Int64)
1307 .unwrap();
1308
1309 let arena = Bump::new();
1310 let bytes = make_row_bytes(&arena, &rs, &[(0, 0)], &[0]);
1311 let mut output = [0u8; 8];
1312 assert_ne!(unsafe { scalar(bytes.as_ptr(), output.as_mut_ptr()) }, 0);
1313 }
1314
1315 #[test]
1316 fn filter_is_null() {
1317 let arrow = make_schema(vec![("val", DataType::Int64, true)]);
1318 let rs = RowSchema::from_arrow(&arrow).unwrap();
1319 let mut jit = JitContext::new().unwrap();
1320 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1321
1322 let filter = compiler.compile_filter(&col("val").is_null()).unwrap();
1323
1324 let arena = Bump::new();
1325 let bytes = make_row_bytes(&arena, &rs, &[(0, 0)], &[0]);
1326 assert_eq!(unsafe { filter(bytes.as_ptr()) }, 1);
1327
1328 let bytes = make_row_bytes(&arena, &rs, &[(0, 42)], &[]);
1329 assert_eq!(unsafe { filter(bytes.as_ptr()) }, 0);
1330 }
1331
1332 #[test]
1333 fn filter_is_not_null() {
1334 let arrow = make_schema(vec![("val", DataType::Int64, true)]);
1335 let rs = RowSchema::from_arrow(&arrow).unwrap();
1336 let mut jit = JitContext::new().unwrap();
1337 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1338
1339 let filter = compiler.compile_filter(&col("val").is_not_null()).unwrap();
1340
1341 let arena = Bump::new();
1342 let bytes = make_row_bytes(&arena, &rs, &[(0, 42)], &[]);
1343 assert_eq!(unsafe { filter(bytes.as_ptr()) }, 1);
1344
1345 let bytes = make_row_bytes(&arena, &rs, &[(0, 0)], &[0]);
1346 assert_eq!(unsafe { filter(bytes.as_ptr()) }, 0);
1347 }
1348
1349 #[test]
1350 fn null_literal_produces_null() {
1351 let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1352 let rs = RowSchema::from_arrow(&arrow).unwrap();
1353 let mut jit = JitContext::new().unwrap();
1354 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1355
1356 let expr = Expr::BinaryExpr(BinaryExpr::new(
1357 Box::new(col("val")),
1358 Operator::Plus,
1359 Box::new(Expr::Literal(ScalarValue::Int64(None), None)),
1360 ));
1361 let scalar = compiler.compile_scalar(&expr, &FieldType::Int64).unwrap();
1362
1363 let arena = Bump::new();
1364 let bytes = make_row_bytes(&arena, &rs, &[(0, 42)], &[]);
1365 let mut output = [0u8; 8];
1366 assert_ne!(unsafe { scalar(bytes.as_ptr(), output.as_mut_ptr()) }, 0);
1367 }
1368
1369 #[test]
1370 fn null_and_true_short_circuits_false() {
1371 let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1372 let rs = RowSchema::from_arrow(&arrow).unwrap();
1373 let mut jit = JitContext::new().unwrap();
1374 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1375
1376 let expr = Expr::BinaryExpr(BinaryExpr::new(
1377 Box::new(Expr::Literal(ScalarValue::Boolean(None), None)),
1378 Operator::And,
1379 Box::new(Expr::Literal(ScalarValue::Boolean(Some(true)), None)),
1380 ));
1381 let filter = compiler.compile_filter(&expr).unwrap();
1382
1383 let arena = Bump::new();
1384 let bytes = make_row_bytes(&arena, &rs, &[(0, 0)], &[]);
1385 assert_eq!(unsafe { filter(bytes.as_ptr()) }, 0);
1386 }
1387
1388 #[test]
1391 fn cast_i32_to_i64() {
1392 let arrow = make_schema(vec![("val", DataType::Int32, false)]);
1393 let rs = RowSchema::from_arrow(&arrow).unwrap();
1394 let mut jit = JitContext::new().unwrap();
1395 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1396
1397 let expr = Expr::Cast(datafusion_expr::Cast {
1398 expr: Box::new(col("val")),
1399 data_type: DataType::Int64,
1400 });
1401 let scalar = compiler.compile_scalar(&expr, &FieldType::Int64).unwrap();
1402
1403 let arena = Bump::new();
1404 let mut row = MutableEventRow::new_in(&arena, &rs, 0);
1405 row.set_i32(0, -42);
1406 let frozen = row.freeze();
1407 let mut output = [0u8; 8];
1408 assert_eq!(
1409 unsafe { scalar(frozen.data().as_ptr(), output.as_mut_ptr()) },
1410 0
1411 );
1412 assert_eq!(i64::from_le_bytes(output), -42);
1413 }
1414
1415 #[test]
1416 fn cast_i64_to_f64() {
1417 let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1418 let rs = RowSchema::from_arrow(&arrow).unwrap();
1419 let mut jit = JitContext::new().unwrap();
1420 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1421
1422 let expr = Expr::Cast(datafusion_expr::Cast {
1423 expr: Box::new(col("val")),
1424 data_type: DataType::Float64,
1425 });
1426 let scalar = compiler.compile_scalar(&expr, &FieldType::Float64).unwrap();
1427
1428 let arena = Bump::new();
1429 let bytes = make_row_bytes(&arena, &rs, &[(0, 42)], &[]);
1430 let mut output = [0u8; 8];
1431 assert_eq!(unsafe { scalar(bytes.as_ptr(), output.as_mut_ptr()) }, 0);
1432 assert!((f64::from_le_bytes(output) - 42.0).abs() < f64::EPSILON);
1433 }
1434
1435 #[test]
1436 fn cast_f64_to_i64() {
1437 let arrow = make_schema(vec![("val", DataType::Float64, false)]);
1438 let rs = RowSchema::from_arrow(&arrow).unwrap();
1439 let mut jit = JitContext::new().unwrap();
1440 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1441
1442 let expr = Expr::Cast(datafusion_expr::Cast {
1443 expr: Box::new(col("val")),
1444 data_type: DataType::Int64,
1445 });
1446 let scalar = compiler.compile_scalar(&expr, &FieldType::Int64).unwrap();
1447
1448 let arena = Bump::new();
1449 let mut row = MutableEventRow::new_in(&arena, &rs, 0);
1450 row.set_f64(0, 42.9);
1451 let frozen = row.freeze();
1452 let mut output = [0u8; 8];
1453 assert_eq!(
1454 unsafe { scalar(frozen.data().as_ptr(), output.as_mut_ptr()) },
1455 0
1456 );
1457 assert_eq!(i64::from_le_bytes(output), 42);
1458 }
1459
1460 #[test]
1461 fn cast_bool_to_i64() {
1462 let arrow = make_schema(vec![("flag", DataType::Boolean, false)]);
1463 let rs = RowSchema::from_arrow(&arrow).unwrap();
1464 let mut jit = JitContext::new().unwrap();
1465 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1466
1467 let expr = Expr::Cast(datafusion_expr::Cast {
1468 expr: Box::new(col("flag")),
1469 data_type: DataType::Int64,
1470 });
1471 let scalar = compiler.compile_scalar(&expr, &FieldType::Int64).unwrap();
1472
1473 let arena = Bump::new();
1474 let mut row = MutableEventRow::new_in(&arena, &rs, 0);
1475 row.set_bool(0, true);
1476 let frozen = row.freeze();
1477 let mut output = [0u8; 8];
1478 assert_eq!(
1479 unsafe { scalar(frozen.data().as_ptr(), output.as_mut_ptr()) },
1480 0
1481 );
1482 assert_eq!(i64::from_le_bytes(output), 1);
1483 }
1484
1485 #[test]
1488 fn case_simple() {
1489 let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1490 let rs = RowSchema::from_arrow(&arrow).unwrap();
1491 let mut jit = JitContext::new().unwrap();
1492 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1493
1494 let expr = Expr::Case(datafusion_expr::Case {
1495 expr: None,
1496 when_then_expr: vec![(Box::new(col("val").gt(lit(10_i64))), Box::new(lit(1_i64)))],
1497 else_expr: Some(Box::new(lit(0_i64))),
1498 });
1499 let scalar = compiler.compile_scalar(&expr, &FieldType::Int64).unwrap();
1500
1501 let arena = Bump::new();
1502 let mut output = [0u8; 8];
1503
1504 let bytes = make_row_bytes(&arena, &rs, &[(0, 20)], &[]);
1505 unsafe { scalar(bytes.as_ptr(), output.as_mut_ptr()) };
1506 assert_eq!(i64::from_le_bytes(output), 1);
1507
1508 let bytes = make_row_bytes(&arena, &rs, &[(0, 5)], &[]);
1509 unsafe { scalar(bytes.as_ptr(), output.as_mut_ptr()) };
1510 assert_eq!(i64::from_le_bytes(output), 0);
1511 }
1512
1513 #[test]
1514 fn case_multiple_whens() {
1515 let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1516 let rs = RowSchema::from_arrow(&arrow).unwrap();
1517 let mut jit = JitContext::new().unwrap();
1518 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1519
1520 let expr = Expr::Case(datafusion_expr::Case {
1521 expr: None,
1522 when_then_expr: vec![
1523 (Box::new(col("val").gt(lit(100_i64))), Box::new(lit(3_i64))),
1524 (Box::new(col("val").gt(lit(10_i64))), Box::new(lit(2_i64))),
1525 ],
1526 else_expr: Some(Box::new(lit(1_i64))),
1527 });
1528 let scalar = compiler.compile_scalar(&expr, &FieldType::Int64).unwrap();
1529
1530 let arena = Bump::new();
1531 let mut output = [0u8; 8];
1532
1533 let b = make_row_bytes(&arena, &rs, &[(0, 200)], &[]);
1534 unsafe { scalar(b.as_ptr(), output.as_mut_ptr()) };
1535 assert_eq!(i64::from_le_bytes(output), 3);
1536
1537 let b = make_row_bytes(&arena, &rs, &[(0, 50)], &[]);
1538 unsafe { scalar(b.as_ptr(), output.as_mut_ptr()) };
1539 assert_eq!(i64::from_le_bytes(output), 2);
1540
1541 let b = make_row_bytes(&arena, &rs, &[(0, 5)], &[]);
1542 unsafe { scalar(b.as_ptr(), output.as_mut_ptr()) };
1543 assert_eq!(i64::from_le_bytes(output), 1);
1544 }
1545
1546 #[test]
1547 fn case_no_else_returns_null() {
1548 let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1549 let rs = RowSchema::from_arrow(&arrow).unwrap();
1550 let mut jit = JitContext::new().unwrap();
1551 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1552
1553 let expr = Expr::Case(datafusion_expr::Case {
1554 expr: None,
1555 when_then_expr: vec![(Box::new(col("val").gt(lit(100_i64))), Box::new(lit(1_i64)))],
1556 else_expr: None,
1557 });
1558 let scalar = compiler.compile_scalar(&expr, &FieldType::Int64).unwrap();
1559
1560 let arena = Bump::new();
1561 let mut output = [0u8; 8];
1562 let bytes = make_row_bytes(&arena, &rs, &[(0, 5)], &[]);
1563 let is_null = unsafe { scalar(bytes.as_ptr(), output.as_mut_ptr()) };
1564 assert_ne!(is_null, 0);
1565 }
1566
1567 #[test]
1570 fn unsupported_expr_error() {
1571 let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1572 let rs = RowSchema::from_arrow(&arrow).unwrap();
1573 let mut jit = JitContext::new().unwrap();
1574 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1575
1576 let expr = Expr::Unnest(datafusion_expr::expr::Unnest::new(col("val")));
1577 assert!(compiler.compile_filter(&expr).is_err());
1578 }
1579
1580 #[test]
1581 fn column_not_found_error() {
1582 let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1583 let rs = RowSchema::from_arrow(&arrow).unwrap();
1584 let mut jit = JitContext::new().unwrap();
1585 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1586
1587 let result = compiler.compile_filter(&col("nonexistent").gt(lit(1_i64)));
1588 assert!(matches!(result, Err(CompileError::ColumnNotFound(_))));
1589 }
1590
1591 #[test]
1594 fn filter_f64_comparison() {
1595 let arrow = make_schema(vec![("price", DataType::Float64, false)]);
1596 let rs = RowSchema::from_arrow(&arrow).unwrap();
1597 let mut jit = JitContext::new().unwrap();
1598 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1599
1600 let filter = compiler
1601 .compile_filter(&col("price").gt(lit(100.0_f64)))
1602 .unwrap();
1603
1604 let arena = Bump::new();
1605 let mut row = MutableEventRow::new_in(&arena, &rs, 0);
1606 row.set_f64(0, 150.0);
1607 assert_eq!(unsafe { filter(row.freeze().data().as_ptr()) }, 1);
1608
1609 let mut row = MutableEventRow::new_in(&arena, &rs, 0);
1610 row.set_f64(0, 50.0);
1611 assert_eq!(unsafe { filter(row.freeze().data().as_ptr()) }, 0);
1612 }
1613
1614 #[test]
1617 fn scalar_two_columns_add() {
1618 let arrow = make_schema(vec![
1619 ("a", DataType::Int64, false),
1620 ("b", DataType::Int64, false),
1621 ]);
1622 let rs = RowSchema::from_arrow(&arrow).unwrap();
1623 let mut jit = JitContext::new().unwrap();
1624 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1625
1626 let scalar = compiler
1627 .compile_scalar(&(col("a") + col("b")), &FieldType::Int64)
1628 .unwrap();
1629
1630 let arena = Bump::new();
1631 let bytes = make_row_bytes(&arena, &rs, &[(0, 30), (1, 12)], &[]);
1632 let mut output = [0u8; 8];
1633 unsafe { scalar(bytes.as_ptr(), output.as_mut_ptr()) };
1634 assert_eq!(i64::from_le_bytes(output), 42);
1635 }
1636
1637 #[test]
1640 fn multiple_functions_same_context() {
1641 let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1642 let rs = RowSchema::from_arrow(&arrow).unwrap();
1643 let mut jit = JitContext::new().unwrap();
1644 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1645
1646 let f1 = compiler
1647 .compile_filter(&col("val").gt(lit(10_i64)))
1648 .unwrap();
1649 let f2 = compiler
1650 .compile_filter(&col("val").lt(lit(100_i64)))
1651 .unwrap();
1652
1653 let arena = Bump::new();
1654 let bytes = make_row_bytes(&arena, &rs, &[(0, 50)], &[]);
1655 assert_eq!(unsafe { f1(bytes.as_ptr()) }, 1);
1656 assert_eq!(unsafe { f2(bytes.as_ptr()) }, 1);
1657
1658 let bytes = make_row_bytes(&arena, &rs, &[(0, 5)], &[]);
1659 assert_eq!(unsafe { f1(bytes.as_ptr()) }, 0);
1660 assert_eq!(unsafe { f2(bytes.as_ptr()) }, 1);
1661 }
1662
1663 #[test]
1666 fn filter_literal_true_false() {
1667 let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1668 let rs = RowSchema::from_arrow(&arrow).unwrap();
1669 let mut jit = JitContext::new().unwrap();
1670 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1671 let arena = Bump::new();
1672 let bytes = make_row_bytes(&arena, &rs, &[(0, 0)], &[]);
1673
1674 let ft = compiler.compile_filter(&lit(true)).unwrap();
1675 assert_eq!(unsafe { ft(bytes.as_ptr()) }, 1);
1676
1677 let ff = compiler.compile_filter(&lit(false)).unwrap();
1678 assert_eq!(unsafe { ff(bytes.as_ptr()) }, 0);
1679 }
1680
1681 #[test]
1684 fn constant_folding_in_compile() {
1685 let arrow = make_schema(vec![("val", DataType::Int64, false)]);
1686 let rs = RowSchema::from_arrow(&arrow).unwrap();
1687 let mut jit = JitContext::new().unwrap();
1688 let mut compiler = ExprCompiler::new(&mut jit, &rs);
1689
1690 let expr = col("val").gt(lit(2_i64) + lit(3_i64));
1691 let filter = compiler.compile_filter(&expr).unwrap();
1692
1693 let arena = Bump::new();
1694 let bytes = make_row_bytes(&arena, &rs, &[(0, 10)], &[]);
1695 assert_eq!(unsafe { filter(bytes.as_ptr()) }, 1);
1696
1697 let bytes = make_row_bytes(&arena, &rs, &[(0, 3)], &[]);
1698 assert_eq!(unsafe { filter(bytes.as_ptr()) }, 0);
1699 }
1700}