Skip to main content

laminar_sql/datafusion/
complex_type_lambda.rs

1//! Lambda higher-order functions for arrays and maps (F-SCHEMA-015 Tier 3).
2//!
3//! Provides vectorized lambda evaluation over Arrow arrays:
4//!
5//! | Function | Lambda | Strategy |
6//! |----------|--------|----------|
7//! | `array_transform(arr, lambda)` | `x -> expr` | Flatten, eval, re-group |
8//! | `array_filter(arr, lambda)` | `x -> bool` | Flatten, eval, filter+rebuild |
9//! | `array_reduce(arr, init, lambda)` | `(acc, x) -> expr` | Sequential fold |
10//! | `map_filter(map, lambda)` | `(k, v) -> bool` | Eval on k+v, filter entries |
11//! | `map_transform_values(map, lambda)` | `(k, v) -> expr` | Eval on k+v, replace vals |
12//!
13//! Lambda expressions are specified as string literal SQL expressions.
14//! They are evaluated using DataFusion's SQL engine against a temporary
15//! table containing the element values. Native lambda syntax is deferred.
16
17use std::any::Any;
18use std::hash::{Hash, Hasher};
19use std::sync::Arc;
20
21use arrow::datatypes::DataType;
22use arrow_array::{Array, ArrayRef, BooleanArray, ListArray, MapArray, StructArray};
23use arrow_schema::{Field, Fields, Schema};
24use datafusion_common::Result;
25use datafusion_expr::{
26    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
27};
28
29use super::json_udf::expand_args;
30
31/// Registers all lambda HOFs with the given session context.
32pub fn register_lambda_functions(ctx: &datafusion::prelude::SessionContext) {
33    use datafusion_expr::ScalarUDF;
34
35    ctx.register_udf(ScalarUDF::new_from_impl(ArrayTransform::new()));
36    ctx.register_udf(ScalarUDF::new_from_impl(ArrayFilter::new()));
37    ctx.register_udf(ScalarUDF::new_from_impl(ArrayReduce::new()));
38    ctx.register_udf(ScalarUDF::new_from_impl(MapFilter::new()));
39    ctx.register_udf(ScalarUDF::new_from_impl(MapTransformValues::new()));
40}
41
42thread_local! {
43    /// Cached `SessionContext` for lambda evaluation.
44    ///
45    /// Creating a `SessionContext` is expensive because it registers all
46    /// built-in functions and sets up catalogs. We cache it per-thread
47    /// and reuse it across lambda invocations.
48    static LAMBDA_CTX: std::cell::RefCell<Option<datafusion::prelude::SessionContext>> =
49        const { std::cell::RefCell::new(None) };
50}
51
52/// Evaluate a SQL expression against a `RecordBatch`, returning the result column.
53///
54/// The expression can reference columns by name from the batch schema.
55fn eval_expr_on_batch(sql_expr: &str, batch: &arrow_array::RecordBatch) -> Result<ArrayRef> {
56    // Reuse a thread-local SessionContext to avoid the cost of
57    // registering built-in functions on every invocation.
58    let ctx = LAMBDA_CTX.with(|cell| {
59        let mut opt = cell.borrow_mut();
60        opt.get_or_insert_with(datafusion::prelude::SessionContext::new)
61            .clone()
62    });
63
64    let provider =
65        datafusion::datasource::MemTable::try_new(batch.schema(), vec![vec![batch.clone()]])?;
66    let rt = tokio::runtime::Handle::try_current().map_err(|e| {
67        datafusion_common::DataFusionError::Internal(format!(
68            "lambda eval requires tokio runtime: {e}"
69        ))
70    })?;
71    // Use block_in_place to allow blocking inside an already-running tokio runtime.
72    tokio::task::block_in_place(|| {
73        rt.block_on(async {
74            ctx.register_table("__lambda_data", Arc::new(provider))?;
75            let df = ctx
76                .sql(&format!("SELECT {sql_expr} FROM __lambda_data"))
77                .await?;
78            let batches = df.collect().await?;
79            // Deregister the ephemeral table so the batch data is freed.
80            let _ = ctx.deregister_table("__lambda_data");
81            if batches.is_empty() {
82                Err(datafusion_common::DataFusionError::Internal(
83                    "lambda expression returned no data".into(),
84                ))
85            } else {
86                // Concatenate all result batches and return the first column.
87                let result = arrow::compute::concat_batches(&batches[0].schema(), &batches)?;
88                Ok(result.column(0).clone())
89            }
90        })
91    })
92}
93
94fn scalar_string_value(cv: &ColumnarValue) -> Result<String> {
95    match cv {
96        ColumnarValue::Scalar(s) => {
97            let arr = s.to_array_of_size(1)?;
98            let str_arr = arr
99                .as_any()
100                .downcast_ref::<arrow_array::StringArray>()
101                .ok_or_else(|| {
102                    datafusion_common::DataFusionError::Internal("expected Utf8 argument".into())
103                })?;
104            Ok(str_arr.value(0).to_string())
105        }
106        ColumnarValue::Array(arr) => {
107            let str_arr = arr
108                .as_any()
109                .downcast_ref::<arrow_array::StringArray>()
110                .ok_or_else(|| {
111                    datafusion_common::DataFusionError::Internal("expected Utf8 argument".into())
112                })?;
113            Ok(str_arr.value(0).to_string())
114        }
115    }
116}
117
118// ══════════════════════════════════════════════════════════════════
119// array_transform(arr, lambda_str) -> List
120// ══════════════════════════════════════════════════════════════════
121
122/// `array_transform(arr, 'x + 1')` — apply a lambda to each element.
123#[derive(Debug)]
124pub struct ArrayTransform {
125    signature: Signature,
126}
127
128impl ArrayTransform {
129    /// Creates a new `array_transform` UDF.
130    #[must_use]
131    pub fn new() -> Self {
132        Self {
133            signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
134        }
135    }
136}
137
138impl Default for ArrayTransform {
139    fn default() -> Self {
140        Self::new()
141    }
142}
143impl PartialEq for ArrayTransform {
144    fn eq(&self, _: &Self) -> bool {
145        true
146    }
147}
148impl Eq for ArrayTransform {}
149impl Hash for ArrayTransform {
150    fn hash<H: Hasher>(&self, s: &mut H) {
151        "array_transform".hash(s);
152    }
153}
154
155impl ScalarUDFImpl for ArrayTransform {
156    fn as_any(&self) -> &dyn Any {
157        self
158    }
159    fn name(&self) -> &'static str {
160        "array_transform"
161    }
162    fn signature(&self) -> &Signature {
163        &self.signature
164    }
165
166    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
167        match &arg_types[0] {
168            DataType::List(f) => Ok(DataType::List(Arc::clone(f))),
169            _ => Ok(DataType::List(Arc::new(Field::new(
170                "item",
171                DataType::Utf8,
172                true,
173            )))),
174        }
175    }
176
177    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
178        let expanded = expand_args(&args.args)?;
179        let list_arr = expanded[0]
180            .as_any()
181            .downcast_ref::<ListArray>()
182            .ok_or_else(|| {
183                datafusion_common::DataFusionError::Internal(
184                    "array_transform: first arg must be List".into(),
185                )
186            })?;
187
188        let lambda_str = scalar_string_value(&args.args[1])?;
189        let flat_values = list_arr.values();
190
191        let schema = Arc::new(Schema::new(vec![Field::new(
192            "x",
193            flat_values.data_type().clone(),
194            true,
195        )]));
196        let batch = arrow_array::RecordBatch::try_new(schema, vec![Arc::clone(flat_values)])?;
197
198        let result_arr = eval_expr_on_batch(&lambda_str, &batch)?;
199
200        let new_field = Arc::new(Field::new("item", result_arr.data_type().clone(), true));
201        let new_list = ListArray::try_new(
202            new_field,
203            list_arr.offsets().clone(),
204            result_arr,
205            list_arr.nulls().cloned(),
206        )?;
207        Ok(ColumnarValue::Array(Arc::new(new_list)))
208    }
209}
210
211// ══════════════════════════════════════════════════════════════════
212// array_filter(arr, lambda_str) -> List
213// ══════════════════════════════════════════════════════════════════
214
215/// `array_filter(arr, 'x > 0')` — filter elements by a boolean lambda.
216#[derive(Debug)]
217pub struct ArrayFilter {
218    signature: Signature,
219}
220
221impl ArrayFilter {
222    /// Creates a new `array_filter` UDF.
223    #[must_use]
224    pub fn new() -> Self {
225        Self {
226            signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
227        }
228    }
229}
230
231impl Default for ArrayFilter {
232    fn default() -> Self {
233        Self::new()
234    }
235}
236impl PartialEq for ArrayFilter {
237    fn eq(&self, _: &Self) -> bool {
238        true
239    }
240}
241impl Eq for ArrayFilter {}
242impl Hash for ArrayFilter {
243    fn hash<H: Hasher>(&self, s: &mut H) {
244        "array_filter".hash(s);
245    }
246}
247
248impl ScalarUDFImpl for ArrayFilter {
249    fn as_any(&self) -> &dyn Any {
250        self
251    }
252    fn name(&self) -> &'static str {
253        "array_filter"
254    }
255    fn signature(&self) -> &Signature {
256        &self.signature
257    }
258
259    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
260        match &arg_types[0] {
261            DataType::List(f) => Ok(DataType::List(Arc::clone(f))),
262            _ => Ok(DataType::List(Arc::new(Field::new(
263                "item",
264                DataType::Utf8,
265                true,
266            )))),
267        }
268    }
269
270    #[allow(
271        clippy::cast_sign_loss,
272        clippy::cast_possible_wrap,
273        clippy::cast_possible_truncation
274    )]
275    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
276        let expanded = expand_args(&args.args)?;
277        let list_arr = expanded[0]
278            .as_any()
279            .downcast_ref::<ListArray>()
280            .ok_or_else(|| {
281                datafusion_common::DataFusionError::Internal(
282                    "array_filter: first arg must be List".into(),
283                )
284            })?;
285
286        let lambda_str = scalar_string_value(&args.args[1])?;
287        let flat_values = list_arr.values();
288        let elem_type = flat_values.data_type().clone();
289
290        let schema = Arc::new(Schema::new(vec![Field::new("x", elem_type.clone(), true)]));
291        let batch = arrow_array::RecordBatch::try_new(schema, vec![Arc::clone(flat_values)])?;
292
293        let mask_arr = eval_expr_on_batch(&lambda_str, &batch)?;
294        let mask = mask_arr
295            .as_any()
296            .downcast_ref::<BooleanArray>()
297            .ok_or_else(|| {
298                datafusion_common::DataFusionError::Internal(
299                    "array_filter: lambda must return Boolean".into(),
300                )
301            })?;
302
303        let mut offsets = vec![0i32];
304        let mut filtered_indices: Vec<usize> = Vec::new();
305
306        for row in 0..list_arr.len() {
307            let start = list_arr.value_offsets()[row] as usize;
308            let end = list_arr.value_offsets()[row + 1] as usize;
309
310            for i in start..end {
311                if !mask.is_null(i) && mask.value(i) {
312                    filtered_indices.push(i);
313                }
314            }
315            offsets.push(filtered_indices.len() as i32);
316        }
317
318        let indices = arrow_array::UInt32Array::from(
319            filtered_indices
320                .iter()
321                .map(|&i| i as u32)
322                .collect::<Vec<_>>(),
323        );
324        let filtered_values = arrow::compute::take(flat_values.as_ref(), &indices, None)?;
325
326        let new_field = Arc::new(Field::new("item", elem_type, true));
327        let new_offsets =
328            arrow::buffer::OffsetBuffer::new(arrow::buffer::ScalarBuffer::from(offsets));
329        let new_list = ListArray::try_new(
330            new_field,
331            new_offsets,
332            filtered_values,
333            list_arr.nulls().cloned(),
334        )?;
335        Ok(ColumnarValue::Array(Arc::new(new_list)))
336    }
337}
338
339// ══════════════════════════════════════════════════════════════════
340// array_reduce(arr, init, lambda_str) -> scalar
341// ══════════════════════════════════════════════════════════════════
342
343/// `array_reduce(arr, init, '(acc + x)')` — fold/reduce array elements.
344#[derive(Debug)]
345pub struct ArrayReduce {
346    signature: Signature,
347}
348
349impl ArrayReduce {
350    /// Creates a new `array_reduce` UDF.
351    #[must_use]
352    pub fn new() -> Self {
353        Self {
354            signature: Signature::new(TypeSignature::Any(3), Volatility::Immutable),
355        }
356    }
357}
358
359impl Default for ArrayReduce {
360    fn default() -> Self {
361        Self::new()
362    }
363}
364impl PartialEq for ArrayReduce {
365    fn eq(&self, _: &Self) -> bool {
366        true
367    }
368}
369impl Eq for ArrayReduce {}
370impl Hash for ArrayReduce {
371    fn hash<H: Hasher>(&self, s: &mut H) {
372        "array_reduce".hash(s);
373    }
374}
375
376impl ScalarUDFImpl for ArrayReduce {
377    fn as_any(&self) -> &dyn Any {
378        self
379    }
380    fn name(&self) -> &'static str {
381        "array_reduce"
382    }
383    fn signature(&self) -> &Signature {
384        &self.signature
385    }
386
387    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
388        Ok(arg_types.get(1).cloned().unwrap_or(DataType::Int64))
389    }
390
391    #[allow(clippy::cast_sign_loss)]
392    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
393        let expanded = expand_args(&args.args)?;
394        let list_arr = expanded[0]
395            .as_any()
396            .downcast_ref::<ListArray>()
397            .ok_or_else(|| {
398                datafusion_common::DataFusionError::Internal(
399                    "array_reduce: first arg must be List".into(),
400                )
401            })?;
402
403        let init_arr = &expanded[1];
404        let lambda_str = scalar_string_value(&args.args[2])?;
405
406        let elem_type = list_arr.values().data_type().clone();
407        let acc_type = init_arr.data_type().clone();
408
409        let schema = Arc::new(Schema::new(vec![
410            Field::new("acc", acc_type, true),
411            Field::new("x", elem_type, true),
412        ]));
413
414        let mut result_builder: Vec<ArrayRef> = Vec::new();
415
416        for row in 0..list_arr.len() {
417            let start = list_arr.value_offsets()[row] as usize;
418            let end = list_arr.value_offsets()[row + 1] as usize;
419
420            let mut acc: ArrayRef = init_arr.slice(row, 1);
421
422            for i in start..end {
423                let x = list_arr.values().slice(i, 1);
424                let batch = arrow_array::RecordBatch::try_new(
425                    Arc::clone(&schema),
426                    vec![Arc::clone(&acc), x],
427                )?;
428                let result_col = eval_expr_on_batch(&lambda_str, &batch)?;
429                acc = result_col;
430            }
431
432            result_builder.push(acc);
433        }
434
435        if result_builder.is_empty() {
436            return Ok(ColumnarValue::Array(Arc::clone(init_arr)));
437        }
438
439        let refs: Vec<&dyn Array> = result_builder
440            .iter()
441            .map(std::convert::AsRef::as_ref)
442            .collect();
443        let result = arrow::compute::concat(&refs)?;
444        Ok(ColumnarValue::Array(result))
445    }
446}
447
448// ══════════════════════════════════════════════════════════════════
449// map_filter(map, lambda_str) -> Map
450// ══════════════════════════════════════════════════════════════════
451
452/// `map_filter(map, '(k <> ''temp'')')` — filter map entries by key+value.
453#[derive(Debug)]
454pub struct MapFilter {
455    signature: Signature,
456}
457
458impl MapFilter {
459    /// Creates a new `map_filter` UDF.
460    #[must_use]
461    pub fn new() -> Self {
462        Self {
463            signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
464        }
465    }
466}
467
468impl Default for MapFilter {
469    fn default() -> Self {
470        Self::new()
471    }
472}
473impl PartialEq for MapFilter {
474    fn eq(&self, _: &Self) -> bool {
475        true
476    }
477}
478impl Eq for MapFilter {}
479impl Hash for MapFilter {
480    fn hash<H: Hasher>(&self, s: &mut H) {
481        "map_filter".hash(s);
482    }
483}
484
485impl ScalarUDFImpl for MapFilter {
486    fn as_any(&self) -> &dyn Any {
487        self
488    }
489    fn name(&self) -> &'static str {
490        "map_filter"
491    }
492    fn signature(&self) -> &Signature {
493        &self.signature
494    }
495
496    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
497        Ok(arg_types[0].clone())
498    }
499
500    #[allow(
501        clippy::cast_sign_loss,
502        clippy::cast_possible_wrap,
503        clippy::cast_possible_truncation
504    )]
505    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
506        let expanded = expand_args(&args.args)?;
507        let map_arr = expanded[0]
508            .as_any()
509            .downcast_ref::<MapArray>()
510            .ok_or_else(|| {
511                datafusion_common::DataFusionError::Internal(
512                    "map_filter: first arg must be Map".into(),
513                )
514            })?;
515
516        let lambda_str = scalar_string_value(&args.args[1])?;
517
518        let entries = map_arr.entries();
519        let key_col = entries.column(0);
520        let val_col = entries.column(1);
521
522        let key_type = key_col.data_type().clone();
523        let val_type = val_col.data_type().clone();
524
525        let schema = Arc::new(Schema::new(vec![
526            Field::new("k", key_type.clone(), true),
527            Field::new("v", val_type.clone(), true),
528        ]));
529        let batch = arrow_array::RecordBatch::try_new(
530            schema,
531            vec![Arc::clone(key_col), Arc::clone(val_col)],
532        )?;
533
534        let mask_arr = eval_expr_on_batch(&lambda_str, &batch)?;
535        let mask_bool = mask_arr
536            .as_any()
537            .downcast_ref::<BooleanArray>()
538            .ok_or_else(|| {
539                datafusion_common::DataFusionError::Internal(
540                    "map_filter: lambda must return Boolean".into(),
541                )
542            })?;
543
544        let mut offsets = vec![0i32];
545        let mut keep_indices: Vec<usize> = Vec::new();
546
547        for row in 0..map_arr.len() {
548            let start = map_arr.value_offsets()[row] as usize;
549            let end = map_arr.value_offsets()[row + 1] as usize;
550
551            for i in start..end {
552                if !mask_bool.is_null(i) && mask_bool.value(i) {
553                    keep_indices.push(i);
554                }
555            }
556            offsets.push(keep_indices.len() as i32);
557        }
558
559        let indices = arrow_array::UInt32Array::from(
560            keep_indices.iter().map(|&i| i as u32).collect::<Vec<_>>(),
561        );
562        let new_keys = arrow::compute::take(key_col.as_ref(), &indices, None)?;
563        let new_vals = arrow::compute::take(val_col.as_ref(), &indices, None)?;
564
565        let struct_fields = Fields::from(vec![
566            Field::new("key", key_type, false),
567            Field::new("value", val_type, true),
568        ]);
569        let new_entries = StructArray::try_new(struct_fields, vec![new_keys, new_vals], None)?;
570
571        let entries_field = Field::new("entries", new_entries.data_type().clone(), false);
572        let new_offsets =
573            arrow::buffer::OffsetBuffer::new(arrow::buffer::ScalarBuffer::from(offsets));
574        let new_map = MapArray::try_new(
575            Arc::new(entries_field),
576            new_offsets,
577            new_entries,
578            map_arr.nulls().cloned(),
579            false,
580        )?;
581        Ok(ColumnarValue::Array(Arc::new(new_map)))
582    }
583}
584
585// ══════════════════════════════════════════════════════════════════
586// map_transform_values(map, lambda_str) -> Map
587// ══════════════════════════════════════════════════════════════════
588
589/// `map_transform_values(map, 'v * 2')` — transform map values.
590#[derive(Debug)]
591pub struct MapTransformValues {
592    signature: Signature,
593}
594
595impl MapTransformValues {
596    /// Creates a new `map_transform_values` UDF.
597    #[must_use]
598    pub fn new() -> Self {
599        Self {
600            signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
601        }
602    }
603}
604
605impl Default for MapTransformValues {
606    fn default() -> Self {
607        Self::new()
608    }
609}
610impl PartialEq for MapTransformValues {
611    fn eq(&self, _: &Self) -> bool {
612        true
613    }
614}
615impl Eq for MapTransformValues {}
616impl Hash for MapTransformValues {
617    fn hash<H: Hasher>(&self, s: &mut H) {
618        "map_transform_values".hash(s);
619    }
620}
621
622impl ScalarUDFImpl for MapTransformValues {
623    fn as_any(&self) -> &dyn Any {
624        self
625    }
626    fn name(&self) -> &'static str {
627        "map_transform_values"
628    }
629    fn signature(&self) -> &Signature {
630        &self.signature
631    }
632
633    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
634        Ok(arg_types[0].clone())
635    }
636
637    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
638        let expanded = expand_args(&args.args)?;
639        let map_arr = expanded[0]
640            .as_any()
641            .downcast_ref::<MapArray>()
642            .ok_or_else(|| {
643                datafusion_common::DataFusionError::Internal(
644                    "map_transform_values: first arg must be Map".into(),
645                )
646            })?;
647
648        let lambda_str = scalar_string_value(&args.args[1])?;
649
650        let entries = map_arr.entries();
651        let key_col = entries.column(0);
652        let val_col = entries.column(1);
653
654        let key_type = key_col.data_type().clone();
655
656        let schema = Arc::new(Schema::new(vec![
657            Field::new("k", key_type.clone(), true),
658            Field::new("v", val_col.data_type().clone(), true),
659        ]));
660        let batch = arrow_array::RecordBatch::try_new(
661            schema,
662            vec![Arc::clone(key_col), Arc::clone(val_col)],
663        )?;
664
665        let new_vals = eval_expr_on_batch(&lambda_str, &batch)?;
666
667        let struct_fields = Fields::from(vec![
668            Field::new("key", key_type, false),
669            Field::new("value", new_vals.data_type().clone(), true),
670        ]);
671        let new_entries =
672            StructArray::try_new(struct_fields, vec![Arc::clone(key_col), new_vals], None)?;
673
674        let entries_field = Field::new("entries", new_entries.data_type().clone(), false);
675        let new_map = MapArray::try_new(
676            Arc::new(entries_field),
677            map_arr.offsets().clone(),
678            new_entries,
679            map_arr.nulls().cloned(),
680            false,
681        )?;
682        Ok(ColumnarValue::Array(Arc::new(new_map)))
683    }
684}
685
686#[cfg(test)]
687mod tests {
688    use super::*;
689    use crate::datafusion::create_session_context;
690    use arrow_array::*;
691    use datafusion_common::config::ConfigOptions;
692
693    // ── array_transform ─────────────────────────────────────────
694
695    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
696    async fn test_array_transform_add_one() {
697        let values = Int64Array::from(vec![1, 2, 3, 4, 5, 6]);
698        let offsets =
699            arrow::buffer::OffsetBuffer::new(arrow::buffer::ScalarBuffer::from(vec![0i32, 3, 6]));
700        let list = ListArray::try_new(
701            Arc::new(Field::new("item", DataType::Int64, true)),
702            offsets,
703            Arc::new(values),
704            None,
705        )
706        .unwrap();
707
708        let udf = ArrayTransform::new();
709        let result = udf
710            .invoke_with_args(ScalarFunctionArgs {
711                args: vec![
712                    ColumnarValue::Array(Arc::new(list)),
713                    ColumnarValue::Scalar(datafusion_common::ScalarValue::Utf8(Some(
714                        "x + 1".into(),
715                    ))),
716                ],
717                number_rows: 0,
718                arg_fields: vec![],
719                return_field: Arc::new(Field::new(
720                    "output",
721                    DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
722                    true,
723                )),
724                config_options: Arc::new(ConfigOptions::default()),
725            })
726            .unwrap();
727
728        if let ColumnarValue::Array(arr) = result {
729            let la = arr.as_any().downcast_ref::<ListArray>().unwrap();
730            assert_eq!(la.len(), 2);
731            let row0 = la.value(0);
732            let r0 = row0.as_any().downcast_ref::<Int64Array>().unwrap();
733            assert_eq!(r0.value(0), 2);
734            assert_eq!(r0.value(1), 3);
735            assert_eq!(r0.value(2), 4);
736        } else {
737            panic!("expected Array");
738        }
739    }
740
741    // ── array_filter ────────────────────────────────────────────
742
743    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
744    async fn test_array_filter_positive() {
745        let values = Int64Array::from(vec![-1, 2, -3, 4]);
746        let offsets =
747            arrow::buffer::OffsetBuffer::new(arrow::buffer::ScalarBuffer::from(vec![0i32, 4]));
748        let list = ListArray::try_new(
749            Arc::new(Field::new("item", DataType::Int64, true)),
750            offsets,
751            Arc::new(values),
752            None,
753        )
754        .unwrap();
755
756        let udf = ArrayFilter::new();
757        let result = udf
758            .invoke_with_args(ScalarFunctionArgs {
759                args: vec![
760                    ColumnarValue::Array(Arc::new(list)),
761                    ColumnarValue::Scalar(datafusion_common::ScalarValue::Utf8(Some(
762                        "x > 0".into(),
763                    ))),
764                ],
765                number_rows: 0,
766                arg_fields: vec![],
767                return_field: Arc::new(Field::new(
768                    "output",
769                    DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
770                    true,
771                )),
772                config_options: Arc::new(ConfigOptions::default()),
773            })
774            .unwrap();
775
776        if let ColumnarValue::Array(arr) = result {
777            let la = arr.as_any().downcast_ref::<ListArray>().unwrap();
778            let row0 = la.value(0);
779            let r0 = row0.as_any().downcast_ref::<Int64Array>().unwrap();
780            assert_eq!(r0.len(), 2);
781            assert_eq!(r0.value(0), 2);
782            assert_eq!(r0.value(1), 4);
783        }
784    }
785
786    // ── Registration ────────────────────────────────────────────
787
788    #[test]
789    fn test_register_lambda_functions() {
790        use datafusion::execution::FunctionRegistry;
791
792        let ctx = create_session_context();
793        register_lambda_functions(&ctx);
794        assert!(ctx.udf("array_transform").is_ok());
795        assert!(ctx.udf("array_filter").is_ok());
796        assert!(ctx.udf("array_reduce").is_ok());
797        assert!(ctx.udf("map_filter").is_ok());
798        assert!(ctx.udf("map_transform_values").is_ok());
799    }
800}