1use std::any::Any;
20use std::hash::{Hash, Hasher};
21use std::sync::Arc;
22
23use arrow::datatypes::DataType;
24use arrow_array::{builder::BooleanBuilder, Array, ArrayRef, MapArray, StructArray};
25use arrow_schema::{Field, Fields};
26use datafusion_common::Result;
27use datafusion_expr::{
28 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
29};
30
31use super::json_udf::expand_args;
32
33pub fn register_complex_type_functions(ctx: &datafusion::prelude::SessionContext) {
35 use datafusion_expr::ScalarUDF;
36
37 ctx.register_udf(ScalarUDF::new_from_impl(StructExtract::new()));
38 ctx.register_udf(ScalarUDF::new_from_impl(StructSet::new()));
39 ctx.register_udf(ScalarUDF::new_from_impl(StructDrop::new()));
40 ctx.register_udf(ScalarUDF::new_from_impl(StructRename::new()));
41 ctx.register_udf(ScalarUDF::new_from_impl(StructMerge::new()));
42 ctx.register_udf(ScalarUDF::new_from_impl(MapKeys::new()));
43 ctx.register_udf(ScalarUDF::new_from_impl(MapValues::new()));
44 ctx.register_udf(ScalarUDF::new_from_impl(MapContainsKey::new()));
45 ctx.register_udf(ScalarUDF::new_from_impl(MapFromArrays::new()));
46}
47
48fn scalar_string_value(cv: &ColumnarValue) -> Result<String> {
51 match cv {
52 ColumnarValue::Scalar(s) => {
53 let arr = s.to_array_of_size(1)?;
54 let str_arr = arr
55 .as_any()
56 .downcast_ref::<arrow_array::StringArray>()
57 .ok_or_else(|| {
58 datafusion_common::DataFusionError::Internal("expected Utf8 argument".into())
59 })?;
60 Ok(str_arr.value(0).to_string())
61 }
62 ColumnarValue::Array(arr) => {
63 let str_arr = arr
64 .as_any()
65 .downcast_ref::<arrow_array::StringArray>()
66 .ok_or_else(|| {
67 datafusion_common::DataFusionError::Internal("expected Utf8 argument".into())
68 })?;
69 Ok(str_arr.value(0).to_string())
70 }
71 }
72}
73
74#[derive(Debug)]
80pub struct StructExtract {
81 signature: Signature,
82}
83
84impl StructExtract {
85 #[must_use]
87 pub fn new() -> Self {
88 Self {
89 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
90 }
91 }
92}
93
94impl Default for StructExtract {
95 fn default() -> Self {
96 Self::new()
97 }
98}
99
100impl PartialEq for StructExtract {
101 fn eq(&self, _other: &Self) -> bool {
102 true
103 }
104}
105impl Eq for StructExtract {}
106impl Hash for StructExtract {
107 fn hash<H: Hasher>(&self, state: &mut H) {
108 "struct_extract".hash(state);
109 }
110}
111
112impl ScalarUDFImpl for StructExtract {
113 fn as_any(&self) -> &dyn Any {
114 self
115 }
116 fn name(&self) -> &'static str {
117 "struct_extract"
118 }
119 fn signature(&self) -> &Signature {
120 &self.signature
121 }
122
123 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
124 Ok(DataType::Utf8)
127 }
128
129 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
130 let expanded = expand_args(&args.args)?;
131 let struct_arr = expanded[0]
132 .as_any()
133 .downcast_ref::<StructArray>()
134 .ok_or_else(|| {
135 datafusion_common::DataFusionError::Internal(
136 "struct_extract: first arg must be Struct".into(),
137 )
138 })?;
139
140 let field_name = scalar_string_value(&args.args[1])?;
141
142 let (idx, _) = struct_arr.fields().find(&field_name).ok_or_else(|| {
143 datafusion_common::DataFusionError::Internal(format!(
144 "struct_extract: field '{field_name}' not found in struct"
145 ))
146 })?;
147
148 Ok(ColumnarValue::Array(struct_arr.column(idx).clone()))
149 }
150}
151
152#[derive(Debug)]
158pub struct StructSet {
159 signature: Signature,
160}
161
162impl StructSet {
163 #[must_use]
165 pub fn new() -> Self {
166 Self {
167 signature: Signature::new(TypeSignature::Any(3), Volatility::Immutable),
168 }
169 }
170}
171
172impl Default for StructSet {
173 fn default() -> Self {
174 Self::new()
175 }
176}
177impl PartialEq for StructSet {
178 fn eq(&self, _other: &Self) -> bool {
179 true
180 }
181}
182impl Eq for StructSet {}
183impl Hash for StructSet {
184 fn hash<H: Hasher>(&self, state: &mut H) {
185 "struct_set".hash(state);
186 }
187}
188
189impl ScalarUDFImpl for StructSet {
190 fn as_any(&self) -> &dyn Any {
191 self
192 }
193 fn name(&self) -> &'static str {
194 "struct_set"
195 }
196 fn signature(&self) -> &Signature {
197 &self.signature
198 }
199
200 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
201 Ok(DataType::Utf8)
203 }
204
205 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
206 let expanded = expand_args(&args.args)?;
207 let struct_arr = expanded[0]
208 .as_any()
209 .downcast_ref::<StructArray>()
210 .ok_or_else(|| {
211 datafusion_common::DataFusionError::Internal(
212 "struct_set: first arg must be Struct".into(),
213 )
214 })?;
215
216 let field_name = scalar_string_value(&args.args[1])?;
217 let new_value = Arc::clone(&expanded[2]);
218
219 let mut new_fields: Vec<Arc<Field>> = Vec::new();
220 let mut new_columns: Vec<ArrayRef> = Vec::new();
221 let mut replaced = false;
222
223 for (i, field) in struct_arr.fields().iter().enumerate() {
224 if field.name() == &field_name {
225 new_fields.push(Arc::new(Field::new(
226 &field_name,
227 new_value.data_type().clone(),
228 new_value.null_count() > 0,
229 )));
230 new_columns.push(Arc::clone(&new_value));
231 replaced = true;
232 } else {
233 new_fields.push(Arc::clone(field));
234 new_columns.push(struct_arr.column(i).clone());
235 }
236 }
237
238 if !replaced {
239 new_fields.push(Arc::new(Field::new(
240 &field_name,
241 new_value.data_type().clone(),
242 new_value.null_count() > 0,
243 )));
244 new_columns.push(new_value);
245 }
246
247 let result = StructArray::try_new(
248 Fields::from(new_fields),
249 new_columns,
250 struct_arr.nulls().cloned(),
251 )?;
252 Ok(ColumnarValue::Array(Arc::new(result)))
253 }
254}
255
256#[derive(Debug)]
262pub struct StructDrop {
263 signature: Signature,
264}
265
266impl StructDrop {
267 #[must_use]
269 pub fn new() -> Self {
270 Self {
271 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
272 }
273 }
274}
275
276impl Default for StructDrop {
277 fn default() -> Self {
278 Self::new()
279 }
280}
281impl PartialEq for StructDrop {
282 fn eq(&self, _other: &Self) -> bool {
283 true
284 }
285}
286impl Eq for StructDrop {}
287impl Hash for StructDrop {
288 fn hash<H: Hasher>(&self, state: &mut H) {
289 "struct_drop".hash(state);
290 }
291}
292
293impl ScalarUDFImpl for StructDrop {
294 fn as_any(&self) -> &dyn Any {
295 self
296 }
297 fn name(&self) -> &'static str {
298 "struct_drop"
299 }
300 fn signature(&self) -> &Signature {
301 &self.signature
302 }
303
304 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
305 Ok(DataType::Utf8)
306 }
307
308 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
309 let expanded = expand_args(&args.args)?;
310 let struct_arr = expanded[0]
311 .as_any()
312 .downcast_ref::<StructArray>()
313 .ok_or_else(|| {
314 datafusion_common::DataFusionError::Internal(
315 "struct_drop: first arg must be Struct".into(),
316 )
317 })?;
318
319 let field_name = scalar_string_value(&args.args[1])?;
320
321 let mut new_fields: Vec<Arc<Field>> = Vec::new();
322 let mut new_columns: Vec<ArrayRef> = Vec::new();
323
324 for (i, field) in struct_arr.fields().iter().enumerate() {
325 if field.name() != &field_name {
326 new_fields.push(Arc::clone(field));
327 new_columns.push(struct_arr.column(i).clone());
328 }
329 }
330
331 if new_fields.is_empty() {
332 return Err(datafusion_common::DataFusionError::Internal(
333 "struct_drop: cannot drop all fields from struct".into(),
334 ));
335 }
336
337 let result = StructArray::try_new(
338 Fields::from(new_fields),
339 new_columns,
340 struct_arr.nulls().cloned(),
341 )?;
342 Ok(ColumnarValue::Array(Arc::new(result)))
343 }
344}
345
346#[derive(Debug)]
352pub struct StructRename {
353 signature: Signature,
354}
355
356impl StructRename {
357 #[must_use]
359 pub fn new() -> Self {
360 Self {
361 signature: Signature::new(TypeSignature::Any(3), Volatility::Immutable),
362 }
363 }
364}
365
366impl Default for StructRename {
367 fn default() -> Self {
368 Self::new()
369 }
370}
371impl PartialEq for StructRename {
372 fn eq(&self, _other: &Self) -> bool {
373 true
374 }
375}
376impl Eq for StructRename {}
377impl Hash for StructRename {
378 fn hash<H: Hasher>(&self, state: &mut H) {
379 "struct_rename".hash(state);
380 }
381}
382
383impl ScalarUDFImpl for StructRename {
384 fn as_any(&self) -> &dyn Any {
385 self
386 }
387 fn name(&self) -> &'static str {
388 "struct_rename"
389 }
390 fn signature(&self) -> &Signature {
391 &self.signature
392 }
393
394 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
395 Ok(DataType::Utf8)
396 }
397
398 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
399 let expanded = expand_args(&args.args)?;
400 let struct_arr = expanded[0]
401 .as_any()
402 .downcast_ref::<StructArray>()
403 .ok_or_else(|| {
404 datafusion_common::DataFusionError::Internal(
405 "struct_rename: first arg must be Struct".into(),
406 )
407 })?;
408
409 let old_name = scalar_string_value(&args.args[1])?;
410 let new_name = scalar_string_value(&args.args[2])?;
411
412 let mut new_fields: Vec<Arc<Field>> = Vec::new();
413 let mut new_columns: Vec<ArrayRef> = Vec::new();
414
415 for (i, field) in struct_arr.fields().iter().enumerate() {
416 if field.name() == &old_name {
417 new_fields.push(Arc::new(Field::new(
418 &new_name,
419 field.data_type().clone(),
420 field.is_nullable(),
421 )));
422 } else {
423 new_fields.push(Arc::clone(field));
424 }
425 new_columns.push(struct_arr.column(i).clone());
426 }
427
428 let result = StructArray::try_new(
429 Fields::from(new_fields),
430 new_columns,
431 struct_arr.nulls().cloned(),
432 )?;
433 Ok(ColumnarValue::Array(Arc::new(result)))
434 }
435}
436
437#[derive(Debug)]
443pub struct StructMerge {
444 signature: Signature,
445}
446
447impl StructMerge {
448 #[must_use]
450 pub fn new() -> Self {
451 Self {
452 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
453 }
454 }
455}
456
457impl Default for StructMerge {
458 fn default() -> Self {
459 Self::new()
460 }
461}
462impl PartialEq for StructMerge {
463 fn eq(&self, _other: &Self) -> bool {
464 true
465 }
466}
467impl Eq for StructMerge {}
468impl Hash for StructMerge {
469 fn hash<H: Hasher>(&self, state: &mut H) {
470 "struct_merge".hash(state);
471 }
472}
473
474impl ScalarUDFImpl for StructMerge {
475 fn as_any(&self) -> &dyn Any {
476 self
477 }
478 fn name(&self) -> &'static str {
479 "struct_merge"
480 }
481 fn signature(&self) -> &Signature {
482 &self.signature
483 }
484
485 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
486 Ok(DataType::Utf8)
487 }
488
489 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
490 let expanded = expand_args(&args.args)?;
491 let s1 = expanded[0]
492 .as_any()
493 .downcast_ref::<StructArray>()
494 .ok_or_else(|| {
495 datafusion_common::DataFusionError::Internal(
496 "struct_merge: first arg must be Struct".into(),
497 )
498 })?;
499 let s2 = expanded[1]
500 .as_any()
501 .downcast_ref::<StructArray>()
502 .ok_or_else(|| {
503 datafusion_common::DataFusionError::Internal(
504 "struct_merge: second arg must be Struct".into(),
505 )
506 })?;
507
508 let mut new_fields: Vec<Arc<Field>> = Vec::new();
509 let mut new_columns: Vec<ArrayRef> = Vec::new();
510
511 let s2_names: Vec<&str> = s2.fields().iter().map(|f| f.name().as_str()).collect();
513
514 for (i, field) in s1.fields().iter().enumerate() {
516 if !s2_names.contains(&field.name().as_str()) {
517 new_fields.push(Arc::clone(field));
518 new_columns.push(s1.column(i).clone());
519 }
520 }
521
522 for (i, field) in s2.fields().iter().enumerate() {
524 new_fields.push(Arc::clone(field));
525 new_columns.push(s2.column(i).clone());
526 }
527
528 let result = StructArray::try_new(Fields::from(new_fields), new_columns, None)?;
529 Ok(ColumnarValue::Array(Arc::new(result)))
530 }
531}
532
533#[derive(Debug)]
539pub struct MapKeys {
540 signature: Signature,
541}
542
543impl MapKeys {
544 #[must_use]
546 pub fn new() -> Self {
547 Self {
548 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
549 }
550 }
551}
552
553impl Default for MapKeys {
554 fn default() -> Self {
555 Self::new()
556 }
557}
558impl PartialEq for MapKeys {
559 fn eq(&self, _other: &Self) -> bool {
560 true
561 }
562}
563impl Eq for MapKeys {}
564impl Hash for MapKeys {
565 fn hash<H: Hasher>(&self, state: &mut H) {
566 "map_keys".hash(state);
567 }
568}
569
570impl ScalarUDFImpl for MapKeys {
571 fn as_any(&self) -> &dyn Any {
572 self
573 }
574 fn name(&self) -> &'static str {
575 "map_keys"
576 }
577 fn signature(&self) -> &Signature {
578 &self.signature
579 }
580
581 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
582 match &arg_types[0] {
583 DataType::Map(field, _) => {
584 if let DataType::Struct(fields) = field.data_type() {
585 if let Some(key_field) = fields.first() {
586 return Ok(DataType::List(Arc::new(Field::new(
587 "item",
588 key_field.data_type().clone(),
589 key_field.is_nullable(),
590 ))));
591 }
592 }
593 Ok(DataType::List(Arc::new(Field::new(
594 "item",
595 DataType::Utf8,
596 true,
597 ))))
598 }
599 _ => Ok(DataType::List(Arc::new(Field::new(
600 "item",
601 DataType::Utf8,
602 true,
603 )))),
604 }
605 }
606
607 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
608 let expanded = expand_args(&args.args)?;
609 let map_arr = expanded[0]
610 .as_any()
611 .downcast_ref::<MapArray>()
612 .ok_or_else(|| {
613 datafusion_common::DataFusionError::Internal("map_keys: arg must be Map".into())
614 })?;
615
616 Ok(ColumnarValue::Array(Arc::new(map_arr.keys().clone())))
617 }
618}
619
620#[derive(Debug)]
626pub struct MapValues {
627 signature: Signature,
628}
629
630impl MapValues {
631 #[must_use]
633 pub fn new() -> Self {
634 Self {
635 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
636 }
637 }
638}
639
640impl Default for MapValues {
641 fn default() -> Self {
642 Self::new()
643 }
644}
645impl PartialEq for MapValues {
646 fn eq(&self, _other: &Self) -> bool {
647 true
648 }
649}
650impl Eq for MapValues {}
651impl Hash for MapValues {
652 fn hash<H: Hasher>(&self, state: &mut H) {
653 "map_values".hash(state);
654 }
655}
656
657impl ScalarUDFImpl for MapValues {
658 fn as_any(&self) -> &dyn Any {
659 self
660 }
661 fn name(&self) -> &'static str {
662 "map_values"
663 }
664 fn signature(&self) -> &Signature {
665 &self.signature
666 }
667
668 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
669 match &arg_types[0] {
670 DataType::Map(field, _) => {
671 if let DataType::Struct(fields) = field.data_type() {
672 if fields.len() >= 2 {
673 let val_field = &fields[1];
674 return Ok(DataType::List(Arc::new(Field::new(
675 "item",
676 val_field.data_type().clone(),
677 val_field.is_nullable(),
678 ))));
679 }
680 }
681 Ok(DataType::List(Arc::new(Field::new(
682 "item",
683 DataType::Utf8,
684 true,
685 ))))
686 }
687 _ => Ok(DataType::List(Arc::new(Field::new(
688 "item",
689 DataType::Utf8,
690 true,
691 )))),
692 }
693 }
694
695 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
696 let expanded = expand_args(&args.args)?;
697 let map_arr = expanded[0]
698 .as_any()
699 .downcast_ref::<MapArray>()
700 .ok_or_else(|| {
701 datafusion_common::DataFusionError::Internal("map_values: arg must be Map".into())
702 })?;
703
704 Ok(ColumnarValue::Array(Arc::new(map_arr.values().clone())))
705 }
706}
707
708#[derive(Debug)]
714pub struct MapContainsKey {
715 signature: Signature,
716}
717
718impl MapContainsKey {
719 #[must_use]
721 pub fn new() -> Self {
722 Self {
723 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
724 }
725 }
726}
727
728impl Default for MapContainsKey {
729 fn default() -> Self {
730 Self::new()
731 }
732}
733impl PartialEq for MapContainsKey {
734 fn eq(&self, _other: &Self) -> bool {
735 true
736 }
737}
738impl Eq for MapContainsKey {}
739impl Hash for MapContainsKey {
740 fn hash<H: Hasher>(&self, state: &mut H) {
741 "map_contains_key".hash(state);
742 }
743}
744
745impl ScalarUDFImpl for MapContainsKey {
746 fn as_any(&self) -> &dyn Any {
747 self
748 }
749 fn name(&self) -> &'static str {
750 "map_contains_key"
751 }
752 fn signature(&self) -> &Signature {
753 &self.signature
754 }
755
756 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
757 Ok(DataType::Boolean)
758 }
759
760 #[allow(clippy::cast_sign_loss)]
761 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
762 let expanded = expand_args(&args.args)?;
763 let map_arr = expanded[0]
764 .as_any()
765 .downcast_ref::<MapArray>()
766 .ok_or_else(|| {
767 datafusion_common::DataFusionError::Internal(
768 "map_contains_key: first arg must be Map".into(),
769 )
770 })?;
771
772 let search_key = expanded[1]
773 .as_any()
774 .downcast_ref::<arrow_array::StringArray>()
775 .ok_or_else(|| {
776 datafusion_common::DataFusionError::Internal(
777 "map_contains_key: second arg must be Utf8".into(),
778 )
779 })?;
780
781 let keys_col = map_arr.keys();
782 let keys_str = keys_col.as_any().downcast_ref::<arrow_array::StringArray>();
783
784 let mut builder = BooleanBuilder::with_capacity(map_arr.len());
785
786 for row in 0..map_arr.len() {
787 if map_arr.is_null(row) || search_key.is_null(row) {
788 builder.append_null();
789 continue;
790 }
791
792 let target = search_key.value(row);
793 let start = map_arr.value_offsets()[row] as usize;
794 let end = map_arr.value_offsets()[row + 1] as usize;
795
796 let found = if let Some(ks) = keys_str {
797 (start..end).any(|i| !ks.is_null(i) && ks.value(i) == target)
798 } else {
799 false
800 };
801
802 builder.append_value(found);
803 }
804
805 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
806 }
807}
808
809#[derive(Debug)]
815pub struct MapFromArrays {
816 signature: Signature,
817}
818
819impl MapFromArrays {
820 #[must_use]
822 pub fn new() -> Self {
823 Self {
824 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
825 }
826 }
827}
828
829impl Default for MapFromArrays {
830 fn default() -> Self {
831 Self::new()
832 }
833}
834impl PartialEq for MapFromArrays {
835 fn eq(&self, _other: &Self) -> bool {
836 true
837 }
838}
839impl Eq for MapFromArrays {}
840impl Hash for MapFromArrays {
841 fn hash<H: Hasher>(&self, state: &mut H) {
842 "map_from_arrays".hash(state);
843 }
844}
845
846impl ScalarUDFImpl for MapFromArrays {
847 fn as_any(&self) -> &dyn Any {
848 self
849 }
850 fn name(&self) -> &'static str {
851 "map_from_arrays"
852 }
853 fn signature(&self) -> &Signature {
854 &self.signature
855 }
856
857 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
858 let key_type = match &arg_types[0] {
859 DataType::List(f) => f.data_type().clone(),
860 _ => DataType::Utf8,
861 };
862 let val_type = match &arg_types[1] {
863 DataType::List(f) => f.data_type().clone(),
864 _ => DataType::Utf8,
865 };
866
867 let entries_field = Field::new(
868 "entries",
869 DataType::Struct(Fields::from(vec![
870 Field::new("key", key_type, false),
871 Field::new("value", val_type, true),
872 ])),
873 false,
874 );
875 Ok(DataType::Map(Arc::new(entries_field), false))
876 }
877
878 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
879 let expanded = expand_args(&args.args)?;
880 let keys_list = expanded[0]
881 .as_any()
882 .downcast_ref::<arrow_array::ListArray>()
883 .ok_or_else(|| {
884 datafusion_common::DataFusionError::Internal(
885 "map_from_arrays: first arg must be List".into(),
886 )
887 })?;
888 let values_list = expanded[1]
889 .as_any()
890 .downcast_ref::<arrow_array::ListArray>()
891 .ok_or_else(|| {
892 datafusion_common::DataFusionError::Internal(
893 "map_from_arrays: second arg must be List".into(),
894 )
895 })?;
896
897 let offsets = keys_list.offsets().clone();
899 let key_values = keys_list.values().clone();
900 let val_values = values_list.values().clone();
901
902 let key_field = Field::new("key", key_values.data_type().clone(), false);
903 let val_field = Field::new("value", val_values.data_type().clone(), true);
904 let struct_fields = Fields::from(vec![key_field.clone(), val_field.clone()]);
905 let entries = StructArray::try_new(struct_fields, vec![key_values, val_values], None)?;
906
907 let entries_field = Field::new("entries", entries.data_type().clone(), false);
908 let map = MapArray::try_new(Arc::new(entries_field), offsets, entries, None, false)?;
909 Ok(ColumnarValue::Array(Arc::new(map)))
910 }
911}
912
913#[cfg(test)]
914mod tests {
915 use super::*;
916 use crate::datafusion::create_session_context;
917 use arrow_array::builder::*;
918 use arrow_array::*;
919 use arrow_schema::{DataType, Field, Fields};
920 use datafusion_common::config::ConfigOptions;
921
922 #[tokio::test]
925 async fn test_builtin_array_length() {
926 let ctx = create_session_context();
927 let df = ctx
928 .sql("SELECT array_length(make_array(1, 2, 3))")
929 .await
930 .unwrap();
931 let batches = df.collect().await.unwrap();
932 assert_eq!(batches[0].num_rows(), 1);
933 }
934
935 #[tokio::test]
936 async fn test_builtin_array_sort() {
937 let ctx = create_session_context();
938 let df = ctx
939 .sql("SELECT array_sort(make_array(3, 1, 2))")
940 .await
941 .unwrap();
942 let batches = df.collect().await.unwrap();
943 assert_eq!(batches[0].num_rows(), 1);
944 }
945
946 #[tokio::test]
947 async fn test_builtin_array_distinct() {
948 let ctx = create_session_context();
949 let df = ctx
950 .sql("SELECT array_distinct(make_array(1, 2, 2, 3))")
951 .await
952 .unwrap();
953 let batches = df.collect().await.unwrap();
954 assert_eq!(batches[0].num_rows(), 1);
955 }
956
957 #[test]
960 fn test_struct_extract() {
961 let fields = Fields::from(vec![
962 Field::new("a", DataType::Int64, false),
963 Field::new("b", DataType::Utf8, true),
964 ]);
965 let struct_arr = StructArray::try_new(
966 fields,
967 vec![
968 Arc::new(Int64Array::from(vec![1, 2, 3])),
969 Arc::new(StringArray::from(vec!["x", "y", "z"])),
970 ],
971 None,
972 )
973 .unwrap();
974
975 let udf = StructExtract::new();
976 let result = udf
977 .invoke_with_args(ScalarFunctionArgs {
978 args: vec![
979 ColumnarValue::Array(Arc::new(struct_arr)),
980 ColumnarValue::Scalar(datafusion_common::ScalarValue::Utf8(Some("b".into()))),
981 ],
982 number_rows: 0,
983 arg_fields: vec![],
984 return_field: Arc::new(Field::new("output", DataType::Utf8, true)),
985 config_options: Arc::new(ConfigOptions::default()),
986 })
987 .unwrap();
988
989 if let ColumnarValue::Array(arr) = result {
990 let str_arr = arr.as_any().downcast_ref::<StringArray>().unwrap();
991 assert_eq!(str_arr.value(0), "x");
992 assert_eq!(str_arr.value(1), "y");
993 assert_eq!(str_arr.value(2), "z");
994 } else {
995 panic!("expected Array");
996 }
997 }
998
999 #[test]
1002 fn test_struct_drop() {
1003 let fields = Fields::from(vec![
1004 Field::new("a", DataType::Int64, false),
1005 Field::new("b", DataType::Utf8, true),
1006 ]);
1007 let struct_arr = StructArray::try_new(
1008 fields,
1009 vec![
1010 Arc::new(Int64Array::from(vec![1])),
1011 Arc::new(StringArray::from(vec!["x"])),
1012 ],
1013 None,
1014 )
1015 .unwrap();
1016
1017 let udf = StructDrop::new();
1018 let result = udf
1019 .invoke_with_args(ScalarFunctionArgs {
1020 args: vec![
1021 ColumnarValue::Array(Arc::new(struct_arr)),
1022 ColumnarValue::Scalar(datafusion_common::ScalarValue::Utf8(Some("b".into()))),
1023 ],
1024 number_rows: 0,
1025 arg_fields: vec![],
1026 return_field: Arc::new(Field::new("output", DataType::Utf8, true)),
1027 config_options: Arc::new(ConfigOptions::default()),
1028 })
1029 .unwrap();
1030
1031 if let ColumnarValue::Array(arr) = result {
1032 let s = arr.as_any().downcast_ref::<StructArray>().unwrap();
1033 assert_eq!(s.num_columns(), 1);
1034 assert_eq!(s.fields()[0].name(), "a");
1035 } else {
1036 panic!("expected Array");
1037 }
1038 }
1039
1040 #[test]
1043 fn test_struct_rename() {
1044 let fields = Fields::from(vec![Field::new("old_name", DataType::Int64, false)]);
1045 let struct_arr =
1046 StructArray::try_new(fields, vec![Arc::new(Int64Array::from(vec![42]))], None).unwrap();
1047
1048 let udf = StructRename::new();
1049 let result = udf
1050 .invoke_with_args(ScalarFunctionArgs {
1051 args: vec![
1052 ColumnarValue::Array(Arc::new(struct_arr)),
1053 ColumnarValue::Scalar(datafusion_common::ScalarValue::Utf8(Some(
1054 "old_name".into(),
1055 ))),
1056 ColumnarValue::Scalar(datafusion_common::ScalarValue::Utf8(Some(
1057 "new_name".into(),
1058 ))),
1059 ],
1060 number_rows: 0,
1061 arg_fields: vec![],
1062 return_field: Arc::new(Field::new("output", DataType::Utf8, true)),
1063 config_options: Arc::new(ConfigOptions::default()),
1064 })
1065 .unwrap();
1066
1067 if let ColumnarValue::Array(arr) = result {
1068 let s = arr.as_any().downcast_ref::<StructArray>().unwrap();
1069 assert_eq!(s.fields()[0].name(), "new_name");
1070 } else {
1071 panic!("expected Array");
1072 }
1073 }
1074
1075 #[test]
1078 fn test_struct_merge() {
1079 let s1 = StructArray::try_new(
1080 Fields::from(vec![
1081 Field::new("a", DataType::Int64, false),
1082 Field::new("b", DataType::Utf8, true),
1083 ]),
1084 vec![
1085 Arc::new(Int64Array::from(vec![1])),
1086 Arc::new(StringArray::from(vec!["old"])),
1087 ],
1088 None,
1089 )
1090 .unwrap();
1091
1092 let s2 = StructArray::try_new(
1093 Fields::from(vec![
1094 Field::new("b", DataType::Utf8, true),
1095 Field::new("c", DataType::Float64, false),
1096 ]),
1097 vec![
1098 Arc::new(StringArray::from(vec!["new"])),
1099 Arc::new(Float64Array::from(vec![3.125])),
1100 ],
1101 None,
1102 )
1103 .unwrap();
1104
1105 let udf = StructMerge::new();
1106 let result = udf
1107 .invoke_with_args(ScalarFunctionArgs {
1108 args: vec![
1109 ColumnarValue::Array(Arc::new(s1)),
1110 ColumnarValue::Array(Arc::new(s2)),
1111 ],
1112 number_rows: 0,
1113 arg_fields: vec![],
1114 return_field: Arc::new(Field::new("output", DataType::Utf8, true)),
1115 config_options: Arc::new(ConfigOptions::default()),
1116 })
1117 .unwrap();
1118
1119 if let ColumnarValue::Array(arr) = result {
1120 let s = arr.as_any().downcast_ref::<StructArray>().unwrap();
1121 assert_eq!(s.num_columns(), 3);
1123 let names: Vec<&str> = s.fields().iter().map(|f| f.name().as_str()).collect();
1124 assert_eq!(names, vec!["a", "b", "c"]);
1125 } else {
1126 panic!("expected Array");
1127 }
1128 }
1129
1130 #[test]
1133 fn test_map_contains_key() {
1134 let key_builder = StringBuilder::new();
1136 let val_builder = Int64Builder::new();
1137 let mut builder = MapBuilder::new(None, key_builder, val_builder);
1138
1139 builder.keys().append_value("x");
1140 builder.values().append_value(1);
1141 builder.keys().append_value("y");
1142 builder.values().append_value(2);
1143 builder.append(true).unwrap();
1144
1145 let map_arr = builder.finish();
1146
1147 let udf = MapContainsKey::new();
1148
1149 let result = udf
1151 .invoke_with_args(ScalarFunctionArgs {
1152 args: vec![
1153 ColumnarValue::Array(Arc::new(map_arr.clone())),
1154 ColumnarValue::Scalar(datafusion_common::ScalarValue::Utf8(Some("x".into()))),
1155 ],
1156 number_rows: 0,
1157 arg_fields: vec![],
1158 return_field: Arc::new(Field::new("output", DataType::Boolean, true)),
1159 config_options: Arc::new(ConfigOptions::default()),
1160 })
1161 .unwrap();
1162
1163 if let ColumnarValue::Array(arr) = result {
1164 let bool_arr = arr.as_any().downcast_ref::<BooleanArray>().unwrap();
1165 assert!(bool_arr.value(0));
1166 }
1167
1168 let result2 = udf
1170 .invoke_with_args(ScalarFunctionArgs {
1171 args: vec![
1172 ColumnarValue::Array(Arc::new(map_arr)),
1173 ColumnarValue::Scalar(datafusion_common::ScalarValue::Utf8(Some("z".into()))),
1174 ],
1175 number_rows: 0,
1176 arg_fields: vec![],
1177 return_field: Arc::new(Field::new("output", DataType::Boolean, true)),
1178 config_options: Arc::new(ConfigOptions::default()),
1179 })
1180 .unwrap();
1181
1182 if let ColumnarValue::Array(arr) = result2 {
1183 let bool_arr = arr.as_any().downcast_ref::<BooleanArray>().unwrap();
1184 assert!(!bool_arr.value(0));
1185 }
1186 }
1187
1188 #[test]
1191 fn test_register_complex_type_functions() {
1192 use datafusion::execution::FunctionRegistry;
1193
1194 let ctx = create_session_context();
1195 register_complex_type_functions(&ctx);
1196 assert!(ctx.udf("struct_extract").is_ok());
1197 assert!(ctx.udf("struct_set").is_ok());
1198 assert!(ctx.udf("struct_drop").is_ok());
1199 assert!(ctx.udf("struct_rename").is_ok());
1200 assert!(ctx.udf("struct_merge").is_ok());
1201 assert!(ctx.udf("map_keys").is_ok());
1202 assert!(ctx.udf("map_values").is_ok());
1203 assert!(ctx.udf("map_contains_key").is_ok());
1204 assert!(ctx.udf("map_from_arrays").is_ok());
1205 }
1206}