Skip to main content

laminar_connectors/schema/
inference.rs

1//! Format inference registry and built-in format inferencers.
2//!
3//! Provides:
4//! - [`FormatInference`] trait for format-specific schema inference
5//! - [`FormatInferenceRegistry`] for registering and looking up inferencers
6//! - Built-in implementations for JSON, CSV, and raw formats
7//! - [`default_infer_from_samples`] free function used by the default
8//!   [`SchemaInferable`](super::traits::SchemaInferable) implementation
9use std::collections::HashMap;
10use std::sync::{Arc, LazyLock, RwLock};
11
12use arrow_schema::{DataType, Field, Schema};
13
14use super::error::{SchemaError, SchemaResult};
15use super::traits::{
16    FieldInferenceDetail, InferenceConfig, InferenceWarning, InferredSchema, WarningSeverity,
17};
18use super::types::RawRecord;
19
20/// Trait for format-specific schema inference.
21///
22/// Each implementation handles a single data format (JSON, CSV, etc.)
23/// and can infer an Arrow schema from sample records.
24pub trait FormatInference: Send + Sync {
25    /// Returns the format name this inferencer handles (e.g., `"json"`).
26    fn format_name(&self) -> &'static str;
27
28    /// Infers a schema from sample raw records.
29    ///
30    /// # Errors
31    ///
32    /// Returns [`SchemaError::InferenceFailed`] if the samples cannot be
33    /// parsed or the schema cannot be determined.
34    fn infer(
35        &self,
36        samples: &[RawRecord],
37        config: &InferenceConfig,
38    ) -> SchemaResult<InferredSchema>;
39}
40
41/// Registry of format inferencers.
42///
43/// Thread-safe registry that maps format names to [`FormatInference`]
44/// implementations. A global instance is available via
45/// [`FORMAT_INFERENCE_REGISTRY`].
46pub struct FormatInferenceRegistry {
47    inferencers: RwLock<HashMap<String, Arc<dyn FormatInference>>>,
48}
49
50impl FormatInferenceRegistry {
51    /// Creates a new empty registry.
52    #[must_use]
53    pub fn new() -> Self {
54        Self {
55            inferencers: RwLock::new(HashMap::new()),
56        }
57    }
58
59    /// Registers a format inferencer.
60    ///
61    /// If an inferencer for the same format already exists, it is replaced.
62    ///
63    /// # Panics
64    ///
65    /// Panics if the internal lock is poisoned.
66    pub fn register(&self, inferencer: Arc<dyn FormatInference>) {
67        let name = inferencer.format_name().to_string();
68        self.inferencers
69            .write()
70            .unwrap_or_else(std::sync::PoisonError::into_inner)
71            .insert(name, inferencer);
72    }
73
74    /// Gets the inferencer for a format, if registered.
75    #[must_use]
76    pub fn get(&self, format: &str) -> Option<Arc<dyn FormatInference>> {
77        self.inferencers
78            .read()
79            .unwrap_or_else(std::sync::PoisonError::into_inner)
80            .get(format)
81            .cloned()
82    }
83
84    /// Returns the names of all registered formats.
85    #[must_use]
86    pub fn registered_formats(&self) -> Vec<String> {
87        self.inferencers
88            .read()
89            .unwrap_or_else(std::sync::PoisonError::into_inner)
90            .keys()
91            .cloned()
92            .collect()
93    }
94}
95
96impl Default for FormatInferenceRegistry {
97    fn default() -> Self {
98        Self::new()
99    }
100}
101
102/// Global format inference registry, pre-populated with JSON, CSV, and raw.
103pub static FORMAT_INFERENCE_REGISTRY: LazyLock<FormatInferenceRegistry> = LazyLock::new(|| {
104    let registry = FormatInferenceRegistry::new();
105    registry.register(Arc::new(JsonFormatInference));
106    registry.register(Arc::new(CsvFormatInference));
107    registry.register(Arc::new(RawFormatInference));
108    registry
109});
110
111/// Default inference implementation that delegates to the global registry.
112///
113/// Used by the default [`SchemaInferable::infer_from_samples`](super::traits::SchemaInferable::infer_from_samples).
114///
115/// # Errors
116///
117/// Returns [`SchemaError::InferenceFailed`] if no inferencer is registered
118/// for the requested format, or if inference itself fails.
119pub fn default_infer_from_samples(
120    samples: &[RawRecord],
121    config: &InferenceConfig,
122) -> SchemaResult<InferredSchema> {
123    let inferencer = FORMAT_INFERENCE_REGISTRY
124        .get(&config.format)
125        .ok_or_else(|| {
126            SchemaError::InferenceFailed(format!(
127                "no inferencer registered for format '{}'",
128                config.format
129            ))
130        })?;
131    inferencer.infer(samples, config)
132}
133
134// ── JSON inference ─────────────────────────────────────────────────
135
136/// JSON format schema inferencer.
137///
138/// Parses each record's value as a JSON object and merges field types
139/// across all samples. Ported from `sdk::schema::infer_schema_from_json`
140/// with added confidence scoring.
141pub struct JsonFormatInference;
142
143#[allow(clippy::cast_precision_loss)]
144impl FormatInference for JsonFormatInference {
145    fn format_name(&self) -> &'static str {
146        "json"
147    }
148
149    fn infer(
150        &self,
151        samples: &[RawRecord],
152        config: &InferenceConfig,
153    ) -> SchemaResult<InferredSchema> {
154        if samples.is_empty() {
155            return Err(SchemaError::InferenceFailed(
156                "cannot infer schema from zero samples".into(),
157            ));
158        }
159
160        let limit = samples.len().min(config.max_samples);
161        let samples = &samples[..limit];
162
163        let mut field_types: HashMap<String, Vec<InferredType>> = HashMap::new();
164        let mut field_order = Vec::new();
165        let mut warnings = Vec::new();
166
167        for (i, record) in samples.iter().enumerate() {
168            let value: serde_json::Value = serde_json::from_slice(&record.value).map_err(|e| {
169                SchemaError::InferenceFailed(format!("JSON parse error in sample {i}: {e}"))
170            })?;
171
172            let obj = value.as_object().ok_or_else(|| {
173                SchemaError::InferenceFailed(format!(
174                    "sample {i}: expected JSON object, got {}",
175                    json_type_name(&value)
176                ))
177            })?;
178
179            for (key, val) in obj {
180                if !field_types.contains_key(key) {
181                    field_order.push(key.clone());
182                    field_types.insert(key.clone(), Vec::with_capacity(samples.len()));
183                }
184                let inferred = infer_type_from_json(val, config.empty_as_null);
185                field_types.get_mut(key).unwrap().push(inferred);
186            }
187        }
188
189        let total = samples.len();
190        let mut fields = Vec::new();
191        let mut details = Vec::new();
192
193        for name in &field_order {
194            let types = &field_types[name];
195
196            // Check for a type hint.
197            let (data_type, hint_applied) = if let Some(hint) = config.type_hints.get(name) {
198                (hint.clone(), true)
199            } else {
200                (merge_types(types), false)
201            };
202
203            let non_null_count = types
204                .iter()
205                .filter(|t| !matches!(t, InferredType::Null))
206                .count();
207            let nullable =
208                types.iter().any(|t| matches!(t, InferredType::Null)) || types.len() < total;
209
210            let field_confidence = if types.is_empty() {
211                0.0
212            } else {
213                let consistent_count = types
214                    .iter()
215                    .filter(|t| !matches!(t, InferredType::Null))
216                    .filter(|t| inferred_to_arrow(t) == data_type)
217                    .count();
218                if non_null_count == 0 {
219                    0.5 // all nulls — low confidence
220                } else {
221                    consistent_count as f64 / non_null_count as f64
222                }
223            };
224
225            if field_confidence < config.min_confidence && !hint_applied {
226                warnings.push(InferenceWarning {
227                    field: Some(name.clone()),
228                    message: format!(
229                        "low confidence {field_confidence:.2} for field '{name}', \
230                         falling back to Utf8"
231                    ),
232                    severity: WarningSeverity::Warning,
233                });
234            }
235
236            fields.push(Field::new(name, data_type.clone(), nullable));
237            details.push(FieldInferenceDetail {
238                field_name: name.clone(),
239                inferred_type: data_type,
240                confidence: field_confidence,
241                non_null_count,
242                total_count: types.len(),
243                hint_applied,
244            });
245        }
246
247        if fields.is_empty() {
248            return Err(SchemaError::InferenceFailed(
249                "no fields could be inferred from JSON samples".into(),
250            ));
251        }
252
253        let overall_confidence = if details.is_empty() {
254            0.0
255        } else {
256            details.iter().map(|d| d.confidence).sum::<f64>() / details.len() as f64
257        };
258
259        Ok(InferredSchema {
260            schema: Arc::new(Schema::new(fields)),
261            confidence: overall_confidence,
262            sample_count: total,
263            field_details: details,
264            warnings,
265        })
266    }
267}
268
269// ── CSV inference ──────────────────────────────────────────────────
270
271/// CSV format schema inferencer.
272///
273/// Treats the first sample's value as a header row and subsequent samples
274/// as data rows. Ported from `sdk::schema::infer_schema_from_csv` with
275/// added confidence scoring.
276pub struct CsvFormatInference;
277
278#[allow(clippy::cast_precision_loss)]
279impl FormatInference for CsvFormatInference {
280    fn format_name(&self) -> &'static str {
281        "csv"
282    }
283
284    fn infer(
285        &self,
286        samples: &[RawRecord],
287        config: &InferenceConfig,
288    ) -> SchemaResult<InferredSchema> {
289        if samples.is_empty() {
290            return Err(SchemaError::InferenceFailed(
291                "cannot infer schema from zero samples".into(),
292            ));
293        }
294
295        let limit = samples.len().min(config.max_samples);
296        let samples = &samples[..limit];
297
298        // First sample contains headers.
299        let header_line = std::str::from_utf8(&samples[0].value).map_err(|e| {
300            SchemaError::InferenceFailed(format!("invalid UTF-8 in CSV header: {e}"))
301        })?;
302
303        let headers: Vec<String> = header_line
304            .split(',')
305            .map(|s| s.trim().to_string())
306            .collect();
307
308        let mut field_types: HashMap<String, Vec<InferredType>> = HashMap::new();
309        for header in &headers {
310            field_types.insert(header.clone(), Vec::with_capacity(samples.len()));
311        }
312
313        let mut warnings: Vec<InferenceWarning> = Vec::new();
314
315        for record in samples.iter().skip(1) {
316            let line = std::str::from_utf8(&record.value).map_err(|e| {
317                SchemaError::InferenceFailed(format!("invalid UTF-8 in CSV row: {e}"))
318            })?;
319
320            let values: Vec<&str> = line.split(',').map(str::trim).collect();
321
322            if values.len() != headers.len() {
323                warnings.push(InferenceWarning {
324                    field: None,
325                    message: format!(
326                        "column count mismatch: expected {}, got {}",
327                        headers.len(),
328                        values.len()
329                    ),
330                    severity: WarningSeverity::Warning,
331                });
332            }
333
334            for (i, value) in values.iter().enumerate() {
335                if let Some(header) = headers.get(i) {
336                    let inferred = infer_type_from_string(value, config.empty_as_null);
337                    if let Some(types) = field_types.get_mut(header) {
338                        types.push(inferred);
339                    }
340                }
341            }
342        }
343
344        let data_rows = samples.len().saturating_sub(1);
345        let mut fields: Vec<Field> = Vec::new();
346        let mut details: Vec<FieldInferenceDetail> = Vec::new();
347
348        for name in &headers {
349            let types = &field_types[name];
350
351            let (data_type, hint_applied) = if let Some(hint) = config.type_hints.get(name) {
352                (hint.clone(), true)
353            } else if types.is_empty() {
354                (DataType::Utf8, false) // no data rows — default to string
355            } else {
356                (merge_types(types), false)
357            };
358
359            let non_null_count = types
360                .iter()
361                .filter(|t| !matches!(t, InferredType::Null))
362                .count();
363            let nullable = types.iter().any(|t| matches!(t, InferredType::Null));
364
365            let field_confidence = if types.is_empty() {
366                0.5 // headers only
367            } else {
368                let consistent = types
369                    .iter()
370                    .filter(|t| !matches!(t, InferredType::Null))
371                    .filter(|t| inferred_to_arrow(t) == data_type)
372                    .count();
373                if non_null_count == 0 {
374                    0.5
375                } else {
376                    consistent as f64 / non_null_count as f64
377                }
378            };
379
380            fields.push(Field::new(name, data_type.clone(), nullable));
381            details.push(FieldInferenceDetail {
382                field_name: name.clone(),
383                inferred_type: data_type,
384                confidence: field_confidence,
385                non_null_count,
386                total_count: types.len(),
387                hint_applied,
388            });
389        }
390
391        let overall_confidence = if details.is_empty() {
392            0.0
393        } else {
394            details.iter().map(|d| d.confidence).sum::<f64>() / details.len() as f64
395        };
396
397        Ok(InferredSchema {
398            schema: Arc::new(Schema::new(fields)),
399            confidence: overall_confidence,
400            sample_count: data_rows,
401            field_details: details,
402            warnings,
403        })
404    }
405}
406
407// ── Raw inference ──────────────────────────────────────────────────
408
409/// Raw format inferencer.
410///
411/// Always returns a single `Binary` column named `"value"`.
412/// No actual inference is needed — the schema is fixed.
413pub struct RawFormatInference;
414
415impl FormatInference for RawFormatInference {
416    fn format_name(&self) -> &'static str {
417        "raw"
418    }
419
420    fn infer(
421        &self,
422        samples: &[RawRecord],
423        _config: &InferenceConfig,
424    ) -> SchemaResult<InferredSchema> {
425        let schema = Arc::new(Schema::new(vec![Field::new(
426            "value",
427            DataType::Binary,
428            false,
429        )]));
430
431        Ok(InferredSchema {
432            schema,
433            confidence: 1.0,
434            sample_count: samples.len(),
435            field_details: vec![FieldInferenceDetail {
436                field_name: "value".into(),
437                inferred_type: DataType::Binary,
438                confidence: 1.0,
439                non_null_count: samples.len(),
440                total_count: samples.len(),
441                hint_applied: false,
442            }],
443            warnings: vec![],
444        })
445    }
446}
447
448// ── Internal helpers ───────────────────────────────────────────────
449
450#[derive(Debug, Clone, PartialEq)]
451enum InferredType {
452    Null,
453    Bool,
454    Int,
455    Float,
456    String,
457}
458
459fn json_type_name(value: &serde_json::Value) -> &'static str {
460    match value {
461        serde_json::Value::Null => "null",
462        serde_json::Value::Bool(_) => "boolean",
463        serde_json::Value::Number(_) => "number",
464        serde_json::Value::String(_) => "string",
465        serde_json::Value::Array(_) => "array",
466        serde_json::Value::Object(_) => "object",
467    }
468}
469
470fn infer_type_from_json(value: &serde_json::Value, empty_as_null: bool) -> InferredType {
471    match value {
472        serde_json::Value::Null => InferredType::Null,
473        serde_json::Value::Bool(_) => InferredType::Bool,
474        serde_json::Value::Number(n) => {
475            if n.is_f64() && !n.is_i64() && !n.is_u64() {
476                InferredType::Float
477            } else {
478                InferredType::Int
479            }
480        }
481        serde_json::Value::String(s) => {
482            if empty_as_null && s.is_empty() {
483                InferredType::Null
484            } else {
485                InferredType::String
486            }
487        }
488        // Arrays and objects are serialized as JSON strings.
489        serde_json::Value::Array(_) | serde_json::Value::Object(_) => InferredType::String,
490    }
491}
492
493fn infer_type_from_string(value: &str, empty_as_null: bool) -> InferredType {
494    if value.is_empty() {
495        return if empty_as_null {
496            InferredType::Null
497        } else {
498            InferredType::String
499        };
500    }
501
502    if value.eq_ignore_ascii_case("true") || value.eq_ignore_ascii_case("false") {
503        return InferredType::Bool;
504    }
505
506    if value.parse::<i64>().is_ok() {
507        return InferredType::Int;
508    }
509
510    if value.parse::<f64>().is_ok() {
511        return InferredType::Float;
512    }
513
514    InferredType::String
515}
516
517fn merge_types(types: &[InferredType]) -> DataType {
518    let non_null: Vec<_> = types
519        .iter()
520        .filter(|t| !matches!(t, InferredType::Null))
521        .collect();
522
523    if non_null.is_empty() {
524        return DataType::Utf8;
525    }
526
527    let first = non_null[0];
528    if non_null.iter().all(|t| *t == first) {
529        return inferred_to_arrow(first);
530    }
531
532    // Int + Float → Float.
533    let has_float = non_null.iter().any(|t| matches!(t, InferredType::Float));
534    let has_int = non_null.iter().any(|t| matches!(t, InferredType::Int));
535    let has_other = non_null
536        .iter()
537        .any(|t| !matches!(t, InferredType::Int | InferredType::Float));
538
539    if has_float && has_int && !has_other {
540        return DataType::Float64;
541    }
542
543    // Mixed types → Utf8.
544    DataType::Utf8
545}
546
547fn inferred_to_arrow(t: &InferredType) -> DataType {
548    match t {
549        InferredType::Bool => DataType::Boolean,
550        InferredType::Int => DataType::Int64,
551        InferredType::Float => DataType::Float64,
552        InferredType::Null | InferredType::String => DataType::Utf8,
553    }
554}
555
556#[cfg(test)]
557mod tests {
558    use super::*;
559
560    fn json_record(json: &str) -> RawRecord {
561        RawRecord::new(json.as_bytes().to_vec())
562    }
563
564    fn csv_record(line: &str) -> RawRecord {
565        RawRecord::new(line.as_bytes().to_vec())
566    }
567
568    // ── Registry tests ─────────────────────────────────────────
569
570    #[test]
571    fn test_registry_has_builtins() {
572        let formats = FORMAT_INFERENCE_REGISTRY.registered_formats();
573        assert!(formats.contains(&"json".to_string()));
574        assert!(formats.contains(&"csv".to_string()));
575        assert!(formats.contains(&"raw".to_string()));
576    }
577
578    #[test]
579    fn test_registry_get_json() {
580        let inf = FORMAT_INFERENCE_REGISTRY.get("json");
581        assert!(inf.is_some());
582        assert_eq!(inf.unwrap().format_name(), "json");
583    }
584
585    #[test]
586    fn test_registry_get_unknown() {
587        assert!(FORMAT_INFERENCE_REGISTRY.get("protobuf").is_none());
588    }
589
590    #[test]
591    fn test_default_infer_unknown_format() {
592        let cfg = InferenceConfig::new("xml");
593        let result = default_infer_from_samples(&[], &cfg);
594        assert!(result.is_err());
595        assert!(result.unwrap_err().to_string().contains("xml"));
596    }
597
598    // ── JSON inference tests ───────────────────────────────────
599
600    #[test]
601    fn test_json_infer_basic() {
602        let samples = vec![
603            json_record(r#"{"id": 1, "name": "Alice"}"#),
604            json_record(r#"{"id": 2, "name": "Bob"}"#),
605        ];
606
607        let cfg = InferenceConfig::new("json");
608        let result = JsonFormatInference.infer(&samples, &cfg).unwrap();
609
610        assert_eq!(result.schema.fields().len(), 2);
611        assert_eq!(
612            result.schema.field_with_name("id").unwrap().data_type(),
613            &DataType::Int64
614        );
615        assert_eq!(
616            result.schema.field_with_name("name").unwrap().data_type(),
617            &DataType::Utf8
618        );
619        assert_eq!(result.sample_count, 2);
620        assert!(result.confidence > 0.9);
621    }
622
623    #[test]
624    fn test_json_infer_types() {
625        let samples = vec![json_record(
626            r#"{"int": 42, "float": 3.14, "bool": true, "str": "hello"}"#,
627        )];
628
629        let cfg = InferenceConfig::new("json");
630        let result = JsonFormatInference.infer(&samples, &cfg).unwrap();
631
632        assert_eq!(
633            result.schema.field_with_name("int").unwrap().data_type(),
634            &DataType::Int64
635        );
636        assert_eq!(
637            result.schema.field_with_name("float").unwrap().data_type(),
638            &DataType::Float64
639        );
640        assert_eq!(
641            result.schema.field_with_name("bool").unwrap().data_type(),
642            &DataType::Boolean
643        );
644        assert_eq!(
645            result.schema.field_with_name("str").unwrap().data_type(),
646            &DataType::Utf8
647        );
648    }
649
650    #[test]
651    fn test_json_infer_nullable() {
652        let samples = vec![
653            json_record(r#"{"value": 1}"#),
654            json_record(r#"{"value": null}"#),
655        ];
656
657        let cfg = InferenceConfig::new("json");
658        let result = JsonFormatInference.infer(&samples, &cfg).unwrap();
659        assert!(result
660            .schema
661            .field_with_name("value")
662            .unwrap()
663            .is_nullable());
664    }
665
666    #[test]
667    fn test_json_infer_mixed_int_float() {
668        let samples = vec![
669            json_record(r#"{"value": 1}"#),
670            json_record(r#"{"value": 2.5}"#),
671        ];
672
673        let cfg = InferenceConfig::new("json");
674        let result = JsonFormatInference.infer(&samples, &cfg).unwrap();
675        assert_eq!(
676            result.schema.field_with_name("value").unwrap().data_type(),
677            &DataType::Float64
678        );
679    }
680
681    #[test]
682    fn test_json_infer_type_hint() {
683        let samples = vec![json_record(r#"{"id": 42}"#)];
684        let cfg = InferenceConfig::new("json").with_type_hint("id", DataType::Int32);
685        let result = JsonFormatInference.infer(&samples, &cfg).unwrap();
686        assert_eq!(
687            result.schema.field_with_name("id").unwrap().data_type(),
688            &DataType::Int32
689        );
690        assert!(result.field_details[0].hint_applied);
691    }
692
693    #[test]
694    fn test_json_infer_empty_error() {
695        let cfg = InferenceConfig::new("json");
696        let result = JsonFormatInference.infer(&[], &cfg);
697        assert!(result.is_err());
698    }
699
700    #[test]
701    fn test_json_infer_empty_as_null() {
702        let samples = vec![
703            json_record(r#"{"value": ""}"#),
704            json_record(r#"{"value": "text"}"#),
705        ];
706        let cfg = InferenceConfig::new("json").with_empty_as_null();
707        let result = JsonFormatInference.infer(&samples, &cfg).unwrap();
708        assert!(result
709            .schema
710            .field_with_name("value")
711            .unwrap()
712            .is_nullable());
713    }
714
715    #[test]
716    fn test_json_infer_confidence_details() {
717        let samples = vec![
718            json_record(r#"{"a": 1, "b": "x"}"#),
719            json_record(r#"{"a": 2, "b": "y"}"#),
720            json_record(r#"{"a": 3, "b": "z"}"#),
721        ];
722
723        let cfg = InferenceConfig::new("json");
724        let result = JsonFormatInference.infer(&samples, &cfg).unwrap();
725
726        assert_eq!(result.field_details.len(), 2);
727        for detail in &result.field_details {
728            assert!((detail.confidence - 1.0).abs() < f64::EPSILON);
729            assert_eq!(detail.non_null_count, 3);
730            assert_eq!(detail.total_count, 3);
731        }
732    }
733
734    #[test]
735    fn test_json_infer_missing_field_nullable() {
736        let samples = vec![
737            json_record(r#"{"a": 1, "b": 2}"#),
738            json_record(r#"{"a": 3}"#), // "b" missing
739        ];
740
741        let cfg = InferenceConfig::new("json");
742        let result = JsonFormatInference.infer(&samples, &cfg).unwrap();
743        // "b" is only seen once out of 2 samples → nullable
744        assert!(result.schema.field_with_name("b").unwrap().is_nullable());
745    }
746
747    // ── CSV inference tests ────────────────────────────────────
748
749    #[test]
750    fn test_csv_infer_basic() {
751        let samples = vec![
752            csv_record("id,name,age"),
753            csv_record("1,Alice,30"),
754            csv_record("2,Bob,25"),
755        ];
756
757        let cfg = InferenceConfig::new("csv");
758        let result = CsvFormatInference.infer(&samples, &cfg).unwrap();
759
760        assert_eq!(result.schema.fields().len(), 3);
761        assert_eq!(
762            result.schema.field_with_name("id").unwrap().data_type(),
763            &DataType::Int64
764        );
765        assert_eq!(
766            result.schema.field_with_name("name").unwrap().data_type(),
767            &DataType::Utf8
768        );
769        assert_eq!(
770            result.schema.field_with_name("age").unwrap().data_type(),
771            &DataType::Int64
772        );
773        assert_eq!(result.sample_count, 2); // 2 data rows
774    }
775
776    #[test]
777    fn test_csv_infer_types() {
778        let samples = vec![
779            csv_record("int_col,float_col,bool_col,str_col"),
780            csv_record("42,3.14,true,hello"),
781            csv_record("100,2.71,false,world"),
782        ];
783
784        let cfg = InferenceConfig::new("csv");
785        let result = CsvFormatInference.infer(&samples, &cfg).unwrap();
786
787        assert_eq!(
788            result
789                .schema
790                .field_with_name("int_col")
791                .unwrap()
792                .data_type(),
793            &DataType::Int64
794        );
795        assert_eq!(
796            result
797                .schema
798                .field_with_name("float_col")
799                .unwrap()
800                .data_type(),
801            &DataType::Float64
802        );
803        assert_eq!(
804            result
805                .schema
806                .field_with_name("bool_col")
807                .unwrap()
808                .data_type(),
809            &DataType::Boolean
810        );
811        assert_eq!(
812            result
813                .schema
814                .field_with_name("str_col")
815                .unwrap()
816                .data_type(),
817            &DataType::Utf8
818        );
819    }
820
821    #[test]
822    fn test_csv_infer_headers_only() {
823        let samples = vec![csv_record("a,b,c")];
824
825        let cfg = InferenceConfig::new("csv");
826        let result = CsvFormatInference.infer(&samples, &cfg).unwrap();
827        assert_eq!(result.schema.fields().len(), 3);
828        // All default to Utf8 with no data.
829        for field in result.schema.fields() {
830            assert_eq!(field.data_type(), &DataType::Utf8);
831        }
832    }
833
834    #[test]
835    fn test_csv_infer_empty_error() {
836        let cfg = InferenceConfig::new("csv");
837        assert!(CsvFormatInference.infer(&[], &cfg).is_err());
838    }
839
840    // ── Raw inference tests ────────────────────────────────────
841
842    #[test]
843    fn test_raw_infer() {
844        let samples = vec![
845            RawRecord::new(b"hello".to_vec()),
846            RawRecord::new(b"world".to_vec()),
847        ];
848
849        let cfg = InferenceConfig::new("raw");
850        let result = RawFormatInference.infer(&samples, &cfg).unwrap();
851
852        assert_eq!(result.schema.fields().len(), 1);
853        assert_eq!(
854            result.schema.field_with_name("value").unwrap().data_type(),
855            &DataType::Binary
856        );
857        assert!((result.confidence - 1.0).abs() < f64::EPSILON);
858        assert_eq!(result.sample_count, 2);
859    }
860
861    // ── default_infer_from_samples tests ───────────────────────
862
863    #[test]
864    fn test_default_infer_json() {
865        let samples = vec![json_record(r#"{"x": 1}"#)];
866        let cfg = InferenceConfig::new("json");
867        let result = default_infer_from_samples(&samples, &cfg).unwrap();
868        assert_eq!(result.schema.fields().len(), 1);
869    }
870
871    #[test]
872    fn test_default_infer_csv() {
873        let samples = vec![csv_record("col1"), csv_record("42")];
874        let cfg = InferenceConfig::new("csv");
875        let result = default_infer_from_samples(&samples, &cfg).unwrap();
876        assert_eq!(result.schema.fields().len(), 1);
877    }
878
879    // ── Helper unit tests ──────────────────────────────────────
880
881    #[test]
882    fn test_merge_types_all_same() {
883        let types = vec![InferredType::Int, InferredType::Int, InferredType::Int];
884        assert_eq!(merge_types(&types), DataType::Int64);
885    }
886
887    #[test]
888    fn test_merge_types_int_float() {
889        let types = vec![InferredType::Int, InferredType::Float];
890        assert_eq!(merge_types(&types), DataType::Float64);
891    }
892
893    #[test]
894    fn test_merge_types_mixed() {
895        let types = vec![InferredType::Int, InferredType::String];
896        assert_eq!(merge_types(&types), DataType::Utf8);
897    }
898
899    #[test]
900    fn test_merge_types_all_null() {
901        let types = vec![InferredType::Null, InferredType::Null];
902        assert_eq!(merge_types(&types), DataType::Utf8);
903    }
904
905    #[test]
906    fn test_infer_type_from_string_cases() {
907        assert_eq!(infer_type_from_string("42", false), InferredType::Int);
908        assert_eq!(infer_type_from_string("3.14", false), InferredType::Float);
909        assert_eq!(infer_type_from_string("true", false), InferredType::Bool);
910        assert_eq!(infer_type_from_string("hello", false), InferredType::String);
911        assert_eq!(infer_type_from_string("", false), InferredType::String);
912        assert_eq!(infer_type_from_string("", true), InferredType::Null);
913    }
914}