Skip to main content

laminar_connectors/schema/
csv.rs

1//! CSV format decoder implementing [`FormatDecoder`].
2//!
3//! Converts raw CSV byte payloads into Arrow `RecordBatch`es.
4//! Constructed once at `CREATE SOURCE` time with a frozen Arrow schema
5//! and CSV format configuration. The decoder is stateless after
6//! construction so the Ring 1 hot path has zero configuration lookups.
7//!
8//! Uses the `csv` crate's `ByteRecord` API for zero-copy field access
9//! where possible. Type coercion (string → int, string → timestamp, etc.)
10//! is performed during the Arrow builder append phase.
11
12use std::sync::atomic::{AtomicU64, Ordering};
13use std::sync::Arc;
14
15use arrow_array::builder::{
16    BooleanBuilder, Date32Builder, Float64Builder, Int64Builder, StringBuilder,
17    TimestampNanosecondBuilder,
18};
19use arrow_array::{ArrayRef, RecordBatch};
20use arrow_schema::{DataType, SchemaRef, TimeUnit};
21
22use crate::schema::error::{SchemaError, SchemaResult};
23use crate::schema::traits::{FormatDecoder, FormatEncoder};
24use crate::schema::types::RawRecord;
25
26/// Strategy for rows with incorrect field count.
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum FieldCountMismatchStrategy {
29    /// Pad missing fields with null, ignore extra fields. Default.
30    Null,
31    /// Skip the malformed row entirely.
32    Skip,
33    /// Return a decode error on the first malformed row.
34    Reject,
35}
36
37/// CSV decoder configuration.
38///
39/// Maps directly to the SQL `FORMAT CSV (...)` options.
40/// All fields have sensible defaults matching RFC 4180.
41#[derive(Debug, Clone)]
42pub struct CsvDecoderConfig {
43    /// Field delimiter character. Default: `','` (comma).
44    /// Common alternatives: `'\t'` (tab), `'|'` (pipe), `';'` (semicolon).
45    pub delimiter: u8,
46
47    /// Quote character for fields containing delimiters or newlines.
48    /// Default: `'"'` (double quote). Set to `None` to disable quoting.
49    pub quote: Option<u8>,
50
51    /// Escape character within quoted fields.
52    /// Default: `None` (RFC 4180 uses doubled quote chars for escaping).
53    /// Set to `Some(b'\\')` for backslash-escaped CSVs.
54    pub escape: Option<u8>,
55
56    /// Whether the first row is a header row with column names.
57    /// Default: `true`.
58    pub has_header: bool,
59
60    /// String value to interpret as SQL NULL.
61    /// Default: `""` (empty string). Common alternatives: `"NA"`, `"null"`, `"\\N"`.
62    pub null_string: String,
63
64    /// Comment line prefix. Lines starting with this character are skipped.
65    /// Default: `None` (no comment support).
66    pub comment: Option<u8>,
67
68    /// Number of rows to skip at the beginning of the data (after header).
69    /// Default: `0`.
70    pub skip_rows: usize,
71
72    /// Timestamp format pattern for parsing timestamp columns.
73    /// Default: `"%Y-%m-%d %H:%M:%S%.f"`.
74    pub timestamp_format: String,
75
76    /// Date format pattern for parsing date columns.
77    /// Default: `"%Y-%m-%d"`.
78    pub date_format: String,
79
80    /// How to handle rows with wrong number of fields.
81    /// Default: `Null` (pad missing fields with null, truncate extra).
82    pub field_count_mismatch: FieldCountMismatchStrategy,
83}
84
85impl Default for CsvDecoderConfig {
86    fn default() -> Self {
87        Self {
88            delimiter: b',',
89            quote: Some(b'"'),
90            escape: None,
91            has_header: true,
92            null_string: String::new(),
93            comment: None,
94            skip_rows: 0,
95            timestamp_format: "%Y-%m-%d %H:%M:%S%.f".into(),
96            date_format: "%Y-%m-%d".into(),
97            field_count_mismatch: FieldCountMismatchStrategy::Null,
98        }
99    }
100}
101
102/// Pre-computed coercion strategy for a single CSV column.
103#[derive(Debug, Clone)]
104enum CsvCoercion {
105    /// Parse as boolean (`"true"`/`"false"`, `"1"`/`"0"`, `"yes"`/`"no"`).
106    Boolean,
107    /// Parse as i64.
108    Int64,
109    /// Parse as f64.
110    Float64,
111    /// Parse as `Timestamp(Nanosecond, UTC)` using the configured format.
112    Timestamp(String),
113    /// Parse as `Date32` using the configured format.
114    Date(String),
115    /// No coercion needed — keep as UTF-8 string.
116    Utf8,
117}
118
119/// Decodes CSV byte payloads into Arrow `RecordBatch`es.
120///
121/// # Ring Placement
122///
123/// - **Ring 1**: `decode_batch()` — parse CSV, build columnar Arrow output
124/// - **Ring 2**: Construction (`new` / `with_config`) — one-time setup
125pub struct CsvDecoder {
126    /// Frozen output schema.
127    schema: SchemaRef,
128    /// CSV format configuration.
129    config: CsvDecoderConfig,
130    /// Per-column type coercion functions, indexed by column position.
131    /// Pre-computed at construction time to avoid per-record dispatch.
132    coercions: Vec<CsvCoercion>,
133    /// Cumulative count of parse errors (for diagnostics).
134    parse_error_count: AtomicU64,
135}
136
137#[allow(clippy::missing_fields_in_debug)]
138impl std::fmt::Debug for CsvDecoder {
139    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140        f.debug_struct("CsvDecoder")
141            .field("schema", &self.schema)
142            .field("config", &self.config)
143            .field(
144                "parse_error_count",
145                &self.parse_error_count.load(Ordering::Relaxed),
146            )
147            .finish()
148    }
149}
150
151impl CsvDecoder {
152    /// Creates a new CSV decoder for the given Arrow schema with default config.
153    #[must_use]
154    pub fn new(schema: SchemaRef) -> Self {
155        Self::with_config(schema, CsvDecoderConfig::default())
156    }
157
158    /// Creates a new CSV decoder with custom configuration.
159    #[must_use]
160    pub fn with_config(schema: SchemaRef, config: CsvDecoderConfig) -> Self {
161        let coercions: Vec<CsvCoercion> = schema
162            .fields()
163            .iter()
164            .map(|field| Self::coercion_for_type(field.data_type(), &config))
165            .collect();
166
167        Self {
168            schema,
169            config,
170            coercions,
171            parse_error_count: AtomicU64::new(0),
172        }
173    }
174
175    /// Returns the cumulative parse error count.
176    pub fn parse_error_count(&self) -> u64 {
177        self.parse_error_count.load(Ordering::Relaxed)
178    }
179
180    /// Determines the coercion strategy for an Arrow data type.
181    fn coercion_for_type(data_type: &DataType, config: &CsvDecoderConfig) -> CsvCoercion {
182        match data_type {
183            DataType::Boolean => CsvCoercion::Boolean,
184            DataType::Int8
185            | DataType::Int16
186            | DataType::Int32
187            | DataType::Int64
188            | DataType::UInt8
189            | DataType::UInt16
190            | DataType::UInt32
191            | DataType::UInt64 => CsvCoercion::Int64,
192            DataType::Float16 | DataType::Float32 | DataType::Float64 => CsvCoercion::Float64,
193            DataType::Timestamp(_, _) => CsvCoercion::Timestamp(config.timestamp_format.clone()),
194            DataType::Date32 | DataType::Date64 => CsvCoercion::Date(config.date_format.clone()),
195            _ => CsvCoercion::Utf8,
196        }
197    }
198
199    /// Builds a `csv::ReaderBuilder` from the decoder config.
200    fn make_reader_builder(&self) -> csv::ReaderBuilder {
201        let mut rb = csv::ReaderBuilder::new();
202        rb.delimiter(self.config.delimiter)
203            .has_headers(false) // We handle headers ourselves
204            .flexible(true); // Allow variable field counts
205
206        if let Some(q) = self.config.quote {
207            rb.quote(q);
208        }
209        if let Some(e) = self.config.escape {
210            rb.escape(Some(e));
211        }
212        if let Some(c) = self.config.comment {
213            rb.comment(Some(c));
214        }
215
216        rb
217    }
218}
219
220impl FormatDecoder for CsvDecoder {
221    fn output_schema(&self) -> SchemaRef {
222        self.schema.clone()
223    }
224
225    /// Decodes a batch of raw CSV records into an Arrow `RecordBatch`.
226    ///
227    /// Each `RawRecord.value` contains one or more CSV lines (typically one
228    /// line per record for streaming sources; may contain multiple lines
229    /// for file-based sources).
230    ///
231    /// # Algorithm
232    ///
233    /// 1. Initialize one Arrow `ArrayBuilder` per schema column.
234    /// 2. Concatenate all raw record bytes into a single buffer.
235    /// 3. Create a `csv::Reader` with pre-configured settings.
236    /// 4. For each CSV row:
237    ///    - Skip rows per `skip_rows` config.
238    ///    - For each field: apply the pre-computed `CsvCoercion`.
239    ///    - Handle field count mismatches per config.
240    /// 5. Finish all builders and assemble into `RecordBatch`.
241    fn decode_batch(&self, records: &[RawRecord]) -> SchemaResult<RecordBatch> {
242        if records.is_empty() {
243            return Ok(RecordBatch::new_empty(self.schema.clone()));
244        }
245
246        let num_fields = self.schema.fields().len();
247        let capacity = records.len();
248
249        // Initialize one builder per schema column.
250        let mut builders = create_builders(&self.schema, capacity);
251
252        // Concatenate all raw record bytes, ensuring newline separation.
253        let mut combined = Vec::with_capacity(records.iter().map(|r| r.value.len() + 1).sum());
254        for record in records {
255            combined.extend_from_slice(&record.value);
256            if !record.value.ends_with(b"\n") {
257                combined.push(b'\n');
258            }
259        }
260
261        let rb = self.make_reader_builder();
262        let mut reader = rb.from_reader(combined.as_slice());
263
264        let mut rows_skipped = 0usize;
265        let mut header_skipped = false;
266        let mut row_count = 0usize;
267
268        let mut byte_record = csv::ByteRecord::new();
269        while reader
270            .read_byte_record(&mut byte_record)
271            .map_err(|e| SchemaError::DecodeError(format!("CSV parse error: {e}")))?
272        {
273            // Skip header row if configured.
274            if self.config.has_header && !header_skipped {
275                header_skipped = true;
276                continue;
277            }
278
279            // Skip initial data rows per config.
280            if rows_skipped < self.config.skip_rows {
281                rows_skipped += 1;
282                continue;
283            }
284
285            let field_count = byte_record.len();
286
287            // Handle field count mismatch.
288            if field_count != num_fields {
289                match self.config.field_count_mismatch {
290                    FieldCountMismatchStrategy::Reject => {
291                        return Err(SchemaError::DecodeError(format!(
292                            "field count mismatch: expected {num_fields}, got {field_count}"
293                        )));
294                    }
295                    FieldCountMismatchStrategy::Skip => {
296                        self.parse_error_count.fetch_add(1, Ordering::Relaxed);
297                        continue;
298                    }
299                    FieldCountMismatchStrategy::Null => {
300                        // Will pad/truncate below.
301                    }
302                }
303            }
304
305            // Process each column.
306            for col_idx in 0..num_fields {
307                if col_idx >= field_count {
308                    // Missing field — append null.
309                    append_null(&mut builders[col_idx]);
310                    continue;
311                }
312
313                let raw_field = &byte_record[col_idx];
314                let field_str = std::str::from_utf8(raw_field).unwrap_or("");
315                let trimmed = field_str.trim();
316
317                // Check for null string.
318                if trimmed == self.config.null_string {
319                    append_null(&mut builders[col_idx]);
320                    continue;
321                }
322
323                // Apply coercion.
324                let ok = append_coerced(&mut builders[col_idx], &self.coercions[col_idx], trimmed);
325
326                if !ok {
327                    self.parse_error_count.fetch_add(1, Ordering::Relaxed);
328                    append_null(&mut builders[col_idx]);
329                }
330            }
331
332            row_count += 1;
333        }
334
335        // If no data rows were processed, return empty batch.
336        if row_count == 0 {
337            return Ok(RecordBatch::new_empty(self.schema.clone()));
338        }
339
340        // Finish all builders into arrays.
341        let columns: Vec<ArrayRef> = builders.into_iter().map(|mut b| b.finish()).collect();
342
343        RecordBatch::try_new(self.schema.clone(), columns)
344            .map_err(|e| SchemaError::DecodeError(format!("RecordBatch construction: {e}")))
345    }
346
347    #[allow(clippy::unnecessary_literal_bound)]
348    fn format_name(&self) -> &str {
349        "csv"
350    }
351}
352
353// ── Builder helpers ────────────────────────────────────────────────
354
355/// Trait-object wrapper so we can store heterogeneous builders in a `Vec`.
356trait ColumnBuilder: Send {
357    fn finish(&mut self) -> ArrayRef;
358    fn append_null_value(&mut self);
359    fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
360}
361
362macro_rules! impl_column_builder {
363    ($builder:ty) => {
364        impl ColumnBuilder for $builder {
365            fn finish(&mut self) -> ArrayRef {
366                Arc::new(<$builder>::finish(self))
367            }
368            fn append_null_value(&mut self) {
369                self.append_null();
370            }
371            fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
372                self
373            }
374        }
375    };
376}
377
378impl_column_builder!(BooleanBuilder);
379impl_column_builder!(Int64Builder);
380impl_column_builder!(Float64Builder);
381impl_column_builder!(StringBuilder);
382impl_column_builder!(TimestampNanosecondBuilder);
383impl_column_builder!(Date32Builder);
384
385fn create_builders(schema: &SchemaRef, capacity: usize) -> Vec<Box<dyn ColumnBuilder>> {
386    schema
387        .fields()
388        .iter()
389        .map(|f| create_builder(f.data_type(), capacity))
390        .collect()
391}
392
393fn create_builder(data_type: &DataType, capacity: usize) -> Box<dyn ColumnBuilder> {
394    match data_type {
395        DataType::Boolean => Box::new(BooleanBuilder::with_capacity(capacity)),
396        DataType::Int8
397        | DataType::Int16
398        | DataType::Int32
399        | DataType::Int64
400        | DataType::UInt8
401        | DataType::UInt16
402        | DataType::UInt32
403        | DataType::UInt64 => Box::new(Int64Builder::with_capacity(capacity)),
404        DataType::Float16 | DataType::Float32 | DataType::Float64 => {
405            Box::new(Float64Builder::with_capacity(capacity))
406        }
407        DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
408            let builder =
409                TimestampNanosecondBuilder::with_capacity(capacity).with_timezone_opt(tz.clone());
410            Box::new(builder)
411        }
412        DataType::Date32 | DataType::Date64 => Box::new(Date32Builder::with_capacity(capacity)),
413        // Fallback: store as UTF-8 string.
414        _ => Box::new(StringBuilder::with_capacity(capacity, capacity * 32)),
415    }
416}
417
418fn append_null(builder: &mut Box<dyn ColumnBuilder>) {
419    builder.append_null_value();
420}
421
422/// Appends a coerced value to the appropriate builder. Returns `true` on
423/// success, `false` if the value could not be parsed.
424fn append_coerced(
425    builder: &mut Box<dyn ColumnBuilder>,
426    coercion: &CsvCoercion,
427    value: &str,
428) -> bool {
429    match coercion {
430        CsvCoercion::Boolean => {
431            let b = builder
432                .as_any_mut()
433                .downcast_mut::<BooleanBuilder>()
434                .unwrap();
435            match value.to_ascii_lowercase().as_str() {
436                "true" | "1" | "yes" | "t" | "y" => {
437                    b.append_value(true);
438                    true
439                }
440                "false" | "0" | "no" | "f" | "n" => {
441                    b.append_value(false);
442                    true
443                }
444                _ => false,
445            }
446        }
447        CsvCoercion::Int64 => {
448            let b = builder.as_any_mut().downcast_mut::<Int64Builder>().unwrap();
449            match value.parse::<i64>() {
450                Ok(v) => {
451                    b.append_value(v);
452                    true
453                }
454                Err(_) => false,
455            }
456        }
457        CsvCoercion::Float64 => {
458            let b = builder
459                .as_any_mut()
460                .downcast_mut::<Float64Builder>()
461                .unwrap();
462            match value.parse::<f64>() {
463                Ok(v) => {
464                    b.append_value(v);
465                    true
466                }
467                Err(_) => false,
468            }
469        }
470        CsvCoercion::Timestamp(fmt) => {
471            let b = builder
472                .as_any_mut()
473                .downcast_mut::<TimestampNanosecondBuilder>()
474                .unwrap();
475            // Try the configured format first.
476            if let Ok(ndt) = chrono::NaiveDateTime::parse_from_str(value, fmt) {
477                let nanos = ndt.and_utc().timestamp_nanos_opt().unwrap_or(0);
478                b.append_value(nanos);
479                return true;
480            }
481            // Try ISO 8601 fallback.
482            if let Ok(nanos) = arrow_cast::parse::string_to_timestamp_nanos(value) {
483                b.append_value(nanos);
484                return true;
485            }
486            false
487        }
488        CsvCoercion::Date(fmt) => {
489            let b = builder
490                .as_any_mut()
491                .downcast_mut::<Date32Builder>()
492                .unwrap();
493            if let Ok(date) = chrono::NaiveDate::parse_from_str(value, fmt) {
494                // Date32 stores days since epoch (1970-01-01).
495                let epoch = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap();
496                let days = (date - epoch).num_days();
497                #[allow(clippy::cast_possible_truncation)]
498                {
499                    b.append_value(days as i32);
500                }
501                return true;
502            }
503            false
504        }
505        CsvCoercion::Utf8 => {
506            let b = builder
507                .as_any_mut()
508                .downcast_mut::<StringBuilder>()
509                .unwrap();
510            b.append_value(value);
511            true
512        }
513    }
514}
515
516/// Configuration for [`CsvEncoder`].
517#[derive(Debug, Clone)]
518pub struct CsvEncoderConfig {
519    /// Field delimiter. Default: `','`.
520    pub delimiter: u8,
521    /// Whether to include a header row. Default: `false`.
522    pub has_header: bool,
523}
524
525impl Default for CsvEncoderConfig {
526    fn default() -> Self {
527        Self {
528            delimiter: b',',
529            has_header: false,
530        }
531    }
532}
533
534/// Encodes Arrow `RecordBatch`es into CSV byte records via `arrow_csv::writer`.
535#[derive(Debug)]
536pub struct CsvEncoder {
537    schema: SchemaRef,
538    config: CsvEncoderConfig,
539}
540
541impl CsvEncoder {
542    /// Creates a new CSV encoder for the given schema with default config.
543    #[must_use]
544    pub fn new(schema: SchemaRef) -> Self {
545        Self::with_config(schema, CsvEncoderConfig::default())
546    }
547
548    /// Creates a new CSV encoder with custom configuration.
549    #[must_use]
550    pub fn with_config(schema: SchemaRef, config: CsvEncoderConfig) -> Self {
551        Self { schema, config }
552    }
553}
554
555impl FormatEncoder for CsvEncoder {
556    fn input_schema(&self) -> SchemaRef {
557        self.schema.clone()
558    }
559
560    fn encode_batch(&self, batch: &RecordBatch) -> SchemaResult<Vec<Vec<u8>>> {
561        if batch.num_rows() == 0 {
562            return Ok(Vec::new());
563        }
564
565        let mut buf = Vec::new();
566        {
567            let writer = arrow_csv::writer::WriterBuilder::new()
568                .with_header(self.config.has_header)
569                .with_delimiter(self.config.delimiter);
570            let mut csv_writer = writer.build(&mut buf);
571            csv_writer
572                .write(batch)
573                .map_err(|e| SchemaError::DecodeError(format!("CSV encode error: {e}")))?;
574        }
575
576        let output: Vec<Vec<u8>> = buf
577            .split(|&b| b == b'\n')
578            .filter(|line| !line.is_empty())
579            .map(<[u8]>::to_vec)
580            .collect();
581
582        Ok(output)
583    }
584
585    #[allow(clippy::unnecessary_literal_bound)]
586    fn format_name(&self) -> &str {
587        "csv"
588    }
589}
590
591#[cfg(test)]
592mod tests {
593    use super::*;
594    use crate::schema::traits::FormatEncoder;
595    use arrow_array::cast::AsArray;
596    use arrow_schema::{Field, Schema};
597
598    fn make_schema(fields: Vec<(&str, DataType, bool)>) -> SchemaRef {
599        Arc::new(Schema::new(
600            fields
601                .into_iter()
602                .map(|(name, dt, nullable)| Field::new(name, dt, nullable))
603                .collect::<Vec<_>>(),
604        ))
605    }
606
607    fn csv_record(line: &str) -> RawRecord {
608        RawRecord::new(line.as_bytes().to_vec())
609    }
610
611    fn csv_block(lines: &str) -> RawRecord {
612        RawRecord::new(lines.as_bytes().to_vec())
613    }
614
615    // ── Basic decode tests ────────────────────────────────────
616
617    #[test]
618    fn test_decode_empty_batch() {
619        let schema = make_schema(vec![("id", DataType::Int64, false)]);
620        let decoder = CsvDecoder::new(schema.clone());
621        let batch = decoder.decode_batch(&[]).unwrap();
622        assert_eq!(batch.num_rows(), 0);
623        assert_eq!(batch.schema(), schema);
624    }
625
626    #[test]
627    fn test_decode_single_row_with_header() {
628        let schema = make_schema(vec![
629            ("id", DataType::Int64, false),
630            ("name", DataType::Utf8, true),
631        ]);
632        let decoder = CsvDecoder::new(schema);
633        let records = vec![csv_block("id,name\n42,Alice")];
634        let batch = decoder.decode_batch(&records).unwrap();
635
636        assert_eq!(batch.num_rows(), 1);
637        assert_eq!(
638            batch
639                .column(0)
640                .as_primitive::<arrow_array::types::Int64Type>()
641                .value(0),
642            42
643        );
644        assert_eq!(batch.column(1).as_string::<i32>().value(0), "Alice");
645    }
646
647    #[test]
648    fn test_decode_multiple_rows() {
649        let schema = make_schema(vec![
650            ("x", DataType::Int64, false),
651            ("y", DataType::Float64, false),
652        ]);
653        let decoder = CsvDecoder::new(schema);
654        let records = vec![csv_block("x,y\n1,1.5\n2,2.5\n3,3.5")];
655        let batch = decoder.decode_batch(&records).unwrap();
656
657        assert_eq!(batch.num_rows(), 3);
658        let x_col = batch
659            .column(0)
660            .as_primitive::<arrow_array::types::Int64Type>();
661        assert_eq!(x_col.value(0), 1);
662        assert_eq!(x_col.value(1), 2);
663        assert_eq!(x_col.value(2), 3);
664    }
665
666    #[test]
667    fn test_decode_all_types() {
668        let schema = make_schema(vec![
669            ("bool_col", DataType::Boolean, false),
670            ("int_col", DataType::Int64, false),
671            ("float_col", DataType::Float64, false),
672            ("str_col", DataType::Utf8, false),
673        ]);
674        let decoder = CsvDecoder::new(schema);
675        let records = vec![csv_block(
676            "bool_col,int_col,float_col,str_col\ntrue,42,3.14,hello",
677        )];
678        let batch = decoder.decode_batch(&records).unwrap();
679
680        assert_eq!(batch.num_rows(), 1);
681        assert!(batch.column(0).as_boolean().value(0));
682        assert_eq!(
683            batch
684                .column(1)
685                .as_primitive::<arrow_array::types::Int64Type>()
686                .value(0),
687            42
688        );
689        let f = batch
690            .column(2)
691            .as_primitive::<arrow_array::types::Float64Type>()
692            .value(0);
693        assert!((f - 3.14).abs() < f64::EPSILON);
694        assert_eq!(batch.column(3).as_string::<i32>().value(0), "hello");
695    }
696
697    // ── Null handling ─────────────────────────────────────────
698
699    #[test]
700    fn test_decode_null_string_default() {
701        // Default null_string is empty string.
702        let schema = make_schema(vec![
703            ("a", DataType::Int64, true),
704            ("b", DataType::Utf8, true),
705        ]);
706        let decoder = CsvDecoder::new(schema);
707        let records = vec![csv_block("a,b\n,")];
708        let batch = decoder.decode_batch(&records).unwrap();
709
710        assert!(batch.column(0).is_null(0));
711        assert!(batch.column(1).is_null(0));
712    }
713
714    #[test]
715    fn test_decode_null_string_custom() {
716        let schema = make_schema(vec![("val", DataType::Int64, true)]);
717        let config = CsvDecoderConfig {
718            null_string: "NA".into(),
719            ..Default::default()
720        };
721        let decoder = CsvDecoder::with_config(schema, config);
722        let records = vec![csv_block("val\nNA\n42")];
723        let batch = decoder.decode_batch(&records).unwrap();
724
725        assert_eq!(batch.num_rows(), 2);
726        assert!(batch.column(0).is_null(0));
727        assert_eq!(
728            batch
729                .column(0)
730                .as_primitive::<arrow_array::types::Int64Type>()
731                .value(1),
732            42
733        );
734    }
735
736    // ── Field count mismatch strategies ───────────────────────
737
738    #[test]
739    fn test_mismatch_null_strategy() {
740        let schema = make_schema(vec![
741            ("a", DataType::Int64, true),
742            ("b", DataType::Utf8, true),
743            ("c", DataType::Int64, true),
744        ]);
745        let decoder = CsvDecoder::new(schema);
746        // Row only has 2 fields, schema expects 3.
747        let records = vec![csv_block("a,b,c\n1,hello")];
748        let batch = decoder.decode_batch(&records).unwrap();
749
750        assert_eq!(batch.num_rows(), 1);
751        assert_eq!(
752            batch
753                .column(0)
754                .as_primitive::<arrow_array::types::Int64Type>()
755                .value(0),
756            1
757        );
758        assert_eq!(batch.column(1).as_string::<i32>().value(0), "hello");
759        assert!(batch.column(2).is_null(0)); // padded with null
760    }
761
762    #[test]
763    fn test_mismatch_skip_strategy() {
764        let schema = make_schema(vec![
765            ("a", DataType::Int64, false),
766            ("b", DataType::Int64, false),
767        ]);
768        let config = CsvDecoderConfig {
769            field_count_mismatch: FieldCountMismatchStrategy::Skip,
770            ..Default::default()
771        };
772        let decoder = CsvDecoder::with_config(schema, config);
773        // One good row, one bad row (too few fields).
774        let records = vec![csv_block("a,b\n1,2\n3")];
775        let batch = decoder.decode_batch(&records).unwrap();
776
777        assert_eq!(batch.num_rows(), 1); // bad row skipped
778        assert_eq!(
779            batch
780                .column(0)
781                .as_primitive::<arrow_array::types::Int64Type>()
782                .value(0),
783            1
784        );
785    }
786
787    #[test]
788    fn test_mismatch_reject_strategy() {
789        let schema = make_schema(vec![
790            ("a", DataType::Int64, false),
791            ("b", DataType::Int64, false),
792        ]);
793        let config = CsvDecoderConfig {
794            field_count_mismatch: FieldCountMismatchStrategy::Reject,
795            ..Default::default()
796        };
797        let decoder = CsvDecoder::with_config(schema, config);
798        let records = vec![csv_block("a,b\n1")]; // too few fields
799        let result = decoder.decode_batch(&records);
800
801        assert!(result.is_err());
802        assert!(result
803            .unwrap_err()
804            .to_string()
805            .contains("field count mismatch"));
806    }
807
808    // ── Delimiter options ─────────────────────────────────────
809
810    #[test]
811    fn test_pipe_delimiter() {
812        let schema = make_schema(vec![
813            ("a", DataType::Int64, false),
814            ("b", DataType::Utf8, false),
815        ]);
816        let config = CsvDecoderConfig {
817            delimiter: b'|',
818            ..Default::default()
819        };
820        let decoder = CsvDecoder::with_config(schema, config);
821        let records = vec![csv_block("a|b\n42|hello")];
822        let batch = decoder.decode_batch(&records).unwrap();
823
824        assert_eq!(
825            batch
826                .column(0)
827                .as_primitive::<arrow_array::types::Int64Type>()
828                .value(0),
829            42
830        );
831        assert_eq!(batch.column(1).as_string::<i32>().value(0), "hello");
832    }
833
834    #[test]
835    fn test_tab_delimiter() {
836        let schema = make_schema(vec![
837            ("a", DataType::Int64, false),
838            ("b", DataType::Utf8, false),
839        ]);
840        let config = CsvDecoderConfig {
841            delimiter: b'\t',
842            ..Default::default()
843        };
844        let decoder = CsvDecoder::with_config(schema, config);
845        let records = vec![csv_block("a\tb\n42\thello")];
846        let batch = decoder.decode_batch(&records).unwrap();
847
848        assert_eq!(
849            batch
850                .column(0)
851                .as_primitive::<arrow_array::types::Int64Type>()
852                .value(0),
853            42
854        );
855        assert_eq!(batch.column(1).as_string::<i32>().value(0), "hello");
856    }
857
858    #[test]
859    fn test_semicolon_delimiter() {
860        let schema = make_schema(vec![
861            ("a", DataType::Int64, false),
862            ("b", DataType::Utf8, false),
863        ]);
864        let config = CsvDecoderConfig {
865            delimiter: b';',
866            ..Default::default()
867        };
868        let decoder = CsvDecoder::with_config(schema, config);
869        let records = vec![csv_block("a;b\n99;world")];
870        let batch = decoder.decode_batch(&records).unwrap();
871
872        assert_eq!(
873            batch
874                .column(0)
875                .as_primitive::<arrow_array::types::Int64Type>()
876                .value(0),
877            99
878        );
879        assert_eq!(batch.column(1).as_string::<i32>().value(0), "world");
880    }
881
882    // ── Comment lines ─────────────────────────────────────────
883
884    #[test]
885    fn test_comment_lines() {
886        let schema = make_schema(vec![("val", DataType::Int64, false)]);
887        let config = CsvDecoderConfig {
888            comment: Some(b'#'),
889            ..Default::default()
890        };
891        let decoder = CsvDecoder::with_config(schema, config);
892        let records = vec![csv_block("val\n# this is a comment\n42\n# another\n99")];
893        let batch = decoder.decode_batch(&records).unwrap();
894
895        assert_eq!(batch.num_rows(), 2);
896        let col = batch
897            .column(0)
898            .as_primitive::<arrow_array::types::Int64Type>();
899        assert_eq!(col.value(0), 42);
900        assert_eq!(col.value(1), 99);
901    }
902
903    // ── Skip rows ─────────────────────────────────────────────
904
905    #[test]
906    fn test_skip_rows() {
907        let schema = make_schema(vec![("val", DataType::Int64, false)]);
908        let config = CsvDecoderConfig {
909            skip_rows: 2,
910            ..Default::default()
911        };
912        let decoder = CsvDecoder::with_config(schema, config);
913        let records = vec![csv_block("val\nskip1\nskip2\n42\n99")];
914        let batch = decoder.decode_batch(&records).unwrap();
915
916        // "skip1" and "skip2" are skipped (parse errors counted), then 42 and 99.
917        // Actually skip_rows skips first N data rows; skip1/skip2 aren't valid i64
918        // so they'd be parse errors. But the skip_rows logic skips before type
919        // coercion, so they won't generate errors.
920        assert_eq!(batch.num_rows(), 2);
921    }
922
923    // ── No header mode ────────────────────────────────────────
924
925    #[test]
926    fn test_no_header() {
927        let schema = make_schema(vec![
928            ("col0", DataType::Int64, false),
929            ("col1", DataType::Utf8, false),
930        ]);
931        let config = CsvDecoderConfig {
932            has_header: false,
933            ..Default::default()
934        };
935        let decoder = CsvDecoder::with_config(schema, config);
936        let records = vec![csv_block("1,alpha\n2,beta")];
937        let batch = decoder.decode_batch(&records).unwrap();
938
939        assert_eq!(batch.num_rows(), 2);
940        let col0 = batch
941            .column(0)
942            .as_primitive::<arrow_array::types::Int64Type>();
943        assert_eq!(col0.value(0), 1);
944        assert_eq!(col0.value(1), 2);
945    }
946
947    // ── Multiple records (streaming) ──────────────────────────
948
949    #[test]
950    fn test_multiple_raw_records() {
951        // Simulate streaming: each RawRecord is one CSV line (no header).
952        let schema = make_schema(vec![
953            ("id", DataType::Int64, false),
954            ("val", DataType::Float64, false),
955        ]);
956        let config = CsvDecoderConfig {
957            has_header: false,
958            ..Default::default()
959        };
960        let decoder = CsvDecoder::with_config(schema, config);
961        let records = vec![
962            csv_record("1,1.5"),
963            csv_record("2,2.5"),
964            csv_record("3,3.5"),
965        ];
966        let batch = decoder.decode_batch(&records).unwrap();
967
968        assert_eq!(batch.num_rows(), 3);
969        let id_col = batch
970            .column(0)
971            .as_primitive::<arrow_array::types::Int64Type>();
972        let val_col = batch
973            .column(1)
974            .as_primitive::<arrow_array::types::Float64Type>();
975        assert_eq!(id_col.value(0), 1);
976        assert_eq!(id_col.value(2), 3);
977        assert!((val_col.value(1) - 2.5).abs() < f64::EPSILON);
978    }
979
980    // ── Quoted fields ─────────────────────────────────────────
981
982    #[test]
983    fn test_quoted_fields_with_delimiter() {
984        let schema = make_schema(vec![
985            ("name", DataType::Utf8, false),
986            ("desc", DataType::Utf8, false),
987        ]);
988        let decoder = CsvDecoder::new(schema);
989        let records = vec![csv_block("name,desc\n\"Smith, John\",\"A, B\"")];
990        let batch = decoder.decode_batch(&records).unwrap();
991
992        assert_eq!(batch.num_rows(), 1);
993        assert_eq!(batch.column(0).as_string::<i32>().value(0), "Smith, John");
994        assert_eq!(batch.column(1).as_string::<i32>().value(0), "A, B");
995    }
996
997    #[test]
998    fn test_quoted_fields_with_newline() {
999        let schema = make_schema(vec![
1000            ("id", DataType::Int64, false),
1001            ("text", DataType::Utf8, false),
1002        ]);
1003        let decoder = CsvDecoder::new(schema);
1004        let records = vec![csv_block("id,text\n1,\"line1\nline2\"")];
1005        let batch = decoder.decode_batch(&records).unwrap();
1006
1007        assert_eq!(batch.num_rows(), 1);
1008        assert_eq!(batch.column(1).as_string::<i32>().value(0), "line1\nline2");
1009    }
1010
1011    #[test]
1012    fn test_escaped_quotes_rfc4180() {
1013        // RFC 4180: doubled quotes within quoted field.
1014        let schema = make_schema(vec![("val", DataType::Utf8, false)]);
1015        let decoder = CsvDecoder::new(schema);
1016        let records = vec![csv_block("val\n\"She said \"\"hello\"\"\"")];
1017        let batch = decoder.decode_batch(&records).unwrap();
1018
1019        assert_eq!(batch.num_rows(), 1);
1020        assert_eq!(
1021            batch.column(0).as_string::<i32>().value(0),
1022            "She said \"hello\""
1023        );
1024    }
1025
1026    // ── Timestamp parsing ─────────────────────────────────────
1027
1028    #[test]
1029    fn test_decode_timestamp() {
1030        let schema = make_schema(vec![(
1031            "ts",
1032            DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
1033            false,
1034        )]);
1035        let decoder = CsvDecoder::new(schema);
1036        let records = vec![csv_block("ts\n2025-01-15 10:30:00.000")];
1037        let batch = decoder.decode_batch(&records).unwrap();
1038
1039        assert_eq!(batch.num_rows(), 1);
1040        assert!(!batch.column(0).is_null(0));
1041    }
1042
1043    #[test]
1044    fn test_decode_timestamp_iso8601_fallback() {
1045        let schema = make_schema(vec![(
1046            "ts",
1047            DataType::Timestamp(TimeUnit::Nanosecond, None),
1048            false,
1049        )]);
1050        let decoder = CsvDecoder::new(schema);
1051        let records = vec![csv_block("ts\n2025-01-15T10:30:00Z")];
1052        let batch = decoder.decode_batch(&records).unwrap();
1053
1054        assert_eq!(batch.num_rows(), 1);
1055        assert!(!batch.column(0).is_null(0));
1056    }
1057
1058    // ── Date parsing ──────────────────────────────────────────
1059
1060    #[test]
1061    fn test_decode_date() {
1062        let schema = make_schema(vec![("d", DataType::Date32, false)]);
1063        let decoder = CsvDecoder::new(schema);
1064        let records = vec![csv_block("d\n2025-06-15")];
1065        let batch = decoder.decode_batch(&records).unwrap();
1066
1067        assert_eq!(batch.num_rows(), 1);
1068        assert!(!batch.column(0).is_null(0));
1069        // 2025-06-15 is day 20254 since epoch.
1070        let days = batch
1071            .column(0)
1072            .as_primitive::<arrow_array::types::Date32Type>()
1073            .value(0);
1074        let expected = chrono::NaiveDate::from_ymd_opt(2025, 6, 15)
1075            .unwrap()
1076            .signed_duration_since(chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap())
1077            .num_days();
1078        #[allow(clippy::cast_possible_truncation)]
1079        {
1080            assert_eq!(days, expected as i32);
1081        }
1082    }
1083
1084    // ── Boolean parsing ───────────────────────────────────────
1085
1086    #[test]
1087    fn test_decode_boolean_variants() {
1088        let schema = make_schema(vec![("b", DataType::Boolean, false)]);
1089        let config = CsvDecoderConfig {
1090            has_header: false,
1091            ..Default::default()
1092        };
1093        let decoder = CsvDecoder::with_config(schema, config);
1094        let records = vec![csv_block("true\nfalse\n1\n0\nyes\nno\nt\nf\ny\nn")];
1095        let batch = decoder.decode_batch(&records).unwrap();
1096
1097        assert_eq!(batch.num_rows(), 10);
1098        let col = batch.column(0).as_boolean();
1099        assert!(col.value(0)); // true
1100        assert!(!col.value(1)); // false
1101        assert!(col.value(2)); // 1
1102        assert!(!col.value(3)); // 0
1103        assert!(col.value(4)); // yes
1104        assert!(!col.value(5)); // no
1105        assert!(col.value(6)); // t
1106        assert!(!col.value(7)); // f
1107        assert!(col.value(8)); // y
1108        assert!(!col.value(9)); // n
1109    }
1110
1111    // ── Parse error counting ──────────────────────────────────
1112
1113    #[test]
1114    fn test_parse_error_count() {
1115        let schema = make_schema(vec![("val", DataType::Int64, true)]);
1116        let decoder = CsvDecoder::new(schema);
1117        let records = vec![csv_block("val\nnot_a_number\n42\nalso_bad")];
1118        let batch = decoder.decode_batch(&records).unwrap();
1119
1120        assert_eq!(batch.num_rows(), 3);
1121        assert!(batch.column(0).is_null(0));
1122        assert_eq!(
1123            batch
1124                .column(0)
1125                .as_primitive::<arrow_array::types::Int64Type>()
1126                .value(1),
1127            42
1128        );
1129        assert!(batch.column(0).is_null(2));
1130        assert_eq!(decoder.parse_error_count(), 2);
1131    }
1132
1133    // ── Extra fields ignored ──────────────────────────────────
1134
1135    #[test]
1136    fn test_extra_fields_truncated() {
1137        let schema = make_schema(vec![("a", DataType::Int64, false)]);
1138        let decoder = CsvDecoder::new(schema);
1139        // Row has 3 fields but schema only has 1.
1140        let records = vec![csv_block("a\n42,extra1,extra2")];
1141        let batch = decoder.decode_batch(&records).unwrap();
1142
1143        // Extra fields silently ignored (flexible mode).
1144        // field_count (3) != num_fields (1), but Null strategy just pads/truncates.
1145        assert_eq!(batch.num_rows(), 1);
1146        assert_eq!(
1147            batch
1148                .column(0)
1149                .as_primitive::<arrow_array::types::Int64Type>()
1150                .value(0),
1151            42
1152        );
1153    }
1154
1155    // ── FormatDecoder trait ───────────────────────────────────
1156
1157    #[test]
1158    fn test_format_name() {
1159        let schema = make_schema(vec![("a", DataType::Int64, false)]);
1160        let decoder = CsvDecoder::new(schema);
1161        assert_eq!(decoder.format_name(), "csv");
1162    }
1163
1164    #[test]
1165    fn test_output_schema() {
1166        let schema = make_schema(vec![
1167            ("a", DataType::Int64, false),
1168            ("b", DataType::Utf8, true),
1169        ]);
1170        let decoder = CsvDecoder::new(schema.clone());
1171        assert_eq!(decoder.output_schema(), schema);
1172    }
1173
1174    #[test]
1175    fn test_decode_one() {
1176        let schema = make_schema(vec![("x", DataType::Int64, false)]);
1177        let config = CsvDecoderConfig {
1178            has_header: false,
1179            ..Default::default()
1180        };
1181        let decoder = CsvDecoder::with_config(schema, config);
1182        let record = csv_record("99");
1183        let batch = decoder.decode_one(&record).unwrap();
1184        assert_eq!(batch.num_rows(), 1);
1185        assert_eq!(
1186            batch
1187                .column(0)
1188                .as_primitive::<arrow_array::types::Int64Type>()
1189                .value(0),
1190            99
1191        );
1192    }
1193
1194    // ── Edge cases ────────────────────────────────────────────
1195
1196    #[test]
1197    fn test_mixed_line_endings() {
1198        let schema = make_schema(vec![("val", DataType::Int64, false)]);
1199        let config = CsvDecoderConfig {
1200            has_header: false,
1201            ..Default::default()
1202        };
1203        let decoder = CsvDecoder::with_config(schema, config);
1204        let records = vec![csv_block("1\r\n2\n3\r\n")];
1205        let batch = decoder.decode_batch(&records).unwrap();
1206        assert_eq!(batch.num_rows(), 3);
1207    }
1208
1209    #[test]
1210    fn test_unicode_values() {
1211        let schema = make_schema(vec![("name", DataType::Utf8, false)]);
1212        let decoder = CsvDecoder::new(schema);
1213        let records = vec![csv_block("name\nこんにちは\nüber\nnaïve")];
1214        let batch = decoder.decode_batch(&records).unwrap();
1215
1216        assert_eq!(batch.num_rows(), 3);
1217        assert_eq!(batch.column(0).as_string::<i32>().value(0), "こんにちは");
1218        assert_eq!(batch.column(0).as_string::<i32>().value(1), "über");
1219        assert_eq!(batch.column(0).as_string::<i32>().value(2), "naïve");
1220    }
1221
1222    #[test]
1223    fn test_trailing_comma() {
1224        // Trailing comma creates an extra empty field.
1225        let schema = make_schema(vec![
1226            ("a", DataType::Int64, false),
1227            ("b", DataType::Int64, true),
1228        ]);
1229        let decoder = CsvDecoder::new(schema);
1230        let records = vec![csv_block("a,b\n1,")];
1231        let batch = decoder.decode_batch(&records).unwrap();
1232
1233        assert_eq!(batch.num_rows(), 1);
1234        assert_eq!(
1235            batch
1236                .column(0)
1237                .as_primitive::<arrow_array::types::Int64Type>()
1238                .value(0),
1239            1
1240        );
1241        // Empty string matches default null_string → null.
1242        assert!(batch.column(1).is_null(0));
1243    }
1244
1245    #[test]
1246    fn test_backslash_escape() {
1247        let schema = make_schema(vec![("val", DataType::Utf8, false)]);
1248        let config = CsvDecoderConfig {
1249            escape: Some(b'\\'),
1250            ..Default::default()
1251        };
1252        let decoder = CsvDecoder::with_config(schema, config);
1253        let records = vec![csv_block("val\n\"hello \\\"world\\\"\"")];
1254        let batch = decoder.decode_batch(&records).unwrap();
1255
1256        assert_eq!(batch.num_rows(), 1);
1257        assert_eq!(
1258            batch.column(0).as_string::<i32>().value(0),
1259            "hello \"world\""
1260        );
1261    }
1262
1263    // ── CsvEncoder tests ──────────────────────────────────────
1264
1265    #[test]
1266    fn test_csv_encode_basic() {
1267        let schema = make_schema(vec![
1268            ("id", DataType::Int64, false),
1269            ("name", DataType::Utf8, false),
1270        ]);
1271
1272        let batch = RecordBatch::try_new(
1273            schema.clone(),
1274            vec![
1275                Arc::new(arrow_array::Int64Array::from(vec![1, 2])),
1276                Arc::new(arrow_array::StringArray::from(vec!["Alice", "Bob"])),
1277            ],
1278        )
1279        .unwrap();
1280
1281        let encoder = CsvEncoder::new(schema);
1282        let records = encoder.encode_batch(&batch).unwrap();
1283
1284        assert_eq!(records.len(), 2);
1285        assert_eq!(std::str::from_utf8(&records[0]).unwrap(), "1,Alice");
1286        assert_eq!(std::str::from_utf8(&records[1]).unwrap(), "2,Bob");
1287    }
1288
1289    #[test]
1290    fn test_csv_encode_with_header() {
1291        let schema = make_schema(vec![
1292            ("id", DataType::Int64, false),
1293            ("name", DataType::Utf8, false),
1294        ]);
1295
1296        let batch = RecordBatch::try_new(
1297            schema.clone(),
1298            vec![
1299                Arc::new(arrow_array::Int64Array::from(vec![1])),
1300                Arc::new(arrow_array::StringArray::from(vec!["Alice"])),
1301            ],
1302        )
1303        .unwrap();
1304
1305        let config = CsvEncoderConfig {
1306            has_header: true,
1307            ..Default::default()
1308        };
1309        let encoder = CsvEncoder::with_config(schema, config);
1310        let records = encoder.encode_batch(&batch).unwrap();
1311
1312        assert_eq!(records.len(), 2); // header + 1 data row
1313        assert_eq!(std::str::from_utf8(&records[0]).unwrap(), "id,name");
1314        assert_eq!(std::str::from_utf8(&records[1]).unwrap(), "1,Alice");
1315    }
1316
1317    #[test]
1318    fn test_csv_encode_tab_delimiter() {
1319        let schema = make_schema(vec![
1320            ("a", DataType::Int64, false),
1321            ("b", DataType::Utf8, false),
1322        ]);
1323
1324        let batch = RecordBatch::try_new(
1325            schema.clone(),
1326            vec![
1327                Arc::new(arrow_array::Int64Array::from(vec![42])),
1328                Arc::new(arrow_array::StringArray::from(vec!["hello"])),
1329            ],
1330        )
1331        .unwrap();
1332
1333        let config = CsvEncoderConfig {
1334            delimiter: b'\t',
1335            ..Default::default()
1336        };
1337        let encoder = CsvEncoder::with_config(schema, config);
1338        let records = encoder.encode_batch(&batch).unwrap();
1339
1340        assert_eq!(records.len(), 1);
1341        assert_eq!(std::str::from_utf8(&records[0]).unwrap(), "42\thello");
1342    }
1343
1344    #[test]
1345    fn test_csv_encode_empty_batch() {
1346        let schema = make_schema(vec![("x", DataType::Int64, false)]);
1347        let batch = RecordBatch::new_empty(schema.clone());
1348        let encoder = CsvEncoder::new(schema);
1349        let records = encoder.encode_batch(&batch).unwrap();
1350        assert!(records.is_empty());
1351    }
1352
1353    #[test]
1354    fn test_csv_encode_nulls() {
1355        let schema = make_schema(vec![
1356            ("id", DataType::Int64, false),
1357            ("value", DataType::Int64, true),
1358        ]);
1359
1360        let batch = RecordBatch::try_new(
1361            schema.clone(),
1362            vec![
1363                Arc::new(arrow_array::Int64Array::from(vec![1, 2])),
1364                Arc::new(arrow_array::Int64Array::from(vec![Some(10), None])),
1365            ],
1366        )
1367        .unwrap();
1368
1369        let encoder = CsvEncoder::new(schema);
1370        let records = encoder.encode_batch(&batch).unwrap();
1371
1372        assert_eq!(records.len(), 2);
1373        assert_eq!(std::str::from_utf8(&records[0]).unwrap(), "1,10");
1374        assert_eq!(std::str::from_utf8(&records[1]).unwrap(), "2,");
1375    }
1376}