Skip to main content

laminar_connectors/serde/
csv.rs

1//! CSV serialization and deserialization.
2//!
3//! Converts between CSV records and Arrow `RecordBatch`.
4
5use std::sync::Arc;
6
7use arrow_array::builder::{Float64Builder, Int64Builder, StringBuilder};
8use arrow_array::{ArrayRef, RecordBatch};
9use arrow_schema::{DataType, SchemaRef};
10
11use super::{Format, RecordDeserializer, RecordSerializer};
12use crate::error::SerdeError;
13
14/// CSV record deserializer.
15///
16/// Parses CSV lines and maps columns by position to the schema fields.
17/// The first line is treated as data (not header) unless configured otherwise.
18///
19/// Supports: Int64, Float64, Utf8 field types.
20#[derive(Debug, Clone)]
21pub struct CsvDeserializer {
22    /// Field delimiter character.
23    delimiter: u8,
24}
25
26impl CsvDeserializer {
27    /// Creates a new CSV deserializer with comma delimiter.
28    #[must_use]
29    pub fn new() -> Self {
30        Self { delimiter: b',' }
31    }
32
33    /// Creates a CSV deserializer with a custom delimiter.
34    #[must_use]
35    pub fn with_delimiter(delimiter: u8) -> Self {
36        Self { delimiter }
37    }
38
39    /// Splits a CSV line into fields, respecting quoted values.
40    fn split_fields<'a>(&self, line: &'a str) -> Vec<&'a str> {
41        let delim = self.delimiter as char;
42        let mut fields = Vec::new();
43        let mut start = 0;
44        let mut in_quotes = false;
45
46        for (i, ch) in line.char_indices() {
47            if ch == '"' {
48                in_quotes = !in_quotes;
49            } else if ch == delim && !in_quotes {
50                fields.push(line[start..i].trim().trim_matches('"'));
51                start = i + 1;
52            }
53        }
54        fields.push(line[start..].trim().trim_matches('"'));
55        fields
56    }
57}
58
59impl Default for CsvDeserializer {
60    fn default() -> Self {
61        Self::new()
62    }
63}
64
65impl RecordDeserializer for CsvDeserializer {
66    fn deserialize(&self, data: &[u8], schema: &SchemaRef) -> Result<RecordBatch, SerdeError> {
67        let line = std::str::from_utf8(data)
68            .map_err(|e| SerdeError::Csv(format!("invalid UTF-8: {e}")))?;
69        let line = line.trim();
70        if line.is_empty() {
71            return Ok(RecordBatch::new_empty(schema.clone()));
72        }
73
74        let fields = self.split_fields(line);
75
76        if fields.len() != schema.fields().len() {
77            return Err(SerdeError::Csv(format!(
78                "expected {} fields, got {}",
79                schema.fields().len(),
80                fields.len()
81            )));
82        }
83
84        let mut columns: Vec<ArrayRef> = Vec::with_capacity(schema.fields().len());
85
86        for (idx, field) in schema.fields().iter().enumerate() {
87            let raw = fields[idx];
88            let array = parse_csv_field(raw, field.data_type(), field.name(), field.is_nullable())?;
89            columns.push(array);
90        }
91
92        RecordBatch::try_new(schema.clone(), columns)
93            .map_err(|e| SerdeError::Csv(format!("failed to create RecordBatch: {e}")))
94    }
95
96    fn format(&self) -> Format {
97        Format::Csv
98    }
99}
100
101/// CSV record serializer.
102///
103/// Converts Arrow `RecordBatch` rows to CSV lines.
104#[derive(Debug, Clone)]
105pub struct CsvSerializer {
106    /// Field delimiter character.
107    delimiter: u8,
108}
109
110impl CsvSerializer {
111    /// Creates a new CSV serializer with comma delimiter.
112    #[must_use]
113    pub fn new() -> Self {
114        Self { delimiter: b',' }
115    }
116
117    /// Creates a CSV serializer with a custom delimiter.
118    #[must_use]
119    pub fn with_delimiter(delimiter: u8) -> Self {
120        Self { delimiter }
121    }
122}
123
124impl Default for CsvSerializer {
125    fn default() -> Self {
126        Self::new()
127    }
128}
129
130impl RecordSerializer for CsvSerializer {
131    fn serialize(&self, batch: &RecordBatch) -> Result<Vec<Vec<u8>>, SerdeError> {
132        let delim = self.delimiter as char;
133        let schema = batch.schema();
134        let mut records = Vec::with_capacity(batch.num_rows());
135
136        for row in 0..batch.num_rows() {
137            let mut line = String::new();
138            for (col_idx, field) in schema.fields().iter().enumerate() {
139                if col_idx > 0 {
140                    line.push(delim);
141                }
142                let column = batch.column(col_idx);
143                if column.is_null(row) {
144                    // Empty field for null
145                } else {
146                    let s = arrow_column_to_csv_string(column, row, field.data_type())?;
147                    // Quote strings that contain delimiter or quotes
148                    if s.contains(delim) || s.contains('"') || s.contains('\n') {
149                        line.push('"');
150                        line.push_str(&s.replace('"', "\"\""));
151                        line.push('"');
152                    } else {
153                        line.push_str(&s);
154                    }
155                }
156            }
157            records.push(line.into_bytes());
158        }
159
160        Ok(records)
161    }
162
163    fn serialize_batch(&self, batch: &RecordBatch) -> Result<Vec<u8>, SerdeError> {
164        let records = self.serialize(batch)?;
165        let total_len: usize = records.iter().map(|r| r.len() + 1).sum();
166        let mut buf = Vec::with_capacity(total_len);
167        for record in &records {
168            buf.extend_from_slice(record);
169            buf.push(b'\n');
170        }
171        Ok(buf)
172    }
173
174    fn format(&self) -> Format {
175        Format::Csv
176    }
177}
178
179/// Parses a CSV field value into a single-element Arrow array.
180fn parse_csv_field(
181    raw: &str,
182    data_type: &DataType,
183    field_name: &str,
184    nullable: bool,
185) -> Result<ArrayRef, SerdeError> {
186    if raw.is_empty() || raw.eq_ignore_ascii_case("null") {
187        if !nullable {
188            return Err(SerdeError::MissingField(field_name.into()));
189        }
190        // Return a null array of the appropriate type
191        return match data_type {
192            DataType::Int64 => {
193                let mut b = Int64Builder::with_capacity(1);
194                b.append_null();
195                Ok(Arc::new(b.finish()))
196            }
197            DataType::Float64 => {
198                let mut b = Float64Builder::with_capacity(1);
199                b.append_null();
200                Ok(Arc::new(b.finish()))
201            }
202            DataType::Utf8 => {
203                let mut b = StringBuilder::with_capacity(1, 0);
204                b.append_null();
205                Ok(Arc::new(b.finish()))
206            }
207            _ => Err(SerdeError::UnsupportedFormat(format!(
208                "unsupported type for CSV null: {data_type}"
209            ))),
210        };
211    }
212
213    match data_type {
214        DataType::Int64 => {
215            let v: i64 = raw.parse().map_err(|e| SerdeError::TypeConversion {
216                field: field_name.into(),
217                expected: "Int64".into(),
218                message: format!("{e}"),
219            })?;
220            let mut b = Int64Builder::with_capacity(1);
221            b.append_value(v);
222            Ok(Arc::new(b.finish()))
223        }
224        DataType::Float64 => {
225            let v: f64 = raw.parse().map_err(|e| SerdeError::TypeConversion {
226                field: field_name.into(),
227                expected: "Float64".into(),
228                message: format!("{e}"),
229            })?;
230            let mut b = Float64Builder::with_capacity(1);
231            b.append_value(v);
232            Ok(Arc::new(b.finish()))
233        }
234        DataType::Utf8 => {
235            let mut b = StringBuilder::with_capacity(1, raw.len());
236            b.append_value(raw);
237            Ok(Arc::new(b.finish()))
238        }
239        other => Err(SerdeError::UnsupportedFormat(format!(
240            "unsupported type for CSV: {other}"
241        ))),
242    }
243}
244
245/// Converts an Arrow column value at `row` to a CSV string.
246fn arrow_column_to_csv_string(
247    column: &ArrayRef,
248    row: usize,
249    data_type: &DataType,
250) -> Result<String, SerdeError> {
251    use arrow_array::{BooleanArray, Float64Array, Int64Array, StringArray};
252
253    match data_type {
254        DataType::Int64 => {
255            let arr = column.as_any().downcast_ref::<Int64Array>().unwrap();
256            Ok(arr.value(row).to_string())
257        }
258        DataType::Float64 => {
259            let arr = column.as_any().downcast_ref::<Float64Array>().unwrap();
260            Ok(arr.value(row).to_string())
261        }
262        DataType::Utf8 => {
263            let arr = column.as_any().downcast_ref::<StringArray>().unwrap();
264            Ok(arr.value(row).to_string())
265        }
266        DataType::Boolean => {
267            let arr = column.as_any().downcast_ref::<BooleanArray>().unwrap();
268            Ok(arr.value(row).to_string())
269        }
270        other => Err(SerdeError::UnsupportedFormat(format!(
271            "unsupported type for CSV serialization: {other}"
272        ))),
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use arrow_schema::{Field, Schema};
280
281    fn test_schema() -> SchemaRef {
282        Arc::new(Schema::new(vec![
283            Field::new("id", DataType::Int64, false),
284            Field::new("name", DataType::Utf8, false),
285            Field::new("score", DataType::Float64, true),
286        ]))
287    }
288
289    #[test]
290    fn test_csv_deserialize_basic() {
291        let deser = CsvDeserializer::new();
292        let schema = test_schema();
293        let data = b"1,Alice,95.5";
294
295        let batch = deser.deserialize(data, &schema).unwrap();
296        assert_eq!(batch.num_rows(), 1);
297
298        let ids = batch
299            .column(0)
300            .as_any()
301            .downcast_ref::<arrow_array::Int64Array>()
302            .unwrap();
303        assert_eq!(ids.value(0), 1);
304
305        let names = batch
306            .column(1)
307            .as_any()
308            .downcast_ref::<arrow_array::StringArray>()
309            .unwrap();
310        assert_eq!(names.value(0), "Alice");
311    }
312
313    #[test]
314    fn test_csv_serialize_roundtrip() {
315        let deser = CsvDeserializer::new();
316        let ser = CsvSerializer::new();
317        let schema = test_schema();
318
319        let data = b"42,Charlie,88.5";
320        let batch = deser.deserialize(data, &schema).unwrap();
321
322        let serialized = ser.serialize(&batch).unwrap();
323        assert_eq!(serialized.len(), 1);
324
325        let line = std::str::from_utf8(&serialized[0]).unwrap();
326        assert!(line.contains("42"));
327        assert!(line.contains("Charlie"));
328    }
329
330    #[test]
331    fn test_csv_null_handling() {
332        let deser = CsvDeserializer::new();
333        let schema = test_schema();
334        let data = b"1,Bob,";
335
336        let batch = deser.deserialize(data, &schema).unwrap();
337        assert!(batch.column(2).is_null(0));
338    }
339
340    #[test]
341    fn test_csv_wrong_field_count() {
342        let deser = CsvDeserializer::new();
343        let schema = test_schema();
344        let data = b"1,Alice";
345
346        let result = deser.deserialize(data, &schema);
347        assert!(result.is_err());
348    }
349
350    #[test]
351    fn test_csv_quoted_fields() {
352        let deser = CsvDeserializer::new();
353        let schema = Arc::new(Schema::new(vec![
354            Field::new("id", DataType::Int64, false),
355            Field::new("desc", DataType::Utf8, false),
356        ]));
357        let data = b"1,\"hello, world\"";
358
359        let batch = deser.deserialize(data, &schema).unwrap();
360        let descs = batch
361            .column(1)
362            .as_any()
363            .downcast_ref::<arrow_array::StringArray>()
364            .unwrap();
365        assert_eq!(descs.value(0), "hello, world");
366    }
367}