laminar_connectors/schema/parquet/
encoder.rs1use arrow_array::RecordBatch;
7use arrow_schema::SchemaRef;
8use parquet::arrow::ArrowWriter;
9use parquet::basic::Compression;
10use parquet::file::properties::WriterProperties;
11
12use crate::schema::error::{SchemaError, SchemaResult};
13use crate::schema::traits::FormatEncoder;
14
15#[derive(Debug, Clone)]
17pub struct ParquetEncoderConfig {
18 pub compression: Compression,
20
21 pub writer_version: i32,
23
24 pub max_row_group_size: usize,
26
27 pub write_statistics: bool,
29}
30
31impl Default for ParquetEncoderConfig {
32 fn default() -> Self {
33 Self {
34 compression: Compression::SNAPPY,
35 writer_version: 2,
36 max_row_group_size: 1_000_000,
37 write_statistics: true,
38 }
39 }
40}
41
42impl ParquetEncoderConfig {
43 #[must_use]
45 pub fn with_compression(mut self, compression: Compression) -> Self {
46 self.compression = compression;
47 self
48 }
49
50 #[must_use]
52 pub fn with_writer_version(mut self, version: i32) -> Self {
53 self.writer_version = version;
54 self
55 }
56
57 #[must_use]
59 pub fn with_max_row_group_size(mut self, size: usize) -> Self {
60 self.max_row_group_size = size;
61 self
62 }
63
64 #[must_use]
66 pub fn with_statistics(mut self, enabled: bool) -> Self {
67 self.write_statistics = enabled;
68 self
69 }
70}
71
72#[derive(Debug)]
78pub struct ParquetEncoder {
79 schema: SchemaRef,
80 config: ParquetEncoderConfig,
81}
82
83impl ParquetEncoder {
84 #[must_use]
86 pub fn new(schema: SchemaRef) -> Self {
87 Self::with_config(schema, ParquetEncoderConfig::default())
88 }
89
90 #[must_use]
92 pub fn with_config(schema: SchemaRef, config: ParquetEncoderConfig) -> Self {
93 Self { schema, config }
94 }
95}
96
97impl FormatEncoder for ParquetEncoder {
98 fn input_schema(&self) -> SchemaRef {
99 self.schema.clone()
100 }
101
102 fn encode_batch(&self, batch: &RecordBatch) -> SchemaResult<Vec<Vec<u8>>> {
103 if batch.num_rows() == 0 {
104 return Ok(vec![]);
105 }
106
107 let mut props_builder = WriterProperties::builder()
108 .set_compression(self.config.compression)
109 .set_max_row_group_size(self.config.max_row_group_size);
110
111 if !self.config.write_statistics {
112 props_builder = props_builder
113 .set_statistics_enabled(parquet::file::properties::EnabledStatistics::None);
114 }
115
116 let props = props_builder.build();
117
118 let mut buf = Vec::new();
119 let mut writer = ArrowWriter::try_new(&mut buf, self.schema.clone(), Some(props))
120 .map_err(|e| SchemaError::DecodeError(format!("Parquet writer init: {e}")))?;
121
122 writer
123 .write(batch)
124 .map_err(|e| SchemaError::DecodeError(format!("Parquet write error: {e}")))?;
125
126 writer
127 .close()
128 .map_err(|e| SchemaError::DecodeError(format!("Parquet close error: {e}")))?;
129
130 Ok(vec![buf])
132 }
133
134 fn format_name(&self) -> &'static str {
135 "parquet"
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use arrow_array::{Int64Array, StringArray};
143 use arrow_schema::{DataType, Field, Schema};
144 use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
145 use parquet::basic::{GzipLevel, ZstdLevel};
146 use std::sync::Arc;
147
148 fn make_schema() -> SchemaRef {
149 Arc::new(Schema::new(vec![
150 Field::new("id", DataType::Int64, false),
151 Field::new("name", DataType::Utf8, true),
152 ]))
153 }
154
155 #[test]
156 fn test_encode_empty_batch() {
157 let schema = make_schema();
158 let batch = RecordBatch::new_empty(schema.clone());
159 let encoder = ParquetEncoder::new(schema);
160 let result = encoder.encode_batch(&batch).unwrap();
161 assert!(result.is_empty());
162 }
163
164 #[test]
165 fn test_encode_roundtrip() {
166 let schema = make_schema();
167 let batch = RecordBatch::try_new(
168 schema.clone(),
169 vec![
170 Arc::new(Int64Array::from(vec![1, 2, 3])),
171 Arc::new(StringArray::from(vec!["a", "b", "c"])),
172 ],
173 )
174 .unwrap();
175
176 let encoder = ParquetEncoder::new(schema);
177 let result = encoder.encode_batch(&batch).unwrap();
178 assert_eq!(result.len(), 1);
179
180 let bytes = bytes::Bytes::from(result.into_iter().next().unwrap());
182 let reader = ParquetRecordBatchReaderBuilder::try_new(bytes)
183 .unwrap()
184 .build()
185 .unwrap();
186
187 let batches: Vec<RecordBatch> = reader.map(Result::unwrap).collect();
188 let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
189 assert_eq!(total_rows, 3);
190 }
191
192 #[test]
193 fn test_encode_with_compression() {
194 let schema = make_schema();
195 let batch = RecordBatch::try_new(
196 schema.clone(),
197 vec![
198 Arc::new(Int64Array::from(vec![1])),
199 Arc::new(StringArray::from(vec!["x"])),
200 ],
201 )
202 .unwrap();
203
204 let config = ParquetEncoderConfig::default()
205 .with_compression(Compression::GZIP(GzipLevel::default()));
206 let encoder = ParquetEncoder::with_config(schema, config);
207 let result = encoder.encode_batch(&batch).unwrap();
208 assert_eq!(result.len(), 1);
209 assert!(!result[0].is_empty());
210 }
211
212 #[test]
213 fn test_format_name() {
214 let schema = make_schema();
215 let encoder = ParquetEncoder::new(schema);
216 assert_eq!(encoder.format_name(), "parquet");
217 }
218
219 #[test]
220 fn test_config_builder() {
221 let config = ParquetEncoderConfig::default()
222 .with_compression(Compression::ZSTD(ZstdLevel::default()))
223 .with_writer_version(1)
224 .with_max_row_group_size(500)
225 .with_statistics(false);
226
227 assert!(matches!(config.compression, Compression::ZSTD(_)));
228 assert_eq!(config.writer_version, 1);
229 assert_eq!(config.max_row_group_size, 500);
230 assert!(!config.write_statistics);
231 }
232}