1use std::sync::Arc;
15
16use arrow::datatypes::{DataType, Field, Schema};
17use arrow::record_batch::RecordBatch;
18use arrow_array::{Int64Array, LargeBinaryArray, StringArray};
19use datafusion::catalog::{TableFunctionImpl, TableProvider};
20use datafusion::datasource::MemTable;
21use datafusion_common::{plan_err, Result, ScalarValue};
22use datafusion_expr::Expr;
23
24use super::json_types;
25
26fn ordinality_vec(n: usize) -> Vec<i64> {
30 (1..=n)
31 .map(|i| i64::try_from(i).unwrap_or(i64::MAX))
32 .collect()
33}
34
35mod tags {
37 pub const ARRAY: u8 = 0x06;
38 pub const OBJECT: u8 = 0x07;
39}
40
41fn extract_jsonb_literal(expr: &Expr) -> Result<Option<Vec<u8>>> {
43 match expr {
44 Expr::Literal(ScalarValue::LargeBinary(bytes), _) => Ok(bytes.clone()),
45 Expr::Literal(ScalarValue::Null | ScalarValue::Utf8(None), _) => Ok(None),
46 Expr::Literal(ScalarValue::Utf8(Some(s)), _) => {
48 let json_val: serde_json::Value = serde_json::from_str(s).map_err(|e| {
49 datafusion_common::DataFusionError::Plan(format!("invalid JSON literal: {e}"))
50 })?;
51 Ok(Some(json_types::encode_jsonb(&json_val)))
52 }
53 other => plan_err!(
54 "JSON TVF argument must be a JSONB (LargeBinary) or JSON string literal, got {other:?}"
55 ),
56 }
57}
58
59fn jsonb_array_elements_iter(data: &[u8]) -> Option<Vec<Vec<u8>>> {
63 if data.is_empty() || data[0] != tags::ARRAY {
64 return None;
65 }
66 if data.len() < 5 {
67 return None;
68 }
69 let count = u32::from_le_bytes([data[1], data[2], data[3], data[4]]) as usize;
70 let offsets_start = 5;
72 let data_start = offsets_start + count * 4;
73 if data.len() < data_start {
74 return None;
75 }
76
77 let mut elements = Vec::with_capacity(count);
78 for i in 0..count {
79 let off_pos = offsets_start + i * 4;
80 let offset = u32::from_le_bytes([
81 data[off_pos],
82 data[off_pos + 1],
83 data[off_pos + 2],
84 data[off_pos + 3],
85 ]) as usize;
86
87 let abs_start = data_start + offset;
88 let abs_end = if i + 1 < count {
90 let next_pos = offsets_start + (i + 1) * 4;
91 data_start
92 + u32::from_le_bytes([
93 data[next_pos],
94 data[next_pos + 1],
95 data[next_pos + 2],
96 data[next_pos + 3],
97 ]) as usize
98 } else {
99 data.len()
100 };
101
102 if abs_start <= abs_end && abs_end <= data.len() {
103 elements.push(data[abs_start..abs_end].to_vec());
104 }
105 }
106 Some(elements)
107}
108
109fn jsonb_object_entries(data: &[u8]) -> Option<Vec<(String, Vec<u8>)>> {
113 if data.is_empty() || data[0] != tags::OBJECT {
114 return None;
115 }
116 if data.len() < 5 {
117 return None;
118 }
119 let count = u32::from_le_bytes([data[1], data[2], data[3], data[4]]) as usize;
120 let offsets_start = 5;
122 let data_start = offsets_start + count * 8;
123 if data.len() < data_start {
124 return None;
125 }
126
127 let mut entries = Vec::with_capacity(count);
128 for i in 0..count {
129 let base = offsets_start + i * 8;
130 let key_off =
131 u32::from_le_bytes([data[base], data[base + 1], data[base + 2], data[base + 3]])
132 as usize;
133 let val_off = u32::from_le_bytes([
134 data[base + 4],
135 data[base + 5],
136 data[base + 6],
137 data[base + 7],
138 ]) as usize;
139
140 let key_abs = data_start + key_off;
142 if key_abs + 2 > data.len() {
143 continue;
144 }
145 let key_len = u16::from_le_bytes([data[key_abs], data[key_abs + 1]]) as usize;
146 let key_start = key_abs + 2;
147 let key_end = key_start + key_len;
148 if key_end > data.len() {
149 continue;
150 }
151 let key = String::from_utf8_lossy(&data[key_start..key_end]).to_string();
152
153 let val_abs = data_start + val_off;
155 let val_end = if i + 1 < count {
156 let next_base = offsets_start + (i + 1) * 8;
157 data_start
158 + u32::from_le_bytes([
159 data[next_base],
160 data[next_base + 1],
161 data[next_base + 2],
162 data[next_base + 3],
163 ]) as usize
164 } else {
165 data.len()
166 };
167
168 if val_abs <= val_end && val_end <= data.len() {
169 entries.push((key, data[val_abs..val_end].to_vec()));
170 }
171 }
172 Some(entries)
173}
174
175#[derive(Debug)]
181pub struct JsonbArrayElementsTvf;
182
183impl TableFunctionImpl for JsonbArrayElementsTvf {
184 fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
185 if args.len() != 1 {
186 return plan_err!("jsonb_array_elements requires exactly 1 argument");
187 }
188 let schema = Arc::new(Schema::new(vec![
189 Field::new("value", DataType::LargeBinary, true),
190 Field::new("ordinality", DataType::Int64, false),
191 ]));
192
193 let bytes = extract_jsonb_literal(&args[0])?;
194 let elements = bytes.as_deref().and_then(jsonb_array_elements_iter);
195
196 match elements {
197 Some(elems) if !elems.is_empty() => {
198 let values: Vec<Option<&[u8]>> = elems.iter().map(|e| Some(e.as_slice())).collect();
199 let ordinality = ordinality_vec(elems.len());
200 let batch = RecordBatch::try_new(
201 Arc::clone(&schema),
202 vec![
203 Arc::new(LargeBinaryArray::from(values)),
204 Arc::new(Int64Array::from(ordinality)),
205 ],
206 )?;
207 Ok(Arc::new(MemTable::try_new(schema, vec![vec![batch]])?))
208 }
209 _ => Ok(Arc::new(MemTable::try_new(schema, vec![vec![]])?)),
210 }
211 }
212}
213
214#[derive(Debug)]
220pub struct JsonbArrayElementsTextTvf;
221
222impl TableFunctionImpl for JsonbArrayElementsTextTvf {
223 fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
224 if args.len() != 1 {
225 return plan_err!("jsonb_array_elements_text requires exactly 1 argument");
226 }
227 let schema = Arc::new(Schema::new(vec![
228 Field::new("value", DataType::Utf8, true),
229 Field::new("ordinality", DataType::Int64, false),
230 ]));
231
232 let bytes = extract_jsonb_literal(&args[0])?;
233 let elements = bytes.as_deref().and_then(jsonb_array_elements_iter);
234
235 match elements {
236 Some(elems) if !elems.is_empty() => {
237 let texts: Vec<Option<String>> =
238 elems.iter().map(|e| json_types::jsonb_to_text(e)).collect();
239 let ordinality = ordinality_vec(elems.len());
240 let batch = RecordBatch::try_new(
241 Arc::clone(&schema),
242 vec![
243 Arc::new(StringArray::from(texts)),
244 Arc::new(Int64Array::from(ordinality)),
245 ],
246 )?;
247 Ok(Arc::new(MemTable::try_new(schema, vec![vec![batch]])?))
248 }
249 _ => Ok(Arc::new(MemTable::try_new(schema, vec![vec![]])?)),
250 }
251 }
252}
253
254#[derive(Debug)]
260pub struct JsonbEachTvf;
261
262impl TableFunctionImpl for JsonbEachTvf {
263 fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
264 if args.len() != 1 {
265 return plan_err!("jsonb_each requires exactly 1 argument");
266 }
267 let schema = Arc::new(Schema::new(vec![
268 Field::new("key", DataType::Utf8, false),
269 Field::new("value", DataType::LargeBinary, true),
270 Field::new("ordinality", DataType::Int64, false),
271 ]));
272
273 let bytes = extract_jsonb_literal(&args[0])?;
274 let entries = bytes.as_deref().and_then(jsonb_object_entries);
275
276 match entries {
277 Some(kvs) if !kvs.is_empty() => {
278 let keys: Vec<&str> = kvs.iter().map(|(k, _)| k.as_str()).collect();
279 let values: Vec<Option<&[u8]>> =
280 kvs.iter().map(|(_, v)| Some(v.as_slice())).collect();
281 let ordinality = ordinality_vec(kvs.len());
282 let batch = RecordBatch::try_new(
283 Arc::clone(&schema),
284 vec![
285 Arc::new(StringArray::from(keys)),
286 Arc::new(LargeBinaryArray::from(values)),
287 Arc::new(Int64Array::from(ordinality)),
288 ],
289 )?;
290 Ok(Arc::new(MemTable::try_new(schema, vec![vec![batch]])?))
291 }
292 _ => Ok(Arc::new(MemTable::try_new(schema, vec![vec![]])?)),
293 }
294 }
295}
296
297#[derive(Debug)]
303pub struct JsonbEachTextTvf;
304
305impl TableFunctionImpl for JsonbEachTextTvf {
306 fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
307 if args.len() != 1 {
308 return plan_err!("jsonb_each_text requires exactly 1 argument");
309 }
310 let schema = Arc::new(Schema::new(vec![
311 Field::new("key", DataType::Utf8, false),
312 Field::new("value", DataType::Utf8, true),
313 Field::new("ordinality", DataType::Int64, false),
314 ]));
315
316 let bytes = extract_jsonb_literal(&args[0])?;
317 let entries = bytes.as_deref().and_then(jsonb_object_entries);
318
319 match entries {
320 Some(kvs) if !kvs.is_empty() => {
321 let keys: Vec<&str> = kvs.iter().map(|(k, _)| k.as_str()).collect();
322 let texts: Vec<Option<String>> = kvs
323 .iter()
324 .map(|(_, v)| json_types::jsonb_to_text(v))
325 .collect();
326 let ordinality = ordinality_vec(kvs.len());
327 let batch = RecordBatch::try_new(
328 Arc::clone(&schema),
329 vec![
330 Arc::new(StringArray::from(keys)),
331 Arc::new(StringArray::from(texts)),
332 Arc::new(Int64Array::from(ordinality)),
333 ],
334 )?;
335 Ok(Arc::new(MemTable::try_new(schema, vec![vec![batch]])?))
336 }
337 _ => Ok(Arc::new(MemTable::try_new(schema, vec![vec![]])?)),
338 }
339 }
340}
341
342#[derive(Debug)]
348pub struct JsonbObjectKeysTvf;
349
350impl TableFunctionImpl for JsonbObjectKeysTvf {
351 fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
352 if args.len() != 1 {
353 return plan_err!("jsonb_object_keys requires exactly 1 argument");
354 }
355 let schema = Arc::new(Schema::new(vec![
356 Field::new("key", DataType::Utf8, false),
357 Field::new("ordinality", DataType::Int64, false),
358 ]));
359
360 let bytes = extract_jsonb_literal(&args[0])?;
361 let entries = bytes.as_deref().and_then(jsonb_object_entries);
362
363 match entries {
364 Some(kvs) if !kvs.is_empty() => {
365 let keys: Vec<&str> = kvs.iter().map(|(k, _)| k.as_str()).collect();
366 let ordinality = ordinality_vec(kvs.len());
367 let batch = RecordBatch::try_new(
368 Arc::clone(&schema),
369 vec![
370 Arc::new(StringArray::from(keys)),
371 Arc::new(Int64Array::from(ordinality)),
372 ],
373 )?;
374 Ok(Arc::new(MemTable::try_new(schema, vec![vec![batch]])?))
375 }
376 _ => Ok(Arc::new(MemTable::try_new(schema, vec![vec![]])?)),
377 }
378 }
379}
380
381pub fn register_json_table_functions(ctx: &datafusion::prelude::SessionContext) {
385 ctx.register_udtf("jsonb_array_elements", Arc::new(JsonbArrayElementsTvf));
386 ctx.register_udtf(
387 "jsonb_array_elements_text",
388 Arc::new(JsonbArrayElementsTextTvf),
389 );
390 ctx.register_udtf("jsonb_each", Arc::new(JsonbEachTvf));
391 ctx.register_udtf("jsonb_each_text", Arc::new(JsonbEachTextTvf));
392 ctx.register_udtf("jsonb_object_keys", Arc::new(JsonbObjectKeysTvf));
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398 use crate::datafusion::create_session_context;
399
400 fn make_jsonb_expr(json_str: &str) -> Expr {
401 let val: serde_json::Value = serde_json::from_str(json_str).unwrap();
402 let bytes = json_types::encode_jsonb(&val);
403 Expr::Literal(ScalarValue::LargeBinary(Some(bytes)), None)
404 }
405
406 #[test]
409 fn test_array_elements_basic() {
410 let tvf = JsonbArrayElementsTvf;
411 let provider = tvf.call(&[make_jsonb_expr("[1, 2, 3]")]).unwrap();
412 let schema = provider.schema();
413 assert_eq!(schema.fields().len(), 2);
414 assert_eq!(schema.field(0).name(), "value");
415 assert_eq!(schema.field(1).name(), "ordinality");
416 }
417
418 #[tokio::test]
419 async fn test_array_elements_via_sql() {
420 let ctx = create_session_context();
421 register_json_table_functions(&ctx);
422
423 let df = ctx
424 .sql("SELECT value, ordinality FROM jsonb_array_elements('[10, 20, 30]')")
425 .await
426 .unwrap();
427 let batches = df.collect().await.unwrap();
428 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
429 assert_eq!(total, 3);
430
431 let ord = batches[0]
433 .column(1)
434 .as_any()
435 .downcast_ref::<Int64Array>()
436 .unwrap();
437 assert_eq!(ord.value(0), 1);
438 assert_eq!(ord.value(1), 2);
439 assert_eq!(ord.value(2), 3);
440 }
441
442 #[tokio::test]
443 async fn test_array_elements_empty() {
444 let ctx = create_session_context();
445 register_json_table_functions(&ctx);
446
447 let df = ctx
448 .sql("SELECT value FROM jsonb_array_elements('[]')")
449 .await
450 .unwrap();
451 let batches = df.collect().await.unwrap();
452 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
453 assert_eq!(total, 0);
454 }
455
456 #[tokio::test]
457 async fn test_array_elements_not_array() {
458 let ctx = create_session_context();
459 register_json_table_functions(&ctx);
460
461 let df = ctx
462 .sql("SELECT value FROM jsonb_array_elements('{\"a\":1}')")
463 .await
464 .unwrap();
465 let batches = df.collect().await.unwrap();
466 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
467 assert_eq!(total, 0); }
469
470 #[tokio::test]
473 async fn test_array_elements_text_strings() {
474 let ctx = create_session_context();
475 register_json_table_functions(&ctx);
476
477 let df = ctx
478 .sql("SELECT value FROM jsonb_array_elements_text('[\"a\", \"b\", \"c\"]')")
479 .await
480 .unwrap();
481 let batches = df.collect().await.unwrap();
482 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
483 assert_eq!(total, 3);
484
485 let vals = batches[0]
486 .column(0)
487 .as_any()
488 .downcast_ref::<StringArray>()
489 .unwrap();
490 assert_eq!(vals.value(0), "a");
491 assert_eq!(vals.value(1), "b");
492 assert_eq!(vals.value(2), "c");
493 }
494
495 #[tokio::test]
496 async fn test_array_elements_text_mixed() {
497 let ctx = create_session_context();
498 register_json_table_functions(&ctx);
499
500 let df = ctx
501 .sql("SELECT value FROM jsonb_array_elements_text('[1, \"hello\", true]')")
502 .await
503 .unwrap();
504 let batches = df.collect().await.unwrap();
505 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
506 assert_eq!(total, 3);
507 }
508
509 #[tokio::test]
512 async fn test_each_basic() {
513 let ctx = create_session_context();
514 register_json_table_functions(&ctx);
515
516 let df = ctx
517 .sql("SELECT key, ordinality FROM jsonb_each('{\"a\":1,\"b\":2}')")
518 .await
519 .unwrap();
520 let batches = df.collect().await.unwrap();
521 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
522 assert_eq!(total, 2);
523 }
524
525 #[tokio::test]
526 async fn test_each_empty() {
527 let ctx = create_session_context();
528 register_json_table_functions(&ctx);
529
530 let df = ctx.sql("SELECT key FROM jsonb_each('{}')").await.unwrap();
531 let batches = df.collect().await.unwrap();
532 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
533 assert_eq!(total, 0);
534 }
535
536 #[tokio::test]
539 async fn test_each_text_basic() {
540 let ctx = create_session_context();
541 register_json_table_functions(&ctx);
542
543 let df = ctx
544 .sql("SELECT key, value FROM jsonb_each_text('{\"x\":\"hello\",\"y\":42}')")
545 .await
546 .unwrap();
547 let batches = df.collect().await.unwrap();
548 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
549 assert_eq!(total, 2);
550 }
551
552 #[tokio::test]
555 async fn test_object_keys_basic() {
556 let ctx = create_session_context();
557 register_json_table_functions(&ctx);
558
559 let df = ctx
560 .sql("SELECT key FROM jsonb_object_keys('{\"a\":1,\"b\":2,\"c\":3}')")
561 .await
562 .unwrap();
563 let batches = df.collect().await.unwrap();
564 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
565 assert_eq!(total, 3);
566 }
567
568 #[tokio::test]
569 async fn test_object_keys_empty() {
570 let ctx = create_session_context();
571 register_json_table_functions(&ctx);
572
573 let df = ctx
574 .sql("SELECT key FROM jsonb_object_keys('{}')")
575 .await
576 .unwrap();
577 let batches = df.collect().await.unwrap();
578 let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
579 assert_eq!(total, 0);
580 }
581
582 #[test]
585 fn test_registration() {
586 let ctx = create_session_context();
587 register_json_table_functions(&ctx);
588 assert!(ctx.table_function("jsonb_array_elements").is_ok());
589 assert!(ctx.table_function("jsonb_array_elements_text").is_ok());
590 assert!(ctx.table_function("jsonb_each").is_ok());
591 assert!(ctx.table_function("jsonb_each_text").is_ok());
592 assert!(ctx.table_function("jsonb_object_keys").is_ok());
593 }
594}