1use std::any::Any;
8use std::hash::{Hash, Hasher};
9use std::sync::Arc;
10
11use arrow::datatypes::DataType;
12use arrow_array::{Array, ArrayRef, LargeBinaryArray, StringArray};
13use arrow_schema::Field;
14use datafusion_common::{Result, ScalarValue};
15use datafusion_expr::function::AccumulatorArgs;
16use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, TypeSignature, Volatility};
17
18use super::json_types;
19
20#[derive(Debug)]
29pub struct JsonAgg {
30 signature: Signature,
31}
32
33impl JsonAgg {
34 #[must_use]
36 pub fn new() -> Self {
37 Self {
38 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
39 }
40 }
41}
42
43impl Default for JsonAgg {
44 fn default() -> Self {
45 Self::new()
46 }
47}
48
49impl PartialEq for JsonAgg {
50 fn eq(&self, _other: &Self) -> bool {
51 true
52 }
53}
54
55impl Eq for JsonAgg {}
56
57impl Hash for JsonAgg {
58 fn hash<H: Hasher>(&self, state: &mut H) {
59 "json_agg".hash(state);
60 }
61}
62
63impl AggregateUDFImpl for JsonAgg {
64 fn as_any(&self) -> &dyn Any {
65 self
66 }
67
68 fn name(&self) -> &'static str {
69 "json_agg"
70 }
71
72 fn signature(&self) -> &Signature {
73 &self.signature
74 }
75
76 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
77 Ok(DataType::LargeBinary)
78 }
79
80 fn state_fields(
81 &self,
82 _args: datafusion_expr::function::StateFieldsArgs,
83 ) -> Result<Vec<Arc<Field>>> {
84 Ok(vec![Arc::new(Field::new(
86 "json_agg_state",
87 DataType::LargeBinary,
88 true,
89 ))])
90 }
91
92 fn accumulator(&self, _args: AccumulatorArgs<'_>) -> Result<Box<dyn Accumulator>> {
93 Ok(Box::new(JsonAggAccumulator::new()))
94 }
95}
96
97#[derive(Debug)]
101struct JsonAggAccumulator {
102 values: Vec<serde_json::Value>,
103}
104
105impl JsonAggAccumulator {
106 fn new() -> Self {
107 Self { values: Vec::new() }
108 }
109}
110
111impl Accumulator for JsonAggAccumulator {
112 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
113 let arr = &values[0];
114 for i in 0..arr.len() {
115 if arr.is_null(i) {
116 self.values.push(serde_json::Value::Null);
117 } else {
118 self.values.push(array_value_to_json(arr, i));
119 }
120 }
121 Ok(())
122 }
123
124 fn evaluate(&mut self) -> Result<ScalarValue> {
125 let json_arr = serde_json::Value::Array(self.values.clone());
126 let bytes = json_types::encode_jsonb(&json_arr);
127 Ok(ScalarValue::LargeBinary(Some(bytes)))
128 }
129
130 fn size(&self) -> usize {
131 std::mem::size_of::<Self>()
132 + self.values.capacity() * std::mem::size_of::<serde_json::Value>()
133 }
134
135 fn state(&mut self) -> Result<Vec<ScalarValue>> {
136 let json_arr = serde_json::Value::Array(self.values.clone());
138 let bytes = json_types::encode_jsonb(&json_arr);
139 Ok(vec![ScalarValue::LargeBinary(Some(bytes))])
140 }
141
142 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
143 let arr = states[0]
144 .as_any()
145 .downcast_ref::<LargeBinaryArray>()
146 .ok_or_else(|| {
147 datafusion_common::DataFusionError::Internal(
148 "json_agg: merge state must be LargeBinary".into(),
149 )
150 })?;
151 for i in 0..arr.len() {
152 if !arr.is_null(i) {
153 let bytes = arr.value(i);
154 if let Some(json_str) = json_types::jsonb_to_text(bytes) {
156 if let Ok(serde_json::Value::Array(elems)) =
157 serde_json::from_str::<serde_json::Value>(&json_str)
158 {
159 self.values.extend(elems);
160 }
161 }
162 }
163 }
164 Ok(())
165 }
166}
167
168#[derive(Debug)]
177pub struct JsonObjectAgg {
178 signature: Signature,
179}
180
181impl JsonObjectAgg {
182 #[must_use]
184 pub fn new() -> Self {
185 Self {
186 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
187 }
188 }
189}
190
191impl Default for JsonObjectAgg {
192 fn default() -> Self {
193 Self::new()
194 }
195}
196
197impl PartialEq for JsonObjectAgg {
198 fn eq(&self, _other: &Self) -> bool {
199 true
200 }
201}
202
203impl Eq for JsonObjectAgg {}
204
205impl Hash for JsonObjectAgg {
206 fn hash<H: Hasher>(&self, state: &mut H) {
207 "json_object_agg".hash(state);
208 }
209}
210
211impl AggregateUDFImpl for JsonObjectAgg {
212 fn as_any(&self) -> &dyn Any {
213 self
214 }
215
216 fn name(&self) -> &'static str {
217 "json_object_agg"
218 }
219
220 fn signature(&self) -> &Signature {
221 &self.signature
222 }
223
224 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
225 Ok(DataType::LargeBinary)
226 }
227
228 fn state_fields(
229 &self,
230 _args: datafusion_expr::function::StateFieldsArgs,
231 ) -> Result<Vec<Arc<Field>>> {
232 Ok(vec![Arc::new(Field::new(
233 "json_object_agg_state",
234 DataType::LargeBinary,
235 true,
236 ))])
237 }
238
239 fn accumulator(&self, _args: AccumulatorArgs<'_>) -> Result<Box<dyn Accumulator>> {
240 Ok(Box::new(JsonObjectAggAccumulator::new()))
241 }
242}
243
244#[derive(Debug)]
248struct JsonObjectAggAccumulator {
249 entries: serde_json::Map<String, serde_json::Value>,
250}
251
252impl JsonObjectAggAccumulator {
253 fn new() -> Self {
254 Self {
255 entries: serde_json::Map::new(),
256 }
257 }
258}
259
260impl Accumulator for JsonObjectAggAccumulator {
261 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
262 let key_arr = &values[0];
263 let val_arr = &values[1];
264
265 for i in 0..key_arr.len() {
266 if key_arr.is_null(i) {
267 continue; }
269 let key = array_value_to_string(key_arr, i)?;
270 let val = if val_arr.is_null(i) {
271 serde_json::Value::Null
272 } else {
273 array_value_to_json(val_arr, i)
274 };
275 self.entries.insert(key, val); }
277 Ok(())
278 }
279
280 fn evaluate(&mut self) -> Result<ScalarValue> {
281 let obj = serde_json::Value::Object(self.entries.clone());
282 let bytes = json_types::encode_jsonb(&obj);
283 Ok(ScalarValue::LargeBinary(Some(bytes)))
284 }
285
286 fn size(&self) -> usize {
287 std::mem::size_of::<Self>() + self.entries.len() * 64 }
289
290 fn state(&mut self) -> Result<Vec<ScalarValue>> {
291 let obj = serde_json::Value::Object(self.entries.clone());
292 let bytes = json_types::encode_jsonb(&obj);
293 Ok(vec![ScalarValue::LargeBinary(Some(bytes))])
294 }
295
296 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
297 let arr = states[0]
298 .as_any()
299 .downcast_ref::<LargeBinaryArray>()
300 .ok_or_else(|| {
301 datafusion_common::DataFusionError::Internal(
302 "json_object_agg: merge state must be LargeBinary".into(),
303 )
304 })?;
305 for i in 0..arr.len() {
306 if !arr.is_null(i) {
307 let bytes = arr.value(i);
308 if let Some(json_str) = json_types::jsonb_to_text(bytes) {
309 if let Ok(serde_json::Value::Object(map)) =
310 serde_json::from_str::<serde_json::Value>(&json_str)
311 {
312 for (k, v) in map {
313 self.entries.insert(k, v);
314 }
315 }
316 }
317 }
318 }
319 Ok(())
320 }
321}
322
323fn array_value_to_json(arr: &ArrayRef, row: usize) -> serde_json::Value {
327 if arr.is_null(row) {
328 return serde_json::Value::Null;
329 }
330 if let Some(a) = arr.as_any().downcast_ref::<StringArray>() {
331 return serde_json::Value::String(a.value(row).to_owned());
332 }
333 if let Some(a) = arr.as_any().downcast_ref::<arrow_array::Int64Array>() {
334 return serde_json::Value::Number(a.value(row).into());
335 }
336 if let Some(a) = arr.as_any().downcast_ref::<arrow_array::Int32Array>() {
337 return serde_json::Value::Number(i64::from(a.value(row)).into());
338 }
339 if let Some(a) = arr.as_any().downcast_ref::<arrow_array::Float64Array>() {
340 if let Some(n) = serde_json::Number::from_f64(a.value(row)) {
341 return serde_json::Value::Number(n);
342 }
343 return serde_json::Value::Null;
344 }
345 if let Some(a) = arr.as_any().downcast_ref::<arrow_array::BooleanArray>() {
346 return serde_json::Value::Bool(a.value(row));
347 }
348 let scalar = ScalarValue::try_from_array(arr, row).ok();
350 match scalar {
351 Some(s) => serde_json::Value::String(s.to_string()),
352 None => serde_json::Value::Null,
353 }
354}
355
356fn array_value_to_string(arr: &ArrayRef, row: usize) -> Result<String> {
358 if let Some(a) = arr.as_any().downcast_ref::<StringArray>() {
359 return Ok(a.value(row).to_owned());
360 }
361 let sv = ScalarValue::try_from_array(arr, row)?;
363 Ok(sv.to_string())
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use arrow_array::Int64Array;
370
371 fn make_string_array(vals: &[&str]) -> StringArray {
372 StringArray::from(vals.to_vec())
373 }
374
375 #[test]
376 fn test_json_agg_basic() {
377 let mut acc = JsonAggAccumulator::new();
378 let vals = Arc::new(Int64Array::from(vec![1, 2, 3])) as ArrayRef;
379 acc.update_batch(&[vals]).unwrap();
380 let result = acc.evaluate().unwrap();
381 match result {
382 ScalarValue::LargeBinary(Some(bytes)) => {
383 assert_eq!(json_types::jsonb_type_name(&bytes), Some("array"));
384 let e0 = json_types::jsonb_array_get(&bytes, 0).unwrap();
385 assert_eq!(json_types::jsonb_to_text(e0), Some("1".to_owned()));
386 let e2 = json_types::jsonb_array_get(&bytes, 2).unwrap();
387 assert_eq!(json_types::jsonb_to_text(e2), Some("3".to_owned()));
388 }
389 other => panic!("Expected LargeBinary, got {other:?}"),
390 }
391 }
392
393 #[test]
394 fn test_json_agg_strings() {
395 let mut acc = JsonAggAccumulator::new();
396 let vals = Arc::new(make_string_array(&["a", "b", "c"])) as ArrayRef;
397 acc.update_batch(&[vals]).unwrap();
398 let result = acc.evaluate().unwrap();
399 match result {
400 ScalarValue::LargeBinary(Some(bytes)) => {
401 let e0 = json_types::jsonb_array_get(&bytes, 0).unwrap();
402 assert_eq!(json_types::jsonb_to_text(e0), Some("a".to_owned()));
403 }
404 other => panic!("Expected LargeBinary, got {other:?}"),
405 }
406 }
407
408 #[test]
409 fn test_json_agg_multiple_batches() {
410 let mut acc = JsonAggAccumulator::new();
411 let v1 = Arc::new(Int64Array::from(vec![1, 2])) as ArrayRef;
412 let v2 = Arc::new(Int64Array::from(vec![3])) as ArrayRef;
413 acc.update_batch(&[v1]).unwrap();
414 acc.update_batch(&[v2]).unwrap();
415 let result = acc.evaluate().unwrap();
416 match result {
417 ScalarValue::LargeBinary(Some(bytes)) => {
418 let text = json_types::jsonb_to_text(&bytes).unwrap();
420 assert_eq!(text, "[1,2,3]");
421 }
422 other => panic!("Expected LargeBinary, got {other:?}"),
423 }
424 }
425
426 #[test]
427 fn test_json_object_agg_basic() {
428 let mut acc = JsonObjectAggAccumulator::new();
429 let keys = Arc::new(make_string_array(&["a", "b", "c"])) as ArrayRef;
430 let vals = Arc::new(Int64Array::from(vec![1, 2, 3])) as ArrayRef;
431 acc.update_batch(&[keys, vals]).unwrap();
432 let result = acc.evaluate().unwrap();
433 match result {
434 ScalarValue::LargeBinary(Some(bytes)) => {
435 assert_eq!(json_types::jsonb_type_name(&bytes), Some("object"));
436 let a = json_types::jsonb_get_field(&bytes, "a").unwrap();
437 assert_eq!(json_types::jsonb_to_text(a), Some("1".to_owned()));
438 let c = json_types::jsonb_get_field(&bytes, "c").unwrap();
439 assert_eq!(json_types::jsonb_to_text(c), Some("3".to_owned()));
440 }
441 other => panic!("Expected LargeBinary, got {other:?}"),
442 }
443 }
444
445 #[test]
446 fn test_json_object_agg_last_value_wins() {
447 let mut acc = JsonObjectAggAccumulator::new();
448 let keys = Arc::new(make_string_array(&["a", "a"])) as ArrayRef;
449 let vals = Arc::new(Int64Array::from(vec![1, 2])) as ArrayRef;
450 acc.update_batch(&[keys, vals]).unwrap();
451 let result = acc.evaluate().unwrap();
452 match result {
453 ScalarValue::LargeBinary(Some(bytes)) => {
454 let a = json_types::jsonb_get_field(&bytes, "a").unwrap();
455 assert_eq!(json_types::jsonb_to_text(a), Some("2".to_owned()));
456 }
457 other => panic!("Expected LargeBinary, got {other:?}"),
458 }
459 }
460
461 #[test]
462 fn test_json_agg_state_merge() {
463 let mut acc1 = JsonAggAccumulator::new();
464 let v1 = Arc::new(Int64Array::from(vec![1, 2])) as ArrayRef;
465 acc1.update_batch(&[v1]).unwrap();
466 let state = acc1.state().unwrap();
467
468 let mut acc2 = JsonAggAccumulator::new();
469 let v2 = Arc::new(Int64Array::from(vec![3])) as ArrayRef;
470 acc2.update_batch(&[v2]).unwrap();
471
472 let state_arr: ArrayRef = match &state[0] {
474 ScalarValue::LargeBinary(Some(b)) => {
475 Arc::new(LargeBinaryArray::from_iter_values(vec![b.as_slice()]))
476 }
477 _ => panic!("expected LargeBinary state"),
478 };
479 acc2.merge_batch(&[state_arr]).unwrap();
480
481 let result = acc2.evaluate().unwrap();
482 match result {
483 ScalarValue::LargeBinary(Some(bytes)) => {
484 let text = json_types::jsonb_to_text(&bytes).unwrap();
485 assert_eq!(text, "[3,1,2]");
486 }
487 other => panic!("Expected LargeBinary, got {other:?}"),
488 }
489 }
490
491 #[test]
492 fn test_udaf_registration() {
493 let json_agg = datafusion_expr::AggregateUDF::new_from_impl(JsonAgg::new());
494 assert_eq!(json_agg.name(), "json_agg");
495
496 let json_obj_agg = datafusion_expr::AggregateUDF::new_from_impl(JsonObjectAgg::new());
497 assert_eq!(json_obj_agg.name(), "json_object_agg");
498 }
499}