1use std::any::Any;
18use std::hash::{Hash, Hasher};
19use std::sync::Arc;
20
21use arrow::datatypes::DataType;
22use arrow_array::{Array, ArrayRef, BooleanArray, ListArray, MapArray, StructArray};
23use arrow_schema::{Field, Fields, Schema};
24use datafusion_common::Result;
25use datafusion_expr::{
26 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
27};
28
29use super::json_udf::expand_args;
30
31pub fn register_lambda_functions(ctx: &datafusion::prelude::SessionContext) {
33 use datafusion_expr::ScalarUDF;
34
35 ctx.register_udf(ScalarUDF::new_from_impl(ArrayTransform::new()));
36 ctx.register_udf(ScalarUDF::new_from_impl(ArrayFilter::new()));
37 ctx.register_udf(ScalarUDF::new_from_impl(ArrayReduce::new()));
38 ctx.register_udf(ScalarUDF::new_from_impl(MapFilter::new()));
39 ctx.register_udf(ScalarUDF::new_from_impl(MapTransformValues::new()));
40}
41
42thread_local! {
43 static LAMBDA_CTX: std::cell::RefCell<Option<datafusion::prelude::SessionContext>> =
49 const { std::cell::RefCell::new(None) };
50}
51
52fn eval_expr_on_batch(sql_expr: &str, batch: &arrow_array::RecordBatch) -> Result<ArrayRef> {
56 let ctx = LAMBDA_CTX.with(|cell| {
59 let mut opt = cell.borrow_mut();
60 opt.get_or_insert_with(datafusion::prelude::SessionContext::new)
61 .clone()
62 });
63
64 let provider =
65 datafusion::datasource::MemTable::try_new(batch.schema(), vec![vec![batch.clone()]])?;
66 let rt = tokio::runtime::Handle::try_current().map_err(|e| {
67 datafusion_common::DataFusionError::Internal(format!(
68 "lambda eval requires tokio runtime: {e}"
69 ))
70 })?;
71 tokio::task::block_in_place(|| {
73 rt.block_on(async {
74 ctx.register_table("__lambda_data", Arc::new(provider))?;
75 let df = ctx
76 .sql(&format!("SELECT {sql_expr} FROM __lambda_data"))
77 .await?;
78 let batches = df.collect().await?;
79 let _ = ctx.deregister_table("__lambda_data");
81 if batches.is_empty() {
82 Err(datafusion_common::DataFusionError::Internal(
83 "lambda expression returned no data".into(),
84 ))
85 } else {
86 let result = arrow::compute::concat_batches(&batches[0].schema(), &batches)?;
88 Ok(result.column(0).clone())
89 }
90 })
91 })
92}
93
94fn scalar_string_value(cv: &ColumnarValue) -> Result<String> {
95 match cv {
96 ColumnarValue::Scalar(s) => {
97 let arr = s.to_array_of_size(1)?;
98 let str_arr = arr
99 .as_any()
100 .downcast_ref::<arrow_array::StringArray>()
101 .ok_or_else(|| {
102 datafusion_common::DataFusionError::Internal("expected Utf8 argument".into())
103 })?;
104 Ok(str_arr.value(0).to_string())
105 }
106 ColumnarValue::Array(arr) => {
107 let str_arr = arr
108 .as_any()
109 .downcast_ref::<arrow_array::StringArray>()
110 .ok_or_else(|| {
111 datafusion_common::DataFusionError::Internal("expected Utf8 argument".into())
112 })?;
113 Ok(str_arr.value(0).to_string())
114 }
115 }
116}
117
118#[derive(Debug)]
124pub struct ArrayTransform {
125 signature: Signature,
126}
127
128impl ArrayTransform {
129 #[must_use]
131 pub fn new() -> Self {
132 Self {
133 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
134 }
135 }
136}
137
138impl Default for ArrayTransform {
139 fn default() -> Self {
140 Self::new()
141 }
142}
143impl PartialEq for ArrayTransform {
144 fn eq(&self, _: &Self) -> bool {
145 true
146 }
147}
148impl Eq for ArrayTransform {}
149impl Hash for ArrayTransform {
150 fn hash<H: Hasher>(&self, s: &mut H) {
151 "array_transform".hash(s);
152 }
153}
154
155impl ScalarUDFImpl for ArrayTransform {
156 fn as_any(&self) -> &dyn Any {
157 self
158 }
159 fn name(&self) -> &'static str {
160 "array_transform"
161 }
162 fn signature(&self) -> &Signature {
163 &self.signature
164 }
165
166 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
167 match &arg_types[0] {
168 DataType::List(f) => Ok(DataType::List(Arc::clone(f))),
169 _ => Ok(DataType::List(Arc::new(Field::new(
170 "item",
171 DataType::Utf8,
172 true,
173 )))),
174 }
175 }
176
177 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
178 let expanded = expand_args(&args.args)?;
179 let list_arr = expanded[0]
180 .as_any()
181 .downcast_ref::<ListArray>()
182 .ok_or_else(|| {
183 datafusion_common::DataFusionError::Internal(
184 "array_transform: first arg must be List".into(),
185 )
186 })?;
187
188 let lambda_str = scalar_string_value(&args.args[1])?;
189 let flat_values = list_arr.values();
190
191 let schema = Arc::new(Schema::new(vec![Field::new(
192 "x",
193 flat_values.data_type().clone(),
194 true,
195 )]));
196 let batch = arrow_array::RecordBatch::try_new(schema, vec![Arc::clone(flat_values)])?;
197
198 let result_arr = eval_expr_on_batch(&lambda_str, &batch)?;
199
200 let new_field = Arc::new(Field::new("item", result_arr.data_type().clone(), true));
201 let new_list = ListArray::try_new(
202 new_field,
203 list_arr.offsets().clone(),
204 result_arr,
205 list_arr.nulls().cloned(),
206 )?;
207 Ok(ColumnarValue::Array(Arc::new(new_list)))
208 }
209}
210
211#[derive(Debug)]
217pub struct ArrayFilter {
218 signature: Signature,
219}
220
221impl ArrayFilter {
222 #[must_use]
224 pub fn new() -> Self {
225 Self {
226 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
227 }
228 }
229}
230
231impl Default for ArrayFilter {
232 fn default() -> Self {
233 Self::new()
234 }
235}
236impl PartialEq for ArrayFilter {
237 fn eq(&self, _: &Self) -> bool {
238 true
239 }
240}
241impl Eq for ArrayFilter {}
242impl Hash for ArrayFilter {
243 fn hash<H: Hasher>(&self, s: &mut H) {
244 "array_filter".hash(s);
245 }
246}
247
248impl ScalarUDFImpl for ArrayFilter {
249 fn as_any(&self) -> &dyn Any {
250 self
251 }
252 fn name(&self) -> &'static str {
253 "array_filter"
254 }
255 fn signature(&self) -> &Signature {
256 &self.signature
257 }
258
259 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
260 match &arg_types[0] {
261 DataType::List(f) => Ok(DataType::List(Arc::clone(f))),
262 _ => Ok(DataType::List(Arc::new(Field::new(
263 "item",
264 DataType::Utf8,
265 true,
266 )))),
267 }
268 }
269
270 #[allow(
271 clippy::cast_sign_loss,
272 clippy::cast_possible_wrap,
273 clippy::cast_possible_truncation
274 )]
275 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
276 let expanded = expand_args(&args.args)?;
277 let list_arr = expanded[0]
278 .as_any()
279 .downcast_ref::<ListArray>()
280 .ok_or_else(|| {
281 datafusion_common::DataFusionError::Internal(
282 "array_filter: first arg must be List".into(),
283 )
284 })?;
285
286 let lambda_str = scalar_string_value(&args.args[1])?;
287 let flat_values = list_arr.values();
288 let elem_type = flat_values.data_type().clone();
289
290 let schema = Arc::new(Schema::new(vec![Field::new("x", elem_type.clone(), true)]));
291 let batch = arrow_array::RecordBatch::try_new(schema, vec![Arc::clone(flat_values)])?;
292
293 let mask_arr = eval_expr_on_batch(&lambda_str, &batch)?;
294 let mask = mask_arr
295 .as_any()
296 .downcast_ref::<BooleanArray>()
297 .ok_or_else(|| {
298 datafusion_common::DataFusionError::Internal(
299 "array_filter: lambda must return Boolean".into(),
300 )
301 })?;
302
303 let mut offsets = vec![0i32];
304 let mut filtered_indices: Vec<usize> = Vec::new();
305
306 for row in 0..list_arr.len() {
307 let start = list_arr.value_offsets()[row] as usize;
308 let end = list_arr.value_offsets()[row + 1] as usize;
309
310 for i in start..end {
311 if !mask.is_null(i) && mask.value(i) {
312 filtered_indices.push(i);
313 }
314 }
315 offsets.push(filtered_indices.len() as i32);
316 }
317
318 let indices = arrow_array::UInt32Array::from(
319 filtered_indices
320 .iter()
321 .map(|&i| i as u32)
322 .collect::<Vec<_>>(),
323 );
324 let filtered_values = arrow::compute::take(flat_values.as_ref(), &indices, None)?;
325
326 let new_field = Arc::new(Field::new("item", elem_type, true));
327 let new_offsets =
328 arrow::buffer::OffsetBuffer::new(arrow::buffer::ScalarBuffer::from(offsets));
329 let new_list = ListArray::try_new(
330 new_field,
331 new_offsets,
332 filtered_values,
333 list_arr.nulls().cloned(),
334 )?;
335 Ok(ColumnarValue::Array(Arc::new(new_list)))
336 }
337}
338
339#[derive(Debug)]
345pub struct ArrayReduce {
346 signature: Signature,
347}
348
349impl ArrayReduce {
350 #[must_use]
352 pub fn new() -> Self {
353 Self {
354 signature: Signature::new(TypeSignature::Any(3), Volatility::Immutable),
355 }
356 }
357}
358
359impl Default for ArrayReduce {
360 fn default() -> Self {
361 Self::new()
362 }
363}
364impl PartialEq for ArrayReduce {
365 fn eq(&self, _: &Self) -> bool {
366 true
367 }
368}
369impl Eq for ArrayReduce {}
370impl Hash for ArrayReduce {
371 fn hash<H: Hasher>(&self, s: &mut H) {
372 "array_reduce".hash(s);
373 }
374}
375
376impl ScalarUDFImpl for ArrayReduce {
377 fn as_any(&self) -> &dyn Any {
378 self
379 }
380 fn name(&self) -> &'static str {
381 "array_reduce"
382 }
383 fn signature(&self) -> &Signature {
384 &self.signature
385 }
386
387 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
388 Ok(arg_types.get(1).cloned().unwrap_or(DataType::Int64))
389 }
390
391 #[allow(clippy::cast_sign_loss)]
392 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
393 let expanded = expand_args(&args.args)?;
394 let list_arr = expanded[0]
395 .as_any()
396 .downcast_ref::<ListArray>()
397 .ok_or_else(|| {
398 datafusion_common::DataFusionError::Internal(
399 "array_reduce: first arg must be List".into(),
400 )
401 })?;
402
403 let init_arr = &expanded[1];
404 let lambda_str = scalar_string_value(&args.args[2])?;
405
406 let elem_type = list_arr.values().data_type().clone();
407 let acc_type = init_arr.data_type().clone();
408
409 let schema = Arc::new(Schema::new(vec![
410 Field::new("acc", acc_type, true),
411 Field::new("x", elem_type, true),
412 ]));
413
414 let mut result_builder: Vec<ArrayRef> = Vec::new();
415
416 for row in 0..list_arr.len() {
417 let start = list_arr.value_offsets()[row] as usize;
418 let end = list_arr.value_offsets()[row + 1] as usize;
419
420 let mut acc: ArrayRef = init_arr.slice(row, 1);
421
422 for i in start..end {
423 let x = list_arr.values().slice(i, 1);
424 let batch = arrow_array::RecordBatch::try_new(
425 Arc::clone(&schema),
426 vec![Arc::clone(&acc), x],
427 )?;
428 let result_col = eval_expr_on_batch(&lambda_str, &batch)?;
429 acc = result_col;
430 }
431
432 result_builder.push(acc);
433 }
434
435 if result_builder.is_empty() {
436 return Ok(ColumnarValue::Array(Arc::clone(init_arr)));
437 }
438
439 let refs: Vec<&dyn Array> = result_builder
440 .iter()
441 .map(std::convert::AsRef::as_ref)
442 .collect();
443 let result = arrow::compute::concat(&refs)?;
444 Ok(ColumnarValue::Array(result))
445 }
446}
447
448#[derive(Debug)]
454pub struct MapFilter {
455 signature: Signature,
456}
457
458impl MapFilter {
459 #[must_use]
461 pub fn new() -> Self {
462 Self {
463 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
464 }
465 }
466}
467
468impl Default for MapFilter {
469 fn default() -> Self {
470 Self::new()
471 }
472}
473impl PartialEq for MapFilter {
474 fn eq(&self, _: &Self) -> bool {
475 true
476 }
477}
478impl Eq for MapFilter {}
479impl Hash for MapFilter {
480 fn hash<H: Hasher>(&self, s: &mut H) {
481 "map_filter".hash(s);
482 }
483}
484
485impl ScalarUDFImpl for MapFilter {
486 fn as_any(&self) -> &dyn Any {
487 self
488 }
489 fn name(&self) -> &'static str {
490 "map_filter"
491 }
492 fn signature(&self) -> &Signature {
493 &self.signature
494 }
495
496 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
497 Ok(arg_types[0].clone())
498 }
499
500 #[allow(
501 clippy::cast_sign_loss,
502 clippy::cast_possible_wrap,
503 clippy::cast_possible_truncation
504 )]
505 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
506 let expanded = expand_args(&args.args)?;
507 let map_arr = expanded[0]
508 .as_any()
509 .downcast_ref::<MapArray>()
510 .ok_or_else(|| {
511 datafusion_common::DataFusionError::Internal(
512 "map_filter: first arg must be Map".into(),
513 )
514 })?;
515
516 let lambda_str = scalar_string_value(&args.args[1])?;
517
518 let entries = map_arr.entries();
519 let key_col = entries.column(0);
520 let val_col = entries.column(1);
521
522 let key_type = key_col.data_type().clone();
523 let val_type = val_col.data_type().clone();
524
525 let schema = Arc::new(Schema::new(vec![
526 Field::new("k", key_type.clone(), true),
527 Field::new("v", val_type.clone(), true),
528 ]));
529 let batch = arrow_array::RecordBatch::try_new(
530 schema,
531 vec![Arc::clone(key_col), Arc::clone(val_col)],
532 )?;
533
534 let mask_arr = eval_expr_on_batch(&lambda_str, &batch)?;
535 let mask_bool = mask_arr
536 .as_any()
537 .downcast_ref::<BooleanArray>()
538 .ok_or_else(|| {
539 datafusion_common::DataFusionError::Internal(
540 "map_filter: lambda must return Boolean".into(),
541 )
542 })?;
543
544 let mut offsets = vec![0i32];
545 let mut keep_indices: Vec<usize> = Vec::new();
546
547 for row in 0..map_arr.len() {
548 let start = map_arr.value_offsets()[row] as usize;
549 let end = map_arr.value_offsets()[row + 1] as usize;
550
551 for i in start..end {
552 if !mask_bool.is_null(i) && mask_bool.value(i) {
553 keep_indices.push(i);
554 }
555 }
556 offsets.push(keep_indices.len() as i32);
557 }
558
559 let indices = arrow_array::UInt32Array::from(
560 keep_indices.iter().map(|&i| i as u32).collect::<Vec<_>>(),
561 );
562 let new_keys = arrow::compute::take(key_col.as_ref(), &indices, None)?;
563 let new_vals = arrow::compute::take(val_col.as_ref(), &indices, None)?;
564
565 let struct_fields = Fields::from(vec![
566 Field::new("key", key_type, false),
567 Field::new("value", val_type, true),
568 ]);
569 let new_entries = StructArray::try_new(struct_fields, vec![new_keys, new_vals], None)?;
570
571 let entries_field = Field::new("entries", new_entries.data_type().clone(), false);
572 let new_offsets =
573 arrow::buffer::OffsetBuffer::new(arrow::buffer::ScalarBuffer::from(offsets));
574 let new_map = MapArray::try_new(
575 Arc::new(entries_field),
576 new_offsets,
577 new_entries,
578 map_arr.nulls().cloned(),
579 false,
580 )?;
581 Ok(ColumnarValue::Array(Arc::new(new_map)))
582 }
583}
584
585#[derive(Debug)]
591pub struct MapTransformValues {
592 signature: Signature,
593}
594
595impl MapTransformValues {
596 #[must_use]
598 pub fn new() -> Self {
599 Self {
600 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
601 }
602 }
603}
604
605impl Default for MapTransformValues {
606 fn default() -> Self {
607 Self::new()
608 }
609}
610impl PartialEq for MapTransformValues {
611 fn eq(&self, _: &Self) -> bool {
612 true
613 }
614}
615impl Eq for MapTransformValues {}
616impl Hash for MapTransformValues {
617 fn hash<H: Hasher>(&self, s: &mut H) {
618 "map_transform_values".hash(s);
619 }
620}
621
622impl ScalarUDFImpl for MapTransformValues {
623 fn as_any(&self) -> &dyn Any {
624 self
625 }
626 fn name(&self) -> &'static str {
627 "map_transform_values"
628 }
629 fn signature(&self) -> &Signature {
630 &self.signature
631 }
632
633 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
634 Ok(arg_types[0].clone())
635 }
636
637 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
638 let expanded = expand_args(&args.args)?;
639 let map_arr = expanded[0]
640 .as_any()
641 .downcast_ref::<MapArray>()
642 .ok_or_else(|| {
643 datafusion_common::DataFusionError::Internal(
644 "map_transform_values: first arg must be Map".into(),
645 )
646 })?;
647
648 let lambda_str = scalar_string_value(&args.args[1])?;
649
650 let entries = map_arr.entries();
651 let key_col = entries.column(0);
652 let val_col = entries.column(1);
653
654 let key_type = key_col.data_type().clone();
655
656 let schema = Arc::new(Schema::new(vec![
657 Field::new("k", key_type.clone(), true),
658 Field::new("v", val_col.data_type().clone(), true),
659 ]));
660 let batch = arrow_array::RecordBatch::try_new(
661 schema,
662 vec![Arc::clone(key_col), Arc::clone(val_col)],
663 )?;
664
665 let new_vals = eval_expr_on_batch(&lambda_str, &batch)?;
666
667 let struct_fields = Fields::from(vec![
668 Field::new("key", key_type, false),
669 Field::new("value", new_vals.data_type().clone(), true),
670 ]);
671 let new_entries =
672 StructArray::try_new(struct_fields, vec![Arc::clone(key_col), new_vals], None)?;
673
674 let entries_field = Field::new("entries", new_entries.data_type().clone(), false);
675 let new_map = MapArray::try_new(
676 Arc::new(entries_field),
677 map_arr.offsets().clone(),
678 new_entries,
679 map_arr.nulls().cloned(),
680 false,
681 )?;
682 Ok(ColumnarValue::Array(Arc::new(new_map)))
683 }
684}
685
686#[cfg(test)]
687mod tests {
688 use super::*;
689 use crate::datafusion::create_session_context;
690 use arrow_array::*;
691 use datafusion_common::config::ConfigOptions;
692
693 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
696 async fn test_array_transform_add_one() {
697 let values = Int64Array::from(vec![1, 2, 3, 4, 5, 6]);
698 let offsets =
699 arrow::buffer::OffsetBuffer::new(arrow::buffer::ScalarBuffer::from(vec![0i32, 3, 6]));
700 let list = ListArray::try_new(
701 Arc::new(Field::new("item", DataType::Int64, true)),
702 offsets,
703 Arc::new(values),
704 None,
705 )
706 .unwrap();
707
708 let udf = ArrayTransform::new();
709 let result = udf
710 .invoke_with_args(ScalarFunctionArgs {
711 args: vec![
712 ColumnarValue::Array(Arc::new(list)),
713 ColumnarValue::Scalar(datafusion_common::ScalarValue::Utf8(Some(
714 "x + 1".into(),
715 ))),
716 ],
717 number_rows: 0,
718 arg_fields: vec![],
719 return_field: Arc::new(Field::new(
720 "output",
721 DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
722 true,
723 )),
724 config_options: Arc::new(ConfigOptions::default()),
725 })
726 .unwrap();
727
728 if let ColumnarValue::Array(arr) = result {
729 let la = arr.as_any().downcast_ref::<ListArray>().unwrap();
730 assert_eq!(la.len(), 2);
731 let row0 = la.value(0);
732 let r0 = row0.as_any().downcast_ref::<Int64Array>().unwrap();
733 assert_eq!(r0.value(0), 2);
734 assert_eq!(r0.value(1), 3);
735 assert_eq!(r0.value(2), 4);
736 } else {
737 panic!("expected Array");
738 }
739 }
740
741 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
744 async fn test_array_filter_positive() {
745 let values = Int64Array::from(vec![-1, 2, -3, 4]);
746 let offsets =
747 arrow::buffer::OffsetBuffer::new(arrow::buffer::ScalarBuffer::from(vec![0i32, 4]));
748 let list = ListArray::try_new(
749 Arc::new(Field::new("item", DataType::Int64, true)),
750 offsets,
751 Arc::new(values),
752 None,
753 )
754 .unwrap();
755
756 let udf = ArrayFilter::new();
757 let result = udf
758 .invoke_with_args(ScalarFunctionArgs {
759 args: vec![
760 ColumnarValue::Array(Arc::new(list)),
761 ColumnarValue::Scalar(datafusion_common::ScalarValue::Utf8(Some(
762 "x > 0".into(),
763 ))),
764 ],
765 number_rows: 0,
766 arg_fields: vec![],
767 return_field: Arc::new(Field::new(
768 "output",
769 DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
770 true,
771 )),
772 config_options: Arc::new(ConfigOptions::default()),
773 })
774 .unwrap();
775
776 if let ColumnarValue::Array(arr) = result {
777 let la = arr.as_any().downcast_ref::<ListArray>().unwrap();
778 let row0 = la.value(0);
779 let r0 = row0.as_any().downcast_ref::<Int64Array>().unwrap();
780 assert_eq!(r0.len(), 2);
781 assert_eq!(r0.value(0), 2);
782 assert_eq!(r0.value(1), 4);
783 }
784 }
785
786 #[test]
789 fn test_register_lambda_functions() {
790 use datafusion::execution::FunctionRegistry;
791
792 let ctx = create_session_context();
793 register_lambda_functions(&ctx);
794 assert!(ctx.udf("array_transform").is_ok());
795 assert!(ctx.udf("array_filter").is_ok());
796 assert!(ctx.udf("array_reduce").is_ok());
797 assert!(ctx.udf("map_filter").is_ok());
798 assert!(ctx.udf("map_transform_values").is_ok());
799 }
800}