Skip to main content

laminar_connectors/files/
arrow_ipc_codec.rs

1//! Arrow IPC file format decoder and encoder.
2
3use std::io::Cursor;
4
5use arrow_array::RecordBatch;
6use arrow_schema::SchemaRef;
7
8use crate::schema::error::{SchemaError, SchemaResult};
9use crate::schema::traits::{FormatDecoder, FormatEncoder};
10use crate::schema::types::RawRecord;
11
12/// Decodes Arrow IPC file bytes into `RecordBatch`es.
13///
14/// The constructor schema is used for `output_schema()` and empty-batch
15/// returns. Actual decoded batches carry the file's embedded schema
16/// (same contract as `ParquetDecoder`).
17pub struct ArrowIpcDecoder {
18    schema: SchemaRef,
19}
20
21impl ArrowIpcDecoder {
22    /// Creates a decoder with the given declared schema.
23    #[must_use]
24    pub fn new(schema: SchemaRef) -> Self {
25        Self { schema }
26    }
27}
28
29impl std::fmt::Debug for ArrowIpcDecoder {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.debug_struct("ArrowIpcDecoder")
32            .field("schema", &self.schema)
33            .finish()
34    }
35}
36
37impl FormatDecoder for ArrowIpcDecoder {
38    fn output_schema(&self) -> SchemaRef {
39        self.schema.clone()
40    }
41
42    fn decode_batch(&self, records: &[RawRecord]) -> SchemaResult<RecordBatch> {
43        if records.is_empty() {
44            return Ok(RecordBatch::new_empty(self.schema.clone()));
45        }
46
47        let mut combined = Vec::with_capacity(records.iter().map(|r| r.value.len()).sum());
48        for record in records {
49            combined.extend_from_slice(&record.value);
50        }
51
52        let cursor = Cursor::new(&combined);
53        let reader = arrow_ipc::reader::FileReader::try_new(cursor, None)
54            .map_err(|e| SchemaError::DecodeError(format!("Arrow IPC read error: {e}")))?;
55
56        let file_schema = reader.schema();
57
58        let mut batches = Vec::new();
59        for batch_result in reader {
60            let batch = batch_result
61                .map_err(|e| SchemaError::DecodeError(format!("Arrow IPC batch error: {e}")))?;
62            batches.push(batch);
63        }
64
65        if batches.is_empty() {
66            return Ok(RecordBatch::new_empty(file_schema));
67        }
68
69        if batches.len() == 1 {
70            return Ok(batches.into_iter().next().unwrap());
71        }
72
73        arrow_select::concat::concat_batches(&file_schema, &batches)
74            .map_err(|e| SchemaError::DecodeError(format!("Arrow IPC concat error: {e}")))
75    }
76
77    fn format_name(&self) -> &str {
78        "arrow_ipc"
79    }
80}
81
82/// Encodes `RecordBatch`es into Arrow IPC file format bytes.
83#[derive(Debug)]
84pub struct ArrowIpcEncoder {
85    schema: SchemaRef,
86}
87
88impl ArrowIpcEncoder {
89    /// Creates a new Arrow IPC encoder for the given schema.
90    #[must_use]
91    pub fn new(schema: SchemaRef) -> Self {
92        Self { schema }
93    }
94}
95
96impl FormatEncoder for ArrowIpcEncoder {
97    fn input_schema(&self) -> SchemaRef {
98        self.schema.clone()
99    }
100
101    fn encode_batch(&self, batch: &RecordBatch) -> SchemaResult<Vec<Vec<u8>>> {
102        if batch.num_rows() == 0 {
103            return Ok(Vec::new());
104        }
105
106        let mut buf = Vec::new();
107        {
108            let mut writer = arrow_ipc::writer::FileWriter::try_new(&mut buf, &batch.schema())
109                .map_err(|e| SchemaError::DecodeError(format!("Arrow IPC writer init: {e}")))?;
110            writer
111                .write(batch)
112                .map_err(|e| SchemaError::DecodeError(format!("Arrow IPC write error: {e}")))?;
113            writer
114                .finish()
115                .map_err(|e| SchemaError::DecodeError(format!("Arrow IPC finish error: {e}")))?;
116        }
117
118        Ok(vec![buf])
119    }
120
121    fn format_name(&self) -> &str {
122        "arrow_ipc"
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129    use std::sync::Arc;
130
131    use arrow_array::{Int64Array, StringArray};
132    use arrow_schema::{DataType, Field, Schema};
133
134    fn test_schema() -> SchemaRef {
135        Arc::new(Schema::new(vec![
136            Field::new("id", DataType::Int64, false),
137            Field::new("name", DataType::Utf8, true),
138        ]))
139    }
140
141    fn test_batch(schema: &SchemaRef) -> RecordBatch {
142        RecordBatch::try_new(
143            schema.clone(),
144            vec![
145                Arc::new(Int64Array::from(vec![1, 2, 3])),
146                Arc::new(StringArray::from(vec![Some("a"), None, Some("c")])),
147            ],
148        )
149        .unwrap()
150    }
151
152    #[test]
153    fn test_encode_decode_roundtrip() {
154        let schema = test_schema();
155        let batch = test_batch(&schema);
156
157        // Encode
158        let encoder = ArrowIpcEncoder::new(schema.clone());
159        let encoded = encoder.encode_batch(&batch).unwrap();
160        assert_eq!(encoded.len(), 1);
161
162        // Decode
163        let decoder = ArrowIpcDecoder::new(schema);
164        let record = RawRecord::new(encoded.into_iter().next().unwrap());
165        let decoded = decoder.decode_batch(&[record]).unwrap();
166
167        assert_eq!(decoded.num_rows(), 3);
168        assert_eq!(decoded.num_columns(), 2);
169        assert_eq!(
170            decoded
171                .column(0)
172                .as_any()
173                .downcast_ref::<Int64Array>()
174                .unwrap()
175                .value(0),
176            1
177        );
178        assert!(decoded.column(1).is_null(1));
179    }
180
181    #[test]
182    fn test_encode_empty_batch() {
183        let schema = test_schema();
184        let batch = RecordBatch::new_empty(schema.clone());
185        let encoder = ArrowIpcEncoder::new(schema);
186        let encoded = encoder.encode_batch(&batch).unwrap();
187        assert!(encoded.is_empty());
188    }
189
190    #[test]
191    fn test_decode_empty_records() {
192        let schema = test_schema();
193        let decoder = ArrowIpcDecoder::new(schema.clone());
194        let batch = decoder.decode_batch(&[]).unwrap();
195        assert_eq!(batch.num_rows(), 0);
196        assert_eq!(batch.schema(), schema);
197    }
198}