laminar_connectors/files/
arrow_ipc_codec.rs1use std::io::Cursor;
4
5use arrow_array::RecordBatch;
6use arrow_schema::SchemaRef;
7
8use crate::schema::error::{SchemaError, SchemaResult};
9use crate::schema::traits::{FormatDecoder, FormatEncoder};
10use crate::schema::types::RawRecord;
11
12pub struct ArrowIpcDecoder {
18 schema: SchemaRef,
19}
20
21impl ArrowIpcDecoder {
22 #[must_use]
24 pub fn new(schema: SchemaRef) -> Self {
25 Self { schema }
26 }
27}
28
29impl std::fmt::Debug for ArrowIpcDecoder {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 f.debug_struct("ArrowIpcDecoder")
32 .field("schema", &self.schema)
33 .finish()
34 }
35}
36
37impl FormatDecoder for ArrowIpcDecoder {
38 fn output_schema(&self) -> SchemaRef {
39 self.schema.clone()
40 }
41
42 fn decode_batch(&self, records: &[RawRecord]) -> SchemaResult<RecordBatch> {
43 if records.is_empty() {
44 return Ok(RecordBatch::new_empty(self.schema.clone()));
45 }
46
47 let mut combined = Vec::with_capacity(records.iter().map(|r| r.value.len()).sum());
48 for record in records {
49 combined.extend_from_slice(&record.value);
50 }
51
52 let cursor = Cursor::new(&combined);
53 let reader = arrow_ipc::reader::FileReader::try_new(cursor, None)
54 .map_err(|e| SchemaError::DecodeError(format!("Arrow IPC read error: {e}")))?;
55
56 let file_schema = reader.schema();
57
58 let mut batches = Vec::new();
59 for batch_result in reader {
60 let batch = batch_result
61 .map_err(|e| SchemaError::DecodeError(format!("Arrow IPC batch error: {e}")))?;
62 batches.push(batch);
63 }
64
65 if batches.is_empty() {
66 return Ok(RecordBatch::new_empty(file_schema));
67 }
68
69 if batches.len() == 1 {
70 return Ok(batches.into_iter().next().unwrap());
71 }
72
73 arrow_select::concat::concat_batches(&file_schema, &batches)
74 .map_err(|e| SchemaError::DecodeError(format!("Arrow IPC concat error: {e}")))
75 }
76
77 fn format_name(&self) -> &str {
78 "arrow_ipc"
79 }
80}
81
82#[derive(Debug)]
84pub struct ArrowIpcEncoder {
85 schema: SchemaRef,
86}
87
88impl ArrowIpcEncoder {
89 #[must_use]
91 pub fn new(schema: SchemaRef) -> Self {
92 Self { schema }
93 }
94}
95
96impl FormatEncoder for ArrowIpcEncoder {
97 fn input_schema(&self) -> SchemaRef {
98 self.schema.clone()
99 }
100
101 fn encode_batch(&self, batch: &RecordBatch) -> SchemaResult<Vec<Vec<u8>>> {
102 if batch.num_rows() == 0 {
103 return Ok(Vec::new());
104 }
105
106 let mut buf = Vec::new();
107 {
108 let mut writer = arrow_ipc::writer::FileWriter::try_new(&mut buf, &batch.schema())
109 .map_err(|e| SchemaError::DecodeError(format!("Arrow IPC writer init: {e}")))?;
110 writer
111 .write(batch)
112 .map_err(|e| SchemaError::DecodeError(format!("Arrow IPC write error: {e}")))?;
113 writer
114 .finish()
115 .map_err(|e| SchemaError::DecodeError(format!("Arrow IPC finish error: {e}")))?;
116 }
117
118 Ok(vec![buf])
119 }
120
121 fn format_name(&self) -> &str {
122 "arrow_ipc"
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129 use std::sync::Arc;
130
131 use arrow_array::{Int64Array, StringArray};
132 use arrow_schema::{DataType, Field, Schema};
133
134 fn test_schema() -> SchemaRef {
135 Arc::new(Schema::new(vec![
136 Field::new("id", DataType::Int64, false),
137 Field::new("name", DataType::Utf8, true),
138 ]))
139 }
140
141 fn test_batch(schema: &SchemaRef) -> RecordBatch {
142 RecordBatch::try_new(
143 schema.clone(),
144 vec![
145 Arc::new(Int64Array::from(vec![1, 2, 3])),
146 Arc::new(StringArray::from(vec![Some("a"), None, Some("c")])),
147 ],
148 )
149 .unwrap()
150 }
151
152 #[test]
153 fn test_encode_decode_roundtrip() {
154 let schema = test_schema();
155 let batch = test_batch(&schema);
156
157 let encoder = ArrowIpcEncoder::new(schema.clone());
159 let encoded = encoder.encode_batch(&batch).unwrap();
160 assert_eq!(encoded.len(), 1);
161
162 let decoder = ArrowIpcDecoder::new(schema);
164 let record = RawRecord::new(encoded.into_iter().next().unwrap());
165 let decoded = decoder.decode_batch(&[record]).unwrap();
166
167 assert_eq!(decoded.num_rows(), 3);
168 assert_eq!(decoded.num_columns(), 2);
169 assert_eq!(
170 decoded
171 .column(0)
172 .as_any()
173 .downcast_ref::<Int64Array>()
174 .unwrap()
175 .value(0),
176 1
177 );
178 assert!(decoded.column(1).is_null(1));
179 }
180
181 #[test]
182 fn test_encode_empty_batch() {
183 let schema = test_schema();
184 let batch = RecordBatch::new_empty(schema.clone());
185 let encoder = ArrowIpcEncoder::new(schema);
186 let encoded = encoder.encode_batch(&batch).unwrap();
187 assert!(encoded.is_empty());
188 }
189
190 #[test]
191 fn test_decode_empty_records() {
192 let schema = test_schema();
193 let decoder = ArrowIpcDecoder::new(schema.clone());
194 let batch = decoder.decode_batch(&[]).unwrap();
195 assert_eq!(batch.num_rows(), 0);
196 assert_eq!(batch.schema(), schema);
197 }
198}