laminar_connectors/serde/
csv.rs1use std::sync::Arc;
6
7use arrow_array::builder::{Float64Builder, Int64Builder, StringBuilder};
8use arrow_array::{ArrayRef, RecordBatch};
9use arrow_schema::{DataType, SchemaRef};
10
11use super::{Format, RecordDeserializer, RecordSerializer};
12use crate::error::SerdeError;
13
14#[derive(Debug, Clone)]
21pub struct CsvDeserializer {
22 delimiter: u8,
24}
25
26impl CsvDeserializer {
27 #[must_use]
29 pub fn new() -> Self {
30 Self { delimiter: b',' }
31 }
32
33 #[must_use]
35 pub fn with_delimiter(delimiter: u8) -> Self {
36 Self { delimiter }
37 }
38
39 fn split_fields<'a>(&self, line: &'a str) -> Vec<&'a str> {
41 let delim = self.delimiter as char;
42 let mut fields = Vec::new();
43 let mut start = 0;
44 let mut in_quotes = false;
45
46 for (i, ch) in line.char_indices() {
47 if ch == '"' {
48 in_quotes = !in_quotes;
49 } else if ch == delim && !in_quotes {
50 fields.push(line[start..i].trim().trim_matches('"'));
51 start = i + 1;
52 }
53 }
54 fields.push(line[start..].trim().trim_matches('"'));
55 fields
56 }
57}
58
59impl Default for CsvDeserializer {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65impl RecordDeserializer for CsvDeserializer {
66 fn deserialize(&self, data: &[u8], schema: &SchemaRef) -> Result<RecordBatch, SerdeError> {
67 let line = std::str::from_utf8(data)
68 .map_err(|e| SerdeError::Csv(format!("invalid UTF-8: {e}")))?;
69 let line = line.trim();
70 if line.is_empty() {
71 return Ok(RecordBatch::new_empty(schema.clone()));
72 }
73
74 let fields = self.split_fields(line);
75
76 if fields.len() != schema.fields().len() {
77 return Err(SerdeError::Csv(format!(
78 "expected {} fields, got {}",
79 schema.fields().len(),
80 fields.len()
81 )));
82 }
83
84 let mut columns: Vec<ArrayRef> = Vec::with_capacity(schema.fields().len());
85
86 for (idx, field) in schema.fields().iter().enumerate() {
87 let raw = fields[idx];
88 let array = parse_csv_field(raw, field.data_type(), field.name(), field.is_nullable())?;
89 columns.push(array);
90 }
91
92 RecordBatch::try_new(schema.clone(), columns)
93 .map_err(|e| SerdeError::Csv(format!("failed to create RecordBatch: {e}")))
94 }
95
96 fn format(&self) -> Format {
97 Format::Csv
98 }
99}
100
101#[derive(Debug, Clone)]
105pub struct CsvSerializer {
106 delimiter: u8,
108}
109
110impl CsvSerializer {
111 #[must_use]
113 pub fn new() -> Self {
114 Self { delimiter: b',' }
115 }
116
117 #[must_use]
119 pub fn with_delimiter(delimiter: u8) -> Self {
120 Self { delimiter }
121 }
122}
123
124impl Default for CsvSerializer {
125 fn default() -> Self {
126 Self::new()
127 }
128}
129
130impl RecordSerializer for CsvSerializer {
131 fn serialize(&self, batch: &RecordBatch) -> Result<Vec<Vec<u8>>, SerdeError> {
132 let delim = self.delimiter as char;
133 let schema = batch.schema();
134 let mut records = Vec::with_capacity(batch.num_rows());
135
136 for row in 0..batch.num_rows() {
137 let mut line = String::new();
138 for (col_idx, field) in schema.fields().iter().enumerate() {
139 if col_idx > 0 {
140 line.push(delim);
141 }
142 let column = batch.column(col_idx);
143 if column.is_null(row) {
144 } else {
146 let s = arrow_column_to_csv_string(column, row, field.data_type())?;
147 if s.contains(delim) || s.contains('"') || s.contains('\n') {
149 line.push('"');
150 line.push_str(&s.replace('"', "\"\""));
151 line.push('"');
152 } else {
153 line.push_str(&s);
154 }
155 }
156 }
157 records.push(line.into_bytes());
158 }
159
160 Ok(records)
161 }
162
163 fn serialize_batch(&self, batch: &RecordBatch) -> Result<Vec<u8>, SerdeError> {
164 let records = self.serialize(batch)?;
165 let total_len: usize = records.iter().map(|r| r.len() + 1).sum();
166 let mut buf = Vec::with_capacity(total_len);
167 for record in &records {
168 buf.extend_from_slice(record);
169 buf.push(b'\n');
170 }
171 Ok(buf)
172 }
173
174 fn format(&self) -> Format {
175 Format::Csv
176 }
177}
178
179fn parse_csv_field(
181 raw: &str,
182 data_type: &DataType,
183 field_name: &str,
184 nullable: bool,
185) -> Result<ArrayRef, SerdeError> {
186 if raw.is_empty() || raw.eq_ignore_ascii_case("null") {
187 if !nullable {
188 return Err(SerdeError::MissingField(field_name.into()));
189 }
190 return match data_type {
192 DataType::Int64 => {
193 let mut b = Int64Builder::with_capacity(1);
194 b.append_null();
195 Ok(Arc::new(b.finish()))
196 }
197 DataType::Float64 => {
198 let mut b = Float64Builder::with_capacity(1);
199 b.append_null();
200 Ok(Arc::new(b.finish()))
201 }
202 DataType::Utf8 => {
203 let mut b = StringBuilder::with_capacity(1, 0);
204 b.append_null();
205 Ok(Arc::new(b.finish()))
206 }
207 _ => Err(SerdeError::UnsupportedFormat(format!(
208 "unsupported type for CSV null: {data_type}"
209 ))),
210 };
211 }
212
213 match data_type {
214 DataType::Int64 => {
215 let v: i64 = raw.parse().map_err(|e| SerdeError::TypeConversion {
216 field: field_name.into(),
217 expected: "Int64".into(),
218 message: format!("{e}"),
219 })?;
220 let mut b = Int64Builder::with_capacity(1);
221 b.append_value(v);
222 Ok(Arc::new(b.finish()))
223 }
224 DataType::Float64 => {
225 let v: f64 = raw.parse().map_err(|e| SerdeError::TypeConversion {
226 field: field_name.into(),
227 expected: "Float64".into(),
228 message: format!("{e}"),
229 })?;
230 let mut b = Float64Builder::with_capacity(1);
231 b.append_value(v);
232 Ok(Arc::new(b.finish()))
233 }
234 DataType::Utf8 => {
235 let mut b = StringBuilder::with_capacity(1, raw.len());
236 b.append_value(raw);
237 Ok(Arc::new(b.finish()))
238 }
239 other => Err(SerdeError::UnsupportedFormat(format!(
240 "unsupported type for CSV: {other}"
241 ))),
242 }
243}
244
245fn arrow_column_to_csv_string(
247 column: &ArrayRef,
248 row: usize,
249 data_type: &DataType,
250) -> Result<String, SerdeError> {
251 use arrow_array::{BooleanArray, Float64Array, Int64Array, StringArray};
252
253 match data_type {
254 DataType::Int64 => {
255 let arr = column.as_any().downcast_ref::<Int64Array>().unwrap();
256 Ok(arr.value(row).to_string())
257 }
258 DataType::Float64 => {
259 let arr = column.as_any().downcast_ref::<Float64Array>().unwrap();
260 Ok(arr.value(row).to_string())
261 }
262 DataType::Utf8 => {
263 let arr = column.as_any().downcast_ref::<StringArray>().unwrap();
264 Ok(arr.value(row).to_string())
265 }
266 DataType::Boolean => {
267 let arr = column.as_any().downcast_ref::<BooleanArray>().unwrap();
268 Ok(arr.value(row).to_string())
269 }
270 other => Err(SerdeError::UnsupportedFormat(format!(
271 "unsupported type for CSV serialization: {other}"
272 ))),
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use arrow_schema::{Field, Schema};
280
281 fn test_schema() -> SchemaRef {
282 Arc::new(Schema::new(vec![
283 Field::new("id", DataType::Int64, false),
284 Field::new("name", DataType::Utf8, false),
285 Field::new("score", DataType::Float64, true),
286 ]))
287 }
288
289 #[test]
290 fn test_csv_deserialize_basic() {
291 let deser = CsvDeserializer::new();
292 let schema = test_schema();
293 let data = b"1,Alice,95.5";
294
295 let batch = deser.deserialize(data, &schema).unwrap();
296 assert_eq!(batch.num_rows(), 1);
297
298 let ids = batch
299 .column(0)
300 .as_any()
301 .downcast_ref::<arrow_array::Int64Array>()
302 .unwrap();
303 assert_eq!(ids.value(0), 1);
304
305 let names = batch
306 .column(1)
307 .as_any()
308 .downcast_ref::<arrow_array::StringArray>()
309 .unwrap();
310 assert_eq!(names.value(0), "Alice");
311 }
312
313 #[test]
314 fn test_csv_serialize_roundtrip() {
315 let deser = CsvDeserializer::new();
316 let ser = CsvSerializer::new();
317 let schema = test_schema();
318
319 let data = b"42,Charlie,88.5";
320 let batch = deser.deserialize(data, &schema).unwrap();
321
322 let serialized = ser.serialize(&batch).unwrap();
323 assert_eq!(serialized.len(), 1);
324
325 let line = std::str::from_utf8(&serialized[0]).unwrap();
326 assert!(line.contains("42"));
327 assert!(line.contains("Charlie"));
328 }
329
330 #[test]
331 fn test_csv_null_handling() {
332 let deser = CsvDeserializer::new();
333 let schema = test_schema();
334 let data = b"1,Bob,";
335
336 let batch = deser.deserialize(data, &schema).unwrap();
337 assert!(batch.column(2).is_null(0));
338 }
339
340 #[test]
341 fn test_csv_wrong_field_count() {
342 let deser = CsvDeserializer::new();
343 let schema = test_schema();
344 let data = b"1,Alice";
345
346 let result = deser.deserialize(data, &schema);
347 assert!(result.is_err());
348 }
349
350 #[test]
351 fn test_csv_quoted_fields() {
352 let deser = CsvDeserializer::new();
353 let schema = Arc::new(Schema::new(vec![
354 Field::new("id", DataType::Int64, false),
355 Field::new("desc", DataType::Utf8, false),
356 ]));
357 let data = b"1,\"hello, world\"";
358
359 let batch = deser.deserialize(data, &schema).unwrap();
360 let descs = batch
361 .column(1)
362 .as_any()
363 .downcast_ref::<arrow_array::StringArray>()
364 .unwrap();
365 assert_eq!(descs.value(0), "hello, world");
366 }
367}