1use arrow_array::builder::{
8 ArrayBuilder, BinaryBuilder, BooleanBuilder, Float32Builder, Float64Builder, Int16Builder,
9 Int32Builder, Int64Builder, Int8Builder, StringBuilder, TimestampMicrosecondBuilder,
10 UInt16Builder, UInt32Builder, UInt64Builder, UInt8Builder,
11};
12use arrow_array::RecordBatch;
13use arrow_schema::{ArrowError, DataType, SchemaRef, TimeUnit};
14
15use super::row::{EventRow, FieldType, RowError, RowSchema};
16
17#[derive(Debug, thiserror::Error)]
19pub enum BridgeError {
20 #[error("bridge is full (capacity: {0})")]
22 Full(usize),
23 #[error("row schema error: {0}")]
25 Schema(#[from] RowError),
26 #[error("arrow error: {0}")]
28 Arrow(#[from] ArrowError),
29}
30
31pub struct RowBatchBridge {
37 schema: SchemaRef,
38 row_schema: RowSchema,
39 builders: Vec<Box<dyn ArrayBuilder>>,
40 row_count: usize,
41 capacity: usize,
42}
43
44impl RowBatchBridge {
45 pub fn new(schema: SchemaRef, capacity: usize) -> Result<Self, BridgeError> {
51 let row_schema = RowSchema::from_arrow(&schema)?;
52 let builders = create_builders(&schema, capacity);
53 Ok(Self {
54 schema,
55 row_schema,
56 builders,
57 row_count: 0,
58 capacity,
59 })
60 }
61
62 pub fn append_row(&mut self, row: &EventRow) -> Result<(), BridgeError> {
68 if self.row_count >= self.capacity {
69 return Err(BridgeError::Full(self.capacity));
70 }
71 for (i, layout) in self.row_schema.fields().iter().enumerate() {
72 append_field(&mut self.builders[i], layout.field_type, row, i);
73 }
74 self.row_count += 1;
75 Ok(())
76 }
77
78 #[must_use]
80 pub fn is_full(&self) -> bool {
81 self.row_count >= self.capacity
82 }
83
84 pub fn flush(&mut self) -> RecordBatch {
91 let arrays: Vec<_> = self.builders.iter_mut().map(ArrayBuilder::finish).collect();
92 let batch = RecordBatch::try_new(self.schema.clone(), arrays)
93 .expect("RowBatchBridge: schema/array mismatch in flush");
94 self.builders = create_builders(&self.schema, self.capacity);
95 self.row_count = 0;
96 batch
97 }
98
99 #[must_use]
101 pub fn row_count(&self) -> usize {
102 self.row_count
103 }
104
105 #[must_use]
107 pub fn capacity(&self) -> usize {
108 self.capacity
109 }
110}
111
112fn create_builders(schema: &SchemaRef, capacity: usize) -> Vec<Box<dyn ArrayBuilder>> {
114 schema
115 .fields()
116 .iter()
117 .map(|f| create_builder(f.data_type(), capacity))
118 .collect()
119}
120
121fn create_builder(dt: &DataType, capacity: usize) -> Box<dyn ArrayBuilder> {
122 match dt {
123 DataType::Boolean => Box::new(BooleanBuilder::with_capacity(capacity)),
124 DataType::Int8 => Box::new(Int8Builder::with_capacity(capacity)),
125 DataType::Int16 => Box::new(Int16Builder::with_capacity(capacity)),
126 DataType::Int32 => Box::new(Int32Builder::with_capacity(capacity)),
127 DataType::Int64 => Box::new(Int64Builder::with_capacity(capacity)),
128 DataType::UInt8 => Box::new(UInt8Builder::with_capacity(capacity)),
129 DataType::UInt16 => Box::new(UInt16Builder::with_capacity(capacity)),
130 DataType::UInt32 => Box::new(UInt32Builder::with_capacity(capacity)),
131 DataType::UInt64 => Box::new(UInt64Builder::with_capacity(capacity)),
132 DataType::Float32 => Box::new(Float32Builder::with_capacity(capacity)),
133 DataType::Float64 => Box::new(Float64Builder::with_capacity(capacity)),
134 DataType::Timestamp(TimeUnit::Microsecond, tz) => {
135 let builder = TimestampMicrosecondBuilder::with_capacity(capacity)
136 .with_data_type(DataType::Timestamp(TimeUnit::Microsecond, tz.clone()));
137 Box::new(builder)
138 }
139 DataType::Utf8 => Box::new(StringBuilder::with_capacity(capacity, capacity * 32)),
140 DataType::Binary => Box::new(BinaryBuilder::with_capacity(capacity, capacity * 32)),
141 other => unreachable!(
142 "unsupported data type in RowBatchBridge: {other} (should be caught by RowSchema::from_arrow)"
143 ),
144 }
145}
146
147#[allow(clippy::too_many_lines)]
149fn append_field(
150 builder: &mut Box<dyn ArrayBuilder>,
151 field_type: FieldType,
152 row: &EventRow,
153 field_idx: usize,
154) {
155 let is_null = row.is_null(field_idx);
156 match field_type {
157 FieldType::Bool => {
158 let b = builder
159 .as_any_mut()
160 .downcast_mut::<BooleanBuilder>()
161 .unwrap();
162 if is_null {
163 b.append_null();
164 } else {
165 b.append_value(row.get_bool(field_idx));
166 }
167 }
168 FieldType::Int8 => {
169 let b = builder.as_any_mut().downcast_mut::<Int8Builder>().unwrap();
170 if is_null {
171 b.append_null();
172 } else {
173 b.append_value(row.get_i8(field_idx));
174 }
175 }
176 FieldType::Int16 => {
177 let b = builder.as_any_mut().downcast_mut::<Int16Builder>().unwrap();
178 if is_null {
179 b.append_null();
180 } else {
181 b.append_value(row.get_i16(field_idx));
182 }
183 }
184 FieldType::Int32 => {
185 let b = builder.as_any_mut().downcast_mut::<Int32Builder>().unwrap();
186 if is_null {
187 b.append_null();
188 } else {
189 b.append_value(row.get_i32(field_idx));
190 }
191 }
192 FieldType::Int64 => {
193 let b = builder.as_any_mut().downcast_mut::<Int64Builder>().unwrap();
194 if is_null {
195 b.append_null();
196 } else {
197 b.append_value(row.get_i64(field_idx));
198 }
199 }
200 FieldType::UInt8 => {
201 let b = builder.as_any_mut().downcast_mut::<UInt8Builder>().unwrap();
202 if is_null {
203 b.append_null();
204 } else {
205 b.append_value(row.get_u8(field_idx));
206 }
207 }
208 FieldType::UInt16 => {
209 let b = builder
210 .as_any_mut()
211 .downcast_mut::<UInt16Builder>()
212 .unwrap();
213 if is_null {
214 b.append_null();
215 } else {
216 b.append_value(row.get_u16(field_idx));
217 }
218 }
219 FieldType::UInt32 => {
220 let b = builder
221 .as_any_mut()
222 .downcast_mut::<UInt32Builder>()
223 .unwrap();
224 if is_null {
225 b.append_null();
226 } else {
227 b.append_value(row.get_u32(field_idx));
228 }
229 }
230 FieldType::UInt64 => {
231 let b = builder
232 .as_any_mut()
233 .downcast_mut::<UInt64Builder>()
234 .unwrap();
235 if is_null {
236 b.append_null();
237 } else {
238 b.append_value(row.get_u64(field_idx));
239 }
240 }
241 FieldType::Float32 => {
242 let b = builder
243 .as_any_mut()
244 .downcast_mut::<Float32Builder>()
245 .unwrap();
246 if is_null {
247 b.append_null();
248 } else {
249 b.append_value(row.get_f32(field_idx));
250 }
251 }
252 FieldType::Float64 => {
253 let b = builder
254 .as_any_mut()
255 .downcast_mut::<Float64Builder>()
256 .unwrap();
257 if is_null {
258 b.append_null();
259 } else {
260 b.append_value(row.get_f64(field_idx));
261 }
262 }
263 FieldType::TimestampMicros => {
264 let b = builder
265 .as_any_mut()
266 .downcast_mut::<TimestampMicrosecondBuilder>()
267 .unwrap();
268 if is_null {
269 b.append_null();
270 } else {
271 b.append_value(row.get_i64(field_idx));
272 }
273 }
274 FieldType::Utf8 => {
275 let b = builder
276 .as_any_mut()
277 .downcast_mut::<StringBuilder>()
278 .unwrap();
279 if is_null {
280 b.append_null();
281 } else {
282 b.append_value(row.get_str(field_idx));
283 }
284 }
285 FieldType::Binary => {
286 let b = builder
287 .as_any_mut()
288 .downcast_mut::<BinaryBuilder>()
289 .unwrap();
290 if is_null {
291 b.append_null();
292 } else {
293 b.append_value(row.get_bytes(field_idx));
294 }
295 }
296 }
297}
298
299#[cfg(test)]
300#[allow(
301 clippy::approx_constant,
302 clippy::identity_op,
303 clippy::cast_possible_wrap
304)]
305mod tests {
306 use super::*;
307 use crate::compiler::row::MutableEventRow;
308 use arrow_array::{Array, BooleanArray, Float64Array, Int64Array, StringArray, UInt32Array};
309 use arrow_schema::{DataType, Field, Schema, TimeUnit};
310 use bumpalo::Bump;
311 use std::sync::Arc;
312
313 fn make_schema(fields: Vec<(&str, DataType, bool)>) -> SchemaRef {
314 Arc::new(Schema::new(
315 fields
316 .into_iter()
317 .map(|(name, dt, nullable)| Field::new(name, dt, nullable))
318 .collect::<Vec<_>>(),
319 ))
320 }
321
322 #[test]
323 fn bridge_single_row() {
324 let schema = make_schema(vec![
325 ("id", DataType::Int64, false),
326 ("val", DataType::Float64, true),
327 ]);
328 let row_schema = RowSchema::from_arrow(&schema).unwrap();
329 let mut bridge = RowBatchBridge::new(schema, 16).unwrap();
330
331 let arena = Bump::new();
332 let mut row = MutableEventRow::new_in(&arena, &row_schema, 0);
333 row.set_i64(0, 100);
334 row.set_f64(1, 3.14);
335 let row = row.freeze();
336
337 bridge.append_row(&row).unwrap();
338 assert_eq!(bridge.row_count(), 1);
339
340 let batch = bridge.flush();
341 assert_eq!(batch.num_rows(), 1);
342 assert_eq!(batch.num_columns(), 2);
343
344 let col0 = batch
345 .column(0)
346 .as_any()
347 .downcast_ref::<Int64Array>()
348 .unwrap();
349 assert_eq!(col0.value(0), 100);
350
351 let col1 = batch
352 .column(0 + 1)
353 .as_any()
354 .downcast_ref::<Float64Array>()
355 .unwrap();
356 assert!((col1.value(0) - 3.14).abs() < f64::EPSILON);
357 }
358
359 #[test]
360 fn bridge_batch_accumulation() {
361 let schema = make_schema(vec![("x", DataType::Int64, false)]);
362 let row_schema = RowSchema::from_arrow(&schema).unwrap();
363 let mut bridge = RowBatchBridge::new(schema, 100).unwrap();
364
365 let arena = Bump::new();
366 for i in 0..10 {
367 let mut row = MutableEventRow::new_in(&arena, &row_schema, 0);
368 row.set_i64(0, i);
369 let row = row.freeze();
370 bridge.append_row(&row).unwrap();
371 }
372 assert_eq!(bridge.row_count(), 10);
373 assert!(!bridge.is_full());
374
375 let batch = bridge.flush();
376 assert_eq!(batch.num_rows(), 10);
377
378 let col = batch
379 .column(0)
380 .as_any()
381 .downcast_ref::<Int64Array>()
382 .unwrap();
383 for i in 0..10 {
384 assert_eq!(col.value(i), i as i64);
385 }
386
387 assert_eq!(bridge.row_count(), 0);
389 }
390
391 #[test]
392 fn bridge_capacity_full_error() {
393 let schema = make_schema(vec![("x", DataType::Int64, false)]);
394 let row_schema = RowSchema::from_arrow(&schema).unwrap();
395 let mut bridge = RowBatchBridge::new(schema, 2).unwrap();
396
397 let arena = Bump::new();
398 for i in 0..2 {
399 let mut row = MutableEventRow::new_in(&arena, &row_schema, 0);
400 row.set_i64(0, i);
401 bridge.append_row(&row.freeze()).unwrap();
402 }
403 assert!(bridge.is_full());
404
405 let mut row = MutableEventRow::new_in(&arena, &row_schema, 0);
407 row.set_i64(0, 999);
408 let err = bridge.append_row(&row.freeze()).unwrap_err();
409 assert!(matches!(err, BridgeError::Full(2)));
410 }
411
412 #[test]
413 fn bridge_mixed_types() {
414 let schema = make_schema(vec![
415 ("flag", DataType::Boolean, false),
416 ("count", DataType::UInt32, false),
417 ("name", DataType::Utf8, true),
418 (
419 "ts",
420 DataType::Timestamp(TimeUnit::Microsecond, None),
421 false,
422 ),
423 ]);
424 let row_schema = RowSchema::from_arrow(&schema).unwrap();
425 let mut bridge = RowBatchBridge::new(schema, 16).unwrap();
426
427 let arena = Bump::new();
428 let mut row = MutableEventRow::new_in(&arena, &row_schema, 64);
429 row.set_bool(0, true);
430 row.set_u32(1, 42);
431 row.set_str(2, "test");
432 row.set_i64(3, 1_000_000);
433 bridge.append_row(&row.freeze()).unwrap();
434
435 let batch = bridge.flush();
436 assert_eq!(batch.num_rows(), 1);
437
438 let bools = batch
439 .column(0)
440 .as_any()
441 .downcast_ref::<BooleanArray>()
442 .unwrap();
443 assert!(bools.value(0));
444
445 let uints = batch
446 .column(1)
447 .as_any()
448 .downcast_ref::<UInt32Array>()
449 .unwrap();
450 assert_eq!(uints.value(0), 42);
451
452 let strs = batch
453 .column(2)
454 .as_any()
455 .downcast_ref::<StringArray>()
456 .unwrap();
457 assert_eq!(strs.value(0), "test");
458 }
459
460 #[test]
461 fn bridge_null_propagation() {
462 let schema = make_schema(vec![
463 ("a", DataType::Int64, true),
464 ("b", DataType::Utf8, true),
465 ]);
466 let row_schema = RowSchema::from_arrow(&schema).unwrap();
467 let mut bridge = RowBatchBridge::new(schema, 16).unwrap();
468
469 let arena = Bump::new();
470 let mut row = MutableEventRow::new_in(&arena, &row_schema, 64);
471 row.set_null(0, true);
472 row.set_str(1, "hello");
473 bridge.append_row(&row.freeze()).unwrap();
474
475 let mut row2 = MutableEventRow::new_in(&arena, &row_schema, 64);
476 row2.set_i64(0, 99);
477 row2.set_null(1, true);
478 bridge.append_row(&row2.freeze()).unwrap();
479
480 let batch = bridge.flush();
481 assert_eq!(batch.num_rows(), 2);
482
483 let ints = batch
484 .column(0)
485 .as_any()
486 .downcast_ref::<Int64Array>()
487 .unwrap();
488 assert!(ints.is_null(0));
489 assert_eq!(ints.value(1), 99);
490
491 let strs = batch
492 .column(1)
493 .as_any()
494 .downcast_ref::<StringArray>()
495 .unwrap();
496 assert_eq!(strs.value(0), "hello");
497 assert!(strs.is_null(1));
498 }
499
500 #[test]
501 fn bridge_flush_resets() {
502 let schema = make_schema(vec![("x", DataType::Int64, false)]);
503 let row_schema = RowSchema::from_arrow(&schema).unwrap();
504 let mut bridge = RowBatchBridge::new(schema, 4).unwrap();
505
506 let arena = Bump::new();
507 let mut row = MutableEventRow::new_in(&arena, &row_schema, 0);
508 row.set_i64(0, 1);
509 bridge.append_row(&row.freeze()).unwrap();
510
511 let batch1 = bridge.flush();
512 assert_eq!(batch1.num_rows(), 1);
513 assert_eq!(bridge.row_count(), 0);
514 assert!(!bridge.is_full());
515
516 let mut row = MutableEventRow::new_in(&arena, &row_schema, 0);
518 row.set_i64(0, 2);
519 bridge.append_row(&row.freeze()).unwrap();
520
521 let batch2 = bridge.flush();
522 assert_eq!(batch2.num_rows(), 1);
523 let col = batch2
524 .column(0)
525 .as_any()
526 .downcast_ref::<Int64Array>()
527 .unwrap();
528 assert_eq!(col.value(0), 2);
529 }
530
531 #[test]
532 fn bridge_empty_flush() {
533 let schema = make_schema(vec![("x", DataType::Int64, false)]);
534 let mut bridge = RowBatchBridge::new(schema, 4).unwrap();
535
536 let batch = bridge.flush();
537 assert_eq!(batch.num_rows(), 0);
538 assert_eq!(batch.num_columns(), 1);
539 }
540}