Skip to main content

laminar_sql/datafusion/
complex_type_udf.rs

1//! Array, Struct, and Map scalar UDFs (F-SCHEMA-015).
2//!
3//! **Tier 1** — Built-in DataFusion array functions are verified via tests.
4//!
5//! **Tier 2** — Custom scalar UDFs for struct/map operations:
6//!
7//! | Function | Signature | Return |
8//! |----------|-----------|--------|
9//! | `struct_extract(struct, field)` | Struct, Utf8 | field type |
10//! | `struct_set(struct, field, value)` | Struct, Utf8, Any | Struct |
11//! | `struct_drop(struct, field)` | Struct, Utf8 | Struct |
12//! | `struct_rename(struct, old, new)` | Struct, Utf8, Utf8 | Struct |
13//! | `struct_merge(s1, s2)` | Struct, Struct | Struct |
14//! | `map_keys(map)` | Map | List\<K\> |
15//! | `map_values(map)` | Map | List\<V\> |
16//! | `map_contains_key(map, key)` | Map, K | Boolean |
17//! | `map_from_arrays(keys, vals)` | List, List | Map |
18
19use std::any::Any;
20use std::hash::{Hash, Hasher};
21use std::sync::Arc;
22
23use arrow::datatypes::DataType;
24use arrow_array::{builder::BooleanBuilder, Array, ArrayRef, MapArray, StructArray};
25use arrow_schema::{Field, Fields};
26use datafusion_common::Result;
27use datafusion_expr::{
28    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
29};
30
31use super::json_udf::expand_args;
32
33/// Registers all complex type UDFs with the given session context.
34pub fn register_complex_type_functions(ctx: &datafusion::prelude::SessionContext) {
35    use datafusion_expr::ScalarUDF;
36
37    ctx.register_udf(ScalarUDF::new_from_impl(StructExtract::new()));
38    ctx.register_udf(ScalarUDF::new_from_impl(StructSet::new()));
39    ctx.register_udf(ScalarUDF::new_from_impl(StructDrop::new()));
40    ctx.register_udf(ScalarUDF::new_from_impl(StructRename::new()));
41    ctx.register_udf(ScalarUDF::new_from_impl(StructMerge::new()));
42    ctx.register_udf(ScalarUDF::new_from_impl(MapKeys::new()));
43    ctx.register_udf(ScalarUDF::new_from_impl(MapValues::new()));
44    ctx.register_udf(ScalarUDF::new_from_impl(MapContainsKey::new()));
45    ctx.register_udf(ScalarUDF::new_from_impl(MapFromArrays::new()));
46}
47
48// ── Helpers ──────────────────────────────────────────────────────
49
50fn scalar_string_value(cv: &ColumnarValue) -> Result<String> {
51    match cv {
52        ColumnarValue::Scalar(s) => {
53            let arr = s.to_array_of_size(1)?;
54            let str_arr = arr
55                .as_any()
56                .downcast_ref::<arrow_array::StringArray>()
57                .ok_or_else(|| {
58                    datafusion_common::DataFusionError::Internal("expected Utf8 argument".into())
59                })?;
60            Ok(str_arr.value(0).to_string())
61        }
62        ColumnarValue::Array(arr) => {
63            let str_arr = arr
64                .as_any()
65                .downcast_ref::<arrow_array::StringArray>()
66                .ok_or_else(|| {
67                    datafusion_common::DataFusionError::Internal("expected Utf8 argument".into())
68                })?;
69            Ok(str_arr.value(0).to_string())
70        }
71    }
72}
73
74// ══════════════════════════════════════════════════════════════════
75// struct_extract(struct, field_name) -> field value
76// ══════════════════════════════════════════════════════════════════
77
78/// `struct_extract(struct, field_name)` — extract a field from a struct column.
79#[derive(Debug)]
80pub struct StructExtract {
81    signature: Signature,
82}
83
84impl StructExtract {
85    /// Creates a new `struct_extract` UDF.
86    #[must_use]
87    pub fn new() -> Self {
88        Self {
89            signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
90        }
91    }
92}
93
94impl Default for StructExtract {
95    fn default() -> Self {
96        Self::new()
97    }
98}
99
100impl PartialEq for StructExtract {
101    fn eq(&self, _other: &Self) -> bool {
102        true
103    }
104}
105impl Eq for StructExtract {}
106impl Hash for StructExtract {
107    fn hash<H: Hasher>(&self, state: &mut H) {
108        "struct_extract".hash(state);
109    }
110}
111
112impl ScalarUDFImpl for StructExtract {
113    fn as_any(&self) -> &dyn Any {
114        self
115    }
116    fn name(&self) -> &'static str {
117        "struct_extract"
118    }
119    fn signature(&self) -> &Signature {
120        &self.signature
121    }
122
123    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
124        // Dynamic return type based on the struct field.
125        // Fallback to Utf8 — actual type resolution happens at plan time.
126        Ok(DataType::Utf8)
127    }
128
129    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
130        let expanded = expand_args(&args.args)?;
131        let struct_arr = expanded[0]
132            .as_any()
133            .downcast_ref::<StructArray>()
134            .ok_or_else(|| {
135                datafusion_common::DataFusionError::Internal(
136                    "struct_extract: first arg must be Struct".into(),
137                )
138            })?;
139
140        let field_name = scalar_string_value(&args.args[1])?;
141
142        let (idx, _) = struct_arr.fields().find(&field_name).ok_or_else(|| {
143            datafusion_common::DataFusionError::Internal(format!(
144                "struct_extract: field '{field_name}' not found in struct"
145            ))
146        })?;
147
148        Ok(ColumnarValue::Array(struct_arr.column(idx).clone()))
149    }
150}
151
152// ══════════════════════════════════════════════════════════════════
153// struct_set(struct, field_name, value) -> struct
154// ══════════════════════════════════════════════════════════════════
155
156/// `struct_set(struct, field_name, value)` — set/add a field in a struct.
157#[derive(Debug)]
158pub struct StructSet {
159    signature: Signature,
160}
161
162impl StructSet {
163    /// Creates a new `struct_set` UDF.
164    #[must_use]
165    pub fn new() -> Self {
166        Self {
167            signature: Signature::new(TypeSignature::Any(3), Volatility::Immutable),
168        }
169    }
170}
171
172impl Default for StructSet {
173    fn default() -> Self {
174        Self::new()
175    }
176}
177impl PartialEq for StructSet {
178    fn eq(&self, _other: &Self) -> bool {
179        true
180    }
181}
182impl Eq for StructSet {}
183impl Hash for StructSet {
184    fn hash<H: Hasher>(&self, state: &mut H) {
185        "struct_set".hash(state);
186    }
187}
188
189impl ScalarUDFImpl for StructSet {
190    fn as_any(&self) -> &dyn Any {
191        self
192    }
193    fn name(&self) -> &'static str {
194        "struct_set"
195    }
196    fn signature(&self) -> &Signature {
197        &self.signature
198    }
199
200    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
201        // Return type is a struct; exact fields depend on input.
202        Ok(DataType::Utf8)
203    }
204
205    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
206        let expanded = expand_args(&args.args)?;
207        let struct_arr = expanded[0]
208            .as_any()
209            .downcast_ref::<StructArray>()
210            .ok_or_else(|| {
211                datafusion_common::DataFusionError::Internal(
212                    "struct_set: first arg must be Struct".into(),
213                )
214            })?;
215
216        let field_name = scalar_string_value(&args.args[1])?;
217        let new_value = Arc::clone(&expanded[2]);
218
219        let mut new_fields: Vec<Arc<Field>> = Vec::new();
220        let mut new_columns: Vec<ArrayRef> = Vec::new();
221        let mut replaced = false;
222
223        for (i, field) in struct_arr.fields().iter().enumerate() {
224            if field.name() == &field_name {
225                new_fields.push(Arc::new(Field::new(
226                    &field_name,
227                    new_value.data_type().clone(),
228                    new_value.null_count() > 0,
229                )));
230                new_columns.push(Arc::clone(&new_value));
231                replaced = true;
232            } else {
233                new_fields.push(Arc::clone(field));
234                new_columns.push(struct_arr.column(i).clone());
235            }
236        }
237
238        if !replaced {
239            new_fields.push(Arc::new(Field::new(
240                &field_name,
241                new_value.data_type().clone(),
242                new_value.null_count() > 0,
243            )));
244            new_columns.push(new_value);
245        }
246
247        let result = StructArray::try_new(
248            Fields::from(new_fields),
249            new_columns,
250            struct_arr.nulls().cloned(),
251        )?;
252        Ok(ColumnarValue::Array(Arc::new(result)))
253    }
254}
255
256// ══════════════════════════════════════════════════════════════════
257// struct_drop(struct, field_name) -> struct
258// ══════════════════════════════════════════════════════════════════
259
260/// `struct_drop(struct, field_name)` — remove a field from a struct.
261#[derive(Debug)]
262pub struct StructDrop {
263    signature: Signature,
264}
265
266impl StructDrop {
267    /// Creates a new `struct_drop` UDF.
268    #[must_use]
269    pub fn new() -> Self {
270        Self {
271            signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
272        }
273    }
274}
275
276impl Default for StructDrop {
277    fn default() -> Self {
278        Self::new()
279    }
280}
281impl PartialEq for StructDrop {
282    fn eq(&self, _other: &Self) -> bool {
283        true
284    }
285}
286impl Eq for StructDrop {}
287impl Hash for StructDrop {
288    fn hash<H: Hasher>(&self, state: &mut H) {
289        "struct_drop".hash(state);
290    }
291}
292
293impl ScalarUDFImpl for StructDrop {
294    fn as_any(&self) -> &dyn Any {
295        self
296    }
297    fn name(&self) -> &'static str {
298        "struct_drop"
299    }
300    fn signature(&self) -> &Signature {
301        &self.signature
302    }
303
304    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
305        Ok(DataType::Utf8)
306    }
307
308    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
309        let expanded = expand_args(&args.args)?;
310        let struct_arr = expanded[0]
311            .as_any()
312            .downcast_ref::<StructArray>()
313            .ok_or_else(|| {
314                datafusion_common::DataFusionError::Internal(
315                    "struct_drop: first arg must be Struct".into(),
316                )
317            })?;
318
319        let field_name = scalar_string_value(&args.args[1])?;
320
321        let mut new_fields: Vec<Arc<Field>> = Vec::new();
322        let mut new_columns: Vec<ArrayRef> = Vec::new();
323
324        for (i, field) in struct_arr.fields().iter().enumerate() {
325            if field.name() != &field_name {
326                new_fields.push(Arc::clone(field));
327                new_columns.push(struct_arr.column(i).clone());
328            }
329        }
330
331        if new_fields.is_empty() {
332            return Err(datafusion_common::DataFusionError::Internal(
333                "struct_drop: cannot drop all fields from struct".into(),
334            ));
335        }
336
337        let result = StructArray::try_new(
338            Fields::from(new_fields),
339            new_columns,
340            struct_arr.nulls().cloned(),
341        )?;
342        Ok(ColumnarValue::Array(Arc::new(result)))
343    }
344}
345
346// ══════════════════════════════════════════════════════════════════
347// struct_rename(struct, old_name, new_name) -> struct
348// ══════════════════════════════════════════════════════════════════
349
350/// `struct_rename(struct, old_name, new_name)` — rename a struct field.
351#[derive(Debug)]
352pub struct StructRename {
353    signature: Signature,
354}
355
356impl StructRename {
357    /// Creates a new `struct_rename` UDF.
358    #[must_use]
359    pub fn new() -> Self {
360        Self {
361            signature: Signature::new(TypeSignature::Any(3), Volatility::Immutable),
362        }
363    }
364}
365
366impl Default for StructRename {
367    fn default() -> Self {
368        Self::new()
369    }
370}
371impl PartialEq for StructRename {
372    fn eq(&self, _other: &Self) -> bool {
373        true
374    }
375}
376impl Eq for StructRename {}
377impl Hash for StructRename {
378    fn hash<H: Hasher>(&self, state: &mut H) {
379        "struct_rename".hash(state);
380    }
381}
382
383impl ScalarUDFImpl for StructRename {
384    fn as_any(&self) -> &dyn Any {
385        self
386    }
387    fn name(&self) -> &'static str {
388        "struct_rename"
389    }
390    fn signature(&self) -> &Signature {
391        &self.signature
392    }
393
394    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
395        Ok(DataType::Utf8)
396    }
397
398    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
399        let expanded = expand_args(&args.args)?;
400        let struct_arr = expanded[0]
401            .as_any()
402            .downcast_ref::<StructArray>()
403            .ok_or_else(|| {
404                datafusion_common::DataFusionError::Internal(
405                    "struct_rename: first arg must be Struct".into(),
406                )
407            })?;
408
409        let old_name = scalar_string_value(&args.args[1])?;
410        let new_name = scalar_string_value(&args.args[2])?;
411
412        let mut new_fields: Vec<Arc<Field>> = Vec::new();
413        let mut new_columns: Vec<ArrayRef> = Vec::new();
414
415        for (i, field) in struct_arr.fields().iter().enumerate() {
416            if field.name() == &old_name {
417                new_fields.push(Arc::new(Field::new(
418                    &new_name,
419                    field.data_type().clone(),
420                    field.is_nullable(),
421                )));
422            } else {
423                new_fields.push(Arc::clone(field));
424            }
425            new_columns.push(struct_arr.column(i).clone());
426        }
427
428        let result = StructArray::try_new(
429            Fields::from(new_fields),
430            new_columns,
431            struct_arr.nulls().cloned(),
432        )?;
433        Ok(ColumnarValue::Array(Arc::new(result)))
434    }
435}
436
437// ══════════════════════════════════════════════════════════════════
438// struct_merge(s1, s2) -> struct
439// ══════════════════════════════════════════════════════════════════
440
441/// `struct_merge(s1, s2)` — merge two structs (s2 fields override s1).
442#[derive(Debug)]
443pub struct StructMerge {
444    signature: Signature,
445}
446
447impl StructMerge {
448    /// Creates a new `struct_merge` UDF.
449    #[must_use]
450    pub fn new() -> Self {
451        Self {
452            signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
453        }
454    }
455}
456
457impl Default for StructMerge {
458    fn default() -> Self {
459        Self::new()
460    }
461}
462impl PartialEq for StructMerge {
463    fn eq(&self, _other: &Self) -> bool {
464        true
465    }
466}
467impl Eq for StructMerge {}
468impl Hash for StructMerge {
469    fn hash<H: Hasher>(&self, state: &mut H) {
470        "struct_merge".hash(state);
471    }
472}
473
474impl ScalarUDFImpl for StructMerge {
475    fn as_any(&self) -> &dyn Any {
476        self
477    }
478    fn name(&self) -> &'static str {
479        "struct_merge"
480    }
481    fn signature(&self) -> &Signature {
482        &self.signature
483    }
484
485    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
486        Ok(DataType::Utf8)
487    }
488
489    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
490        let expanded = expand_args(&args.args)?;
491        let s1 = expanded[0]
492            .as_any()
493            .downcast_ref::<StructArray>()
494            .ok_or_else(|| {
495                datafusion_common::DataFusionError::Internal(
496                    "struct_merge: first arg must be Struct".into(),
497                )
498            })?;
499        let s2 = expanded[1]
500            .as_any()
501            .downcast_ref::<StructArray>()
502            .ok_or_else(|| {
503                datafusion_common::DataFusionError::Internal(
504                    "struct_merge: second arg must be Struct".into(),
505                )
506            })?;
507
508        let mut new_fields: Vec<Arc<Field>> = Vec::new();
509        let mut new_columns: Vec<ArrayRef> = Vec::new();
510
511        // Collect s2 field names for override detection.
512        let s2_names: Vec<&str> = s2.fields().iter().map(|f| f.name().as_str()).collect();
513
514        // Add s1 fields that aren't overridden by s2.
515        for (i, field) in s1.fields().iter().enumerate() {
516            if !s2_names.contains(&field.name().as_str()) {
517                new_fields.push(Arc::clone(field));
518                new_columns.push(s1.column(i).clone());
519            }
520        }
521
522        // Add all s2 fields.
523        for (i, field) in s2.fields().iter().enumerate() {
524            new_fields.push(Arc::clone(field));
525            new_columns.push(s2.column(i).clone());
526        }
527
528        let result = StructArray::try_new(Fields::from(new_fields), new_columns, None)?;
529        Ok(ColumnarValue::Array(Arc::new(result)))
530    }
531}
532
533// ══════════════════════════════════════════════════════════════════
534// map_keys(map) -> List<K>
535// ══════════════════════════════════════════════════════════════════
536
537/// `map_keys(map)` — extract keys from a map as a list.
538#[derive(Debug)]
539pub struct MapKeys {
540    signature: Signature,
541}
542
543impl MapKeys {
544    /// Creates a new `map_keys` UDF.
545    #[must_use]
546    pub fn new() -> Self {
547        Self {
548            signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
549        }
550    }
551}
552
553impl Default for MapKeys {
554    fn default() -> Self {
555        Self::new()
556    }
557}
558impl PartialEq for MapKeys {
559    fn eq(&self, _other: &Self) -> bool {
560        true
561    }
562}
563impl Eq for MapKeys {}
564impl Hash for MapKeys {
565    fn hash<H: Hasher>(&self, state: &mut H) {
566        "map_keys".hash(state);
567    }
568}
569
570impl ScalarUDFImpl for MapKeys {
571    fn as_any(&self) -> &dyn Any {
572        self
573    }
574    fn name(&self) -> &'static str {
575        "map_keys"
576    }
577    fn signature(&self) -> &Signature {
578        &self.signature
579    }
580
581    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
582        match &arg_types[0] {
583            DataType::Map(field, _) => {
584                if let DataType::Struct(fields) = field.data_type() {
585                    if let Some(key_field) = fields.first() {
586                        return Ok(DataType::List(Arc::new(Field::new(
587                            "item",
588                            key_field.data_type().clone(),
589                            key_field.is_nullable(),
590                        ))));
591                    }
592                }
593                Ok(DataType::List(Arc::new(Field::new(
594                    "item",
595                    DataType::Utf8,
596                    true,
597                ))))
598            }
599            _ => Ok(DataType::List(Arc::new(Field::new(
600                "item",
601                DataType::Utf8,
602                true,
603            )))),
604        }
605    }
606
607    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
608        let expanded = expand_args(&args.args)?;
609        let map_arr = expanded[0]
610            .as_any()
611            .downcast_ref::<MapArray>()
612            .ok_or_else(|| {
613                datafusion_common::DataFusionError::Internal("map_keys: arg must be Map".into())
614            })?;
615
616        Ok(ColumnarValue::Array(Arc::new(map_arr.keys().clone())))
617    }
618}
619
620// ══════════════════════════════════════════════════════════════════
621// map_values(map) -> List<V>
622// ══════════════════════════════════════════════════════════════════
623
624/// `map_values(map)` — extract values from a map as a list.
625#[derive(Debug)]
626pub struct MapValues {
627    signature: Signature,
628}
629
630impl MapValues {
631    /// Creates a new `map_values` UDF.
632    #[must_use]
633    pub fn new() -> Self {
634        Self {
635            signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
636        }
637    }
638}
639
640impl Default for MapValues {
641    fn default() -> Self {
642        Self::new()
643    }
644}
645impl PartialEq for MapValues {
646    fn eq(&self, _other: &Self) -> bool {
647        true
648    }
649}
650impl Eq for MapValues {}
651impl Hash for MapValues {
652    fn hash<H: Hasher>(&self, state: &mut H) {
653        "map_values".hash(state);
654    }
655}
656
657impl ScalarUDFImpl for MapValues {
658    fn as_any(&self) -> &dyn Any {
659        self
660    }
661    fn name(&self) -> &'static str {
662        "map_values"
663    }
664    fn signature(&self) -> &Signature {
665        &self.signature
666    }
667
668    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
669        match &arg_types[0] {
670            DataType::Map(field, _) => {
671                if let DataType::Struct(fields) = field.data_type() {
672                    if fields.len() >= 2 {
673                        let val_field = &fields[1];
674                        return Ok(DataType::List(Arc::new(Field::new(
675                            "item",
676                            val_field.data_type().clone(),
677                            val_field.is_nullable(),
678                        ))));
679                    }
680                }
681                Ok(DataType::List(Arc::new(Field::new(
682                    "item",
683                    DataType::Utf8,
684                    true,
685                ))))
686            }
687            _ => Ok(DataType::List(Arc::new(Field::new(
688                "item",
689                DataType::Utf8,
690                true,
691            )))),
692        }
693    }
694
695    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
696        let expanded = expand_args(&args.args)?;
697        let map_arr = expanded[0]
698            .as_any()
699            .downcast_ref::<MapArray>()
700            .ok_or_else(|| {
701                datafusion_common::DataFusionError::Internal("map_values: arg must be Map".into())
702            })?;
703
704        Ok(ColumnarValue::Array(Arc::new(map_arr.values().clone())))
705    }
706}
707
708// ══════════════════════════════════════════════════════════════════
709// map_contains_key(map, key) -> Boolean
710// ══════════════════════════════════════════════════════════════════
711
712/// `map_contains_key(map, key)` — check if a map contains a given key.
713#[derive(Debug)]
714pub struct MapContainsKey {
715    signature: Signature,
716}
717
718impl MapContainsKey {
719    /// Creates a new `map_contains_key` UDF.
720    #[must_use]
721    pub fn new() -> Self {
722        Self {
723            signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
724        }
725    }
726}
727
728impl Default for MapContainsKey {
729    fn default() -> Self {
730        Self::new()
731    }
732}
733impl PartialEq for MapContainsKey {
734    fn eq(&self, _other: &Self) -> bool {
735        true
736    }
737}
738impl Eq for MapContainsKey {}
739impl Hash for MapContainsKey {
740    fn hash<H: Hasher>(&self, state: &mut H) {
741        "map_contains_key".hash(state);
742    }
743}
744
745impl ScalarUDFImpl for MapContainsKey {
746    fn as_any(&self) -> &dyn Any {
747        self
748    }
749    fn name(&self) -> &'static str {
750        "map_contains_key"
751    }
752    fn signature(&self) -> &Signature {
753        &self.signature
754    }
755
756    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
757        Ok(DataType::Boolean)
758    }
759
760    #[allow(clippy::cast_sign_loss)]
761    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
762        let expanded = expand_args(&args.args)?;
763        let map_arr = expanded[0]
764            .as_any()
765            .downcast_ref::<MapArray>()
766            .ok_or_else(|| {
767                datafusion_common::DataFusionError::Internal(
768                    "map_contains_key: first arg must be Map".into(),
769                )
770            })?;
771
772        let search_key = expanded[1]
773            .as_any()
774            .downcast_ref::<arrow_array::StringArray>()
775            .ok_or_else(|| {
776                datafusion_common::DataFusionError::Internal(
777                    "map_contains_key: second arg must be Utf8".into(),
778                )
779            })?;
780
781        let keys_col = map_arr.keys();
782        let keys_str = keys_col.as_any().downcast_ref::<arrow_array::StringArray>();
783
784        let mut builder = BooleanBuilder::with_capacity(map_arr.len());
785
786        for row in 0..map_arr.len() {
787            if map_arr.is_null(row) || search_key.is_null(row) {
788                builder.append_null();
789                continue;
790            }
791
792            let target = search_key.value(row);
793            let start = map_arr.value_offsets()[row] as usize;
794            let end = map_arr.value_offsets()[row + 1] as usize;
795
796            let found = if let Some(ks) = keys_str {
797                (start..end).any(|i| !ks.is_null(i) && ks.value(i) == target)
798            } else {
799                false
800            };
801
802            builder.append_value(found);
803        }
804
805        Ok(ColumnarValue::Array(Arc::new(builder.finish())))
806    }
807}
808
809// ══════════════════════════════════════════════════════════════════
810// map_from_arrays(keys, values) -> Map
811// ══════════════════════════════════════════════════════════════════
812
813/// `map_from_arrays(keys, values)` — construct a map from key/value arrays.
814#[derive(Debug)]
815pub struct MapFromArrays {
816    signature: Signature,
817}
818
819impl MapFromArrays {
820    /// Creates a new `map_from_arrays` UDF.
821    #[must_use]
822    pub fn new() -> Self {
823        Self {
824            signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
825        }
826    }
827}
828
829impl Default for MapFromArrays {
830    fn default() -> Self {
831        Self::new()
832    }
833}
834impl PartialEq for MapFromArrays {
835    fn eq(&self, _other: &Self) -> bool {
836        true
837    }
838}
839impl Eq for MapFromArrays {}
840impl Hash for MapFromArrays {
841    fn hash<H: Hasher>(&self, state: &mut H) {
842        "map_from_arrays".hash(state);
843    }
844}
845
846impl ScalarUDFImpl for MapFromArrays {
847    fn as_any(&self) -> &dyn Any {
848        self
849    }
850    fn name(&self) -> &'static str {
851        "map_from_arrays"
852    }
853    fn signature(&self) -> &Signature {
854        &self.signature
855    }
856
857    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
858        let key_type = match &arg_types[0] {
859            DataType::List(f) => f.data_type().clone(),
860            _ => DataType::Utf8,
861        };
862        let val_type = match &arg_types[1] {
863            DataType::List(f) => f.data_type().clone(),
864            _ => DataType::Utf8,
865        };
866
867        let entries_field = Field::new(
868            "entries",
869            DataType::Struct(Fields::from(vec![
870                Field::new("key", key_type, false),
871                Field::new("value", val_type, true),
872            ])),
873            false,
874        );
875        Ok(DataType::Map(Arc::new(entries_field), false))
876    }
877
878    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
879        let expanded = expand_args(&args.args)?;
880        let keys_list = expanded[0]
881            .as_any()
882            .downcast_ref::<arrow_array::ListArray>()
883            .ok_or_else(|| {
884                datafusion_common::DataFusionError::Internal(
885                    "map_from_arrays: first arg must be List".into(),
886                )
887            })?;
888        let values_list = expanded[1]
889            .as_any()
890            .downcast_ref::<arrow_array::ListArray>()
891            .ok_or_else(|| {
892                datafusion_common::DataFusionError::Internal(
893                    "map_from_arrays: second arg must be List".into(),
894                )
895            })?;
896
897        // Build the MapArray from key/value ListArrays.
898        let offsets = keys_list.offsets().clone();
899        let key_values = keys_list.values().clone();
900        let val_values = values_list.values().clone();
901
902        let key_field = Field::new("key", key_values.data_type().clone(), false);
903        let val_field = Field::new("value", val_values.data_type().clone(), true);
904        let struct_fields = Fields::from(vec![key_field.clone(), val_field.clone()]);
905        let entries = StructArray::try_new(struct_fields, vec![key_values, val_values], None)?;
906
907        let entries_field = Field::new("entries", entries.data_type().clone(), false);
908        let map = MapArray::try_new(Arc::new(entries_field), offsets, entries, None, false)?;
909        Ok(ColumnarValue::Array(Arc::new(map)))
910    }
911}
912
913#[cfg(test)]
914mod tests {
915    use super::*;
916    use crate::datafusion::create_session_context;
917    use arrow_array::builder::*;
918    use arrow_array::*;
919    use arrow_schema::{DataType, Field, Fields};
920    use datafusion_common::config::ConfigOptions;
921
922    // ── Tier 1: Verify DataFusion built-in array functions ──────
923
924    #[tokio::test]
925    async fn test_builtin_array_length() {
926        let ctx = create_session_context();
927        let df = ctx
928            .sql("SELECT array_length(make_array(1, 2, 3))")
929            .await
930            .unwrap();
931        let batches = df.collect().await.unwrap();
932        assert_eq!(batches[0].num_rows(), 1);
933    }
934
935    #[tokio::test]
936    async fn test_builtin_array_sort() {
937        let ctx = create_session_context();
938        let df = ctx
939            .sql("SELECT array_sort(make_array(3, 1, 2))")
940            .await
941            .unwrap();
942        let batches = df.collect().await.unwrap();
943        assert_eq!(batches[0].num_rows(), 1);
944    }
945
946    #[tokio::test]
947    async fn test_builtin_array_distinct() {
948        let ctx = create_session_context();
949        let df = ctx
950            .sql("SELECT array_distinct(make_array(1, 2, 2, 3))")
951            .await
952            .unwrap();
953        let batches = df.collect().await.unwrap();
954        assert_eq!(batches[0].num_rows(), 1);
955    }
956
957    // ── Tier 2: struct_extract ──────────────────────────────────
958
959    #[test]
960    fn test_struct_extract() {
961        let fields = Fields::from(vec![
962            Field::new("a", DataType::Int64, false),
963            Field::new("b", DataType::Utf8, true),
964        ]);
965        let struct_arr = StructArray::try_new(
966            fields,
967            vec![
968                Arc::new(Int64Array::from(vec![1, 2, 3])),
969                Arc::new(StringArray::from(vec!["x", "y", "z"])),
970            ],
971            None,
972        )
973        .unwrap();
974
975        let udf = StructExtract::new();
976        let result = udf
977            .invoke_with_args(ScalarFunctionArgs {
978                args: vec![
979                    ColumnarValue::Array(Arc::new(struct_arr)),
980                    ColumnarValue::Scalar(datafusion_common::ScalarValue::Utf8(Some("b".into()))),
981                ],
982                number_rows: 0,
983                arg_fields: vec![],
984                return_field: Arc::new(Field::new("output", DataType::Utf8, true)),
985                config_options: Arc::new(ConfigOptions::default()),
986            })
987            .unwrap();
988
989        if let ColumnarValue::Array(arr) = result {
990            let str_arr = arr.as_any().downcast_ref::<StringArray>().unwrap();
991            assert_eq!(str_arr.value(0), "x");
992            assert_eq!(str_arr.value(1), "y");
993            assert_eq!(str_arr.value(2), "z");
994        } else {
995            panic!("expected Array");
996        }
997    }
998
999    // ── Tier 2: struct_drop ─────────────────────────────────────
1000
1001    #[test]
1002    fn test_struct_drop() {
1003        let fields = Fields::from(vec![
1004            Field::new("a", DataType::Int64, false),
1005            Field::new("b", DataType::Utf8, true),
1006        ]);
1007        let struct_arr = StructArray::try_new(
1008            fields,
1009            vec![
1010                Arc::new(Int64Array::from(vec![1])),
1011                Arc::new(StringArray::from(vec!["x"])),
1012            ],
1013            None,
1014        )
1015        .unwrap();
1016
1017        let udf = StructDrop::new();
1018        let result = udf
1019            .invoke_with_args(ScalarFunctionArgs {
1020                args: vec![
1021                    ColumnarValue::Array(Arc::new(struct_arr)),
1022                    ColumnarValue::Scalar(datafusion_common::ScalarValue::Utf8(Some("b".into()))),
1023                ],
1024                number_rows: 0,
1025                arg_fields: vec![],
1026                return_field: Arc::new(Field::new("output", DataType::Utf8, true)),
1027                config_options: Arc::new(ConfigOptions::default()),
1028            })
1029            .unwrap();
1030
1031        if let ColumnarValue::Array(arr) = result {
1032            let s = arr.as_any().downcast_ref::<StructArray>().unwrap();
1033            assert_eq!(s.num_columns(), 1);
1034            assert_eq!(s.fields()[0].name(), "a");
1035        } else {
1036            panic!("expected Array");
1037        }
1038    }
1039
1040    // ── Tier 2: struct_rename ───────────────────────────────────
1041
1042    #[test]
1043    fn test_struct_rename() {
1044        let fields = Fields::from(vec![Field::new("old_name", DataType::Int64, false)]);
1045        let struct_arr =
1046            StructArray::try_new(fields, vec![Arc::new(Int64Array::from(vec![42]))], None).unwrap();
1047
1048        let udf = StructRename::new();
1049        let result = udf
1050            .invoke_with_args(ScalarFunctionArgs {
1051                args: vec![
1052                    ColumnarValue::Array(Arc::new(struct_arr)),
1053                    ColumnarValue::Scalar(datafusion_common::ScalarValue::Utf8(Some(
1054                        "old_name".into(),
1055                    ))),
1056                    ColumnarValue::Scalar(datafusion_common::ScalarValue::Utf8(Some(
1057                        "new_name".into(),
1058                    ))),
1059                ],
1060                number_rows: 0,
1061                arg_fields: vec![],
1062                return_field: Arc::new(Field::new("output", DataType::Utf8, true)),
1063                config_options: Arc::new(ConfigOptions::default()),
1064            })
1065            .unwrap();
1066
1067        if let ColumnarValue::Array(arr) = result {
1068            let s = arr.as_any().downcast_ref::<StructArray>().unwrap();
1069            assert_eq!(s.fields()[0].name(), "new_name");
1070        } else {
1071            panic!("expected Array");
1072        }
1073    }
1074
1075    // ── Tier 2: struct_merge ────────────────────────────────────
1076
1077    #[test]
1078    fn test_struct_merge() {
1079        let s1 = StructArray::try_new(
1080            Fields::from(vec![
1081                Field::new("a", DataType::Int64, false),
1082                Field::new("b", DataType::Utf8, true),
1083            ]),
1084            vec![
1085                Arc::new(Int64Array::from(vec![1])),
1086                Arc::new(StringArray::from(vec!["old"])),
1087            ],
1088            None,
1089        )
1090        .unwrap();
1091
1092        let s2 = StructArray::try_new(
1093            Fields::from(vec![
1094                Field::new("b", DataType::Utf8, true),
1095                Field::new("c", DataType::Float64, false),
1096            ]),
1097            vec![
1098                Arc::new(StringArray::from(vec!["new"])),
1099                Arc::new(Float64Array::from(vec![3.125])),
1100            ],
1101            None,
1102        )
1103        .unwrap();
1104
1105        let udf = StructMerge::new();
1106        let result = udf
1107            .invoke_with_args(ScalarFunctionArgs {
1108                args: vec![
1109                    ColumnarValue::Array(Arc::new(s1)),
1110                    ColumnarValue::Array(Arc::new(s2)),
1111                ],
1112                number_rows: 0,
1113                arg_fields: vec![],
1114                return_field: Arc::new(Field::new("output", DataType::Utf8, true)),
1115                config_options: Arc::new(ConfigOptions::default()),
1116            })
1117            .unwrap();
1118
1119        if let ColumnarValue::Array(arr) = result {
1120            let s = arr.as_any().downcast_ref::<StructArray>().unwrap();
1121            // "a" from s1, "b" from s2 (override), "c" from s2
1122            assert_eq!(s.num_columns(), 3);
1123            let names: Vec<&str> = s.fields().iter().map(|f| f.name().as_str()).collect();
1124            assert_eq!(names, vec!["a", "b", "c"]);
1125        } else {
1126            panic!("expected Array");
1127        }
1128    }
1129
1130    // ── Tier 2: map_contains_key ────────────────────────────────
1131
1132    #[test]
1133    fn test_map_contains_key() {
1134        // Build a MapArray with one row: {"x": 1, "y": 2}
1135        let key_builder = StringBuilder::new();
1136        let val_builder = Int64Builder::new();
1137        let mut builder = MapBuilder::new(None, key_builder, val_builder);
1138
1139        builder.keys().append_value("x");
1140        builder.values().append_value(1);
1141        builder.keys().append_value("y");
1142        builder.values().append_value(2);
1143        builder.append(true).unwrap();
1144
1145        let map_arr = builder.finish();
1146
1147        let udf = MapContainsKey::new();
1148
1149        // Check for "x" — should be true.
1150        let result = udf
1151            .invoke_with_args(ScalarFunctionArgs {
1152                args: vec![
1153                    ColumnarValue::Array(Arc::new(map_arr.clone())),
1154                    ColumnarValue::Scalar(datafusion_common::ScalarValue::Utf8(Some("x".into()))),
1155                ],
1156                number_rows: 0,
1157                arg_fields: vec![],
1158                return_field: Arc::new(Field::new("output", DataType::Boolean, true)),
1159                config_options: Arc::new(ConfigOptions::default()),
1160            })
1161            .unwrap();
1162
1163        if let ColumnarValue::Array(arr) = result {
1164            let bool_arr = arr.as_any().downcast_ref::<BooleanArray>().unwrap();
1165            assert!(bool_arr.value(0));
1166        }
1167
1168        // Check for "z" — should be false.
1169        let result2 = udf
1170            .invoke_with_args(ScalarFunctionArgs {
1171                args: vec![
1172                    ColumnarValue::Array(Arc::new(map_arr)),
1173                    ColumnarValue::Scalar(datafusion_common::ScalarValue::Utf8(Some("z".into()))),
1174                ],
1175                number_rows: 0,
1176                arg_fields: vec![],
1177                return_field: Arc::new(Field::new("output", DataType::Boolean, true)),
1178                config_options: Arc::new(ConfigOptions::default()),
1179            })
1180            .unwrap();
1181
1182        if let ColumnarValue::Array(arr) = result2 {
1183            let bool_arr = arr.as_any().downcast_ref::<BooleanArray>().unwrap();
1184            assert!(!bool_arr.value(0));
1185        }
1186    }
1187
1188    // ── Registration ────────────────────────────────────────────
1189
1190    #[test]
1191    fn test_register_complex_type_functions() {
1192        use datafusion::execution::FunctionRegistry;
1193
1194        let ctx = create_session_context();
1195        register_complex_type_functions(&ctx);
1196        assert!(ctx.udf("struct_extract").is_ok());
1197        assert!(ctx.udf("struct_set").is_ok());
1198        assert!(ctx.udf("struct_drop").is_ok());
1199        assert!(ctx.udf("struct_rename").is_ok());
1200        assert!(ctx.udf("struct_merge").is_ok());
1201        assert!(ctx.udf("map_keys").is_ok());
1202        assert!(ctx.udf("map_values").is_ok());
1203        assert!(ctx.udf("map_contains_key").is_ok());
1204        assert!(ctx.udf("map_from_arrays").is_ok());
1205    }
1206}