Skip to main content

laminar_connectors/schema/parquet/
provider.rs

1//! Parquet schema provider implementing [`SchemaProvider`].
2//!
3//! Reads the Parquet file footer to extract an authoritative Arrow schema
4//! with optional per-field metadata.
5
6use std::collections::HashMap;
7
8use arrow_schema::SchemaRef;
9use async_trait::async_trait;
10use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
11
12use crate::schema::error::{SchemaError, SchemaResult};
13use crate::schema::traits::SchemaProvider;
14use crate::schema::types::FieldMeta;
15
16/// Provides an Arrow schema extracted from a Parquet file's footer.
17///
18/// This provider is **authoritative** — Parquet files carry their schema
19/// in the footer, so the schema is exact (no inference needed).
20pub struct ParquetSchemaProvider {
21    /// The Parquet file bytes (complete file including footer).
22    data: Vec<u8>,
23}
24
25impl ParquetSchemaProvider {
26    /// Creates a new provider from complete Parquet file bytes.
27    #[must_use]
28    pub fn new(data: Vec<u8>) -> Self {
29        Self { data }
30    }
31}
32
33impl std::fmt::Debug for ParquetSchemaProvider {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("ParquetSchemaProvider")
36            .field("data_len", &self.data.len())
37            .finish()
38    }
39}
40
41#[async_trait]
42impl SchemaProvider for ParquetSchemaProvider {
43    async fn provide_schema(&self) -> SchemaResult<SchemaRef> {
44        let bytes = bytes::Bytes::copy_from_slice(&self.data);
45        let builder = ParquetRecordBatchReaderBuilder::try_new(bytes).map_err(|e| {
46            SchemaError::InferenceFailed(format!("cannot read Parquet footer: {e}"))
47        })?;
48
49        Ok(builder.schema().clone())
50    }
51
52    fn is_authoritative(&self) -> bool {
53        true
54    }
55
56    async fn field_metadata(&self) -> SchemaResult<HashMap<String, FieldMeta>> {
57        let bytes = bytes::Bytes::copy_from_slice(&self.data);
58        let builder = ParquetRecordBatchReaderBuilder::try_new(bytes).map_err(|e| {
59            SchemaError::InferenceFailed(format!("cannot read Parquet footer: {e}"))
60        })?;
61
62        let schema = builder.schema();
63        let parquet_meta = builder.metadata();
64        let file_meta = parquet_meta.file_metadata();
65
66        let mut result = HashMap::new();
67        for (idx, field) in schema.fields().iter().enumerate() {
68            let mut meta = FieldMeta::new()
69                .with_field_id(u32::try_from(idx).expect("field index overflow"))
70                .with_source_type(format!("{}", field.data_type()));
71
72            // Copy Arrow field metadata as properties.
73            for (k, v) in field.metadata() {
74                meta = meta.with_property(k, v);
75            }
76
77            // Add Parquet-level metadata if available.
78            if let Some(created_by) = file_meta.created_by() {
79                meta = meta.with_property("parquet.created_by", created_by);
80            }
81
82            result.insert(field.name().clone(), meta);
83        }
84
85        Ok(result)
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use std::sync::Arc;
92
93    use super::*;
94    use arrow_array::{Int64Array, RecordBatch, StringArray};
95    use arrow_schema::{DataType, Field, Schema};
96    use parquet::arrow::ArrowWriter;
97
98    fn sample_parquet_bytes() -> Vec<u8> {
99        let schema = Arc::new(Schema::new(vec![
100            Field::new("id", DataType::Int64, false),
101            Field::new("name", DataType::Utf8, true),
102        ]));
103        let batch = RecordBatch::try_new(
104            schema.clone(),
105            vec![
106                Arc::new(Int64Array::from(vec![1, 2])),
107                Arc::new(StringArray::from(vec!["a", "b"])),
108            ],
109        )
110        .unwrap();
111
112        let mut buf = Vec::new();
113        let mut writer = ArrowWriter::try_new(&mut buf, schema, None).unwrap();
114        writer.write(&batch).unwrap();
115        writer.close().unwrap();
116        buf
117    }
118
119    #[tokio::test]
120    async fn test_provide_schema() {
121        let data = sample_parquet_bytes();
122        let provider = ParquetSchemaProvider::new(data);
123        let schema = provider.provide_schema().await.unwrap();
124
125        assert_eq!(schema.fields().len(), 2);
126        assert_eq!(schema.field(0).name(), "id");
127        assert_eq!(schema.field(0).data_type(), &DataType::Int64);
128        assert_eq!(schema.field(1).name(), "name");
129        assert_eq!(schema.field(1).data_type(), &DataType::Utf8);
130    }
131
132    #[test]
133    fn test_is_authoritative() {
134        let provider = ParquetSchemaProvider::new(vec![]);
135        assert!(provider.is_authoritative());
136    }
137
138    #[tokio::test]
139    async fn test_field_metadata() {
140        let data = sample_parquet_bytes();
141        let provider = ParquetSchemaProvider::new(data);
142        let meta = provider.field_metadata().await.unwrap();
143
144        assert!(meta.contains_key("id"));
145        assert!(meta.contains_key("name"));
146        let id_meta = &meta["id"];
147        assert_eq!(id_meta.field_id, Some(0));
148        assert!(id_meta.source_type.as_ref().unwrap().contains("Int64"));
149    }
150
151    #[tokio::test]
152    async fn test_invalid_bytes() {
153        let provider = ParquetSchemaProvider::new(b"not parquet".to_vec());
154        let result = provider.provide_schema().await;
155        assert!(result.is_err());
156    }
157}