laminar_connectors/kafka/
avro_serializer.rs1use std::sync::atomic::{AtomicU32, Ordering};
9use std::sync::Arc;
10
11use arrow_array::RecordBatch;
12use arrow_avro::schema::FingerprintStrategy;
13use arrow_avro::writer::format::AvroSoeFormat;
14use arrow_avro::writer::WriterBuilder;
15use arrow_schema::SchemaRef;
16
17use crate::error::SerdeError;
18use crate::kafka::schema_registry::SchemaRegistryClient;
19use crate::serde::{Format, RecordSerializer};
20
21pub struct AvroSerializer {
26 schema_id: Arc<AtomicU32>,
28 schema: SchemaRef,
30 schema_registry: Option<Arc<SchemaRegistryClient>>,
32}
33
34impl AvroSerializer {
35 #[must_use]
39 pub fn new(schema: SchemaRef, schema_id: u32) -> Self {
40 Self {
41 schema_id: Arc::new(AtomicU32::new(schema_id)),
42 schema,
43 schema_registry: None,
44 }
45 }
46
47 #[must_use]
52 pub fn with_shared_schema_id(
53 schema: SchemaRef,
54 schema_id: Arc<AtomicU32>,
55 registry: Option<Arc<SchemaRegistryClient>>,
56 ) -> Self {
57 Self {
58 schema_id,
59 schema,
60 schema_registry: registry,
61 }
62 }
63
64 #[must_use]
66 pub fn with_schema_registry(
67 schema: SchemaRef,
68 schema_id: u32,
69 registry: Arc<SchemaRegistryClient>,
70 ) -> Self {
71 Self {
72 schema_id: Arc::new(AtomicU32::new(schema_id)),
73 schema,
74 schema_registry: Some(registry),
75 }
76 }
77
78 #[must_use]
80 pub fn schema_id_handle(&self) -> Arc<AtomicU32> {
81 Arc::clone(&self.schema_id)
82 }
83
84 #[must_use]
86 pub fn schema_id(&self) -> u32 {
87 self.schema_id.load(Ordering::Relaxed)
88 }
89
90 #[must_use]
92 pub fn has_schema_registry(&self) -> bool {
93 self.schema_registry.is_some()
94 }
95
96 fn serialize_with_confluent_prefix(
105 &self,
106 batch: &RecordBatch,
107 ) -> Result<Vec<Vec<u8>>, SerdeError> {
108 let num_rows = batch.num_rows();
109 if num_rows == 0 {
110 return Ok(Vec::new());
111 }
112
113 let id = self.schema_id.load(Ordering::Relaxed);
114 let arrow_schema = (*self.schema).clone();
116 let mut records = Vec::with_capacity(num_rows);
117
118 for row_idx in 0..num_rows {
121 let mut buf = Vec::new();
122 let row_batch = batch.slice(row_idx, 1);
123
124 let mut writer = WriterBuilder::new(arrow_schema.clone())
125 .with_fingerprint_strategy(FingerprintStrategy::Id(id))
126 .build::<_, AvroSoeFormat>(&mut buf)
127 .map_err(|e| {
128 SerdeError::MalformedInput(format!("failed to build Avro writer: {e}"))
129 })?;
130
131 writer
132 .write(&row_batch)
133 .map_err(|e| SerdeError::MalformedInput(format!("Avro encode error: {e}")))?;
134
135 writer
136 .finish()
137 .map_err(|e| SerdeError::MalformedInput(format!("Avro flush error: {e}")))?;
138
139 records.push(buf);
140 }
141
142 Ok(records)
143 }
144}
145
146impl RecordSerializer for AvroSerializer {
147 fn serialize(&self, batch: &RecordBatch) -> Result<Vec<Vec<u8>>, SerdeError> {
148 self.serialize_with_confluent_prefix(batch)
149 }
150
151 fn serialize_batch(&self, batch: &RecordBatch) -> Result<Vec<u8>, SerdeError> {
152 if batch.num_rows() == 0 {
153 return Ok(Vec::new());
154 }
155
156 let mut buf = Vec::new();
157 let arrow_schema = (*self.schema).clone();
158 let id = self.schema_id.load(Ordering::Relaxed);
159
160 let mut writer = WriterBuilder::new(arrow_schema)
161 .with_fingerprint_strategy(FingerprintStrategy::Id(id))
162 .build::<_, AvroSoeFormat>(&mut buf)
163 .map_err(|e| SerdeError::MalformedInput(format!("failed to build Avro writer: {e}")))?;
164
165 writer
166 .write(batch)
167 .map_err(|e| SerdeError::MalformedInput(format!("Avro encode error: {e}")))?;
168
169 writer
170 .finish()
171 .map_err(|e| SerdeError::MalformedInput(format!("Avro flush error: {e}")))?;
172
173 Ok(buf)
174 }
175
176 fn format(&self) -> Format {
177 Format::Avro
178 }
179}
180
181impl std::fmt::Debug for AvroSerializer {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 f.debug_struct("AvroSerializer")
184 .field("schema_id", &self.schema_id.load(Ordering::Relaxed))
185 .field("has_registry", &self.schema_registry.is_some())
186 .finish_non_exhaustive()
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use arrow_array::{Int64Array, StringArray};
194
195 const CONFLUENT_HEADER_SIZE: usize = 5;
197
198 const CONFLUENT_MAGIC: u8 = 0x00;
200 use arrow_schema::{DataType, Field, Schema};
201
202 fn test_schema() -> SchemaRef {
203 Arc::new(Schema::new(vec![
204 Field::new("id", DataType::Int64, false),
205 Field::new("name", DataType::Utf8, false),
206 ]))
207 }
208
209 fn test_batch(n: usize) -> RecordBatch {
210 let ids: Vec<i64> = (0..n as i64).collect();
211 let names: Vec<String> = (0..n).map(|i| format!("name-{i}")).collect();
212 RecordBatch::try_new(
213 test_schema(),
214 vec![
215 Arc::new(Int64Array::from(ids)),
216 Arc::new(StringArray::from(names)),
217 ],
218 )
219 .unwrap()
220 }
221
222 #[test]
223 fn test_new_serializer() {
224 let ser = AvroSerializer::new(test_schema(), 42);
225 assert_eq!(ser.schema_id(), 42);
226 assert!(!ser.has_schema_registry());
227 assert_eq!(ser.format(), Format::Avro);
228 }
229
230 #[test]
231 fn test_shared_schema_id() {
232 let ser = AvroSerializer::new(test_schema(), 1);
233 assert_eq!(ser.schema_id(), 1);
234 let handle = ser.schema_id_handle();
235 handle.store(99, std::sync::atomic::Ordering::Relaxed);
236 assert_eq!(ser.schema_id(), 99);
237 }
238
239 #[test]
240 fn test_serialize_empty_batch() {
241 let ser = AvroSerializer::new(test_schema(), 1);
242 let batch = RecordBatch::new_empty(test_schema());
243 let result = ser.serialize(&batch).unwrap();
244 assert!(result.is_empty());
245 }
246
247 #[test]
248 fn test_serialize_batch_produces_records() {
249 let ser = AvroSerializer::new(test_schema(), 7);
250 let batch = test_batch(3);
251 let records = ser.serialize(&batch).unwrap();
252 assert_eq!(records.len(), 3);
253
254 for record in &records {
256 assert!(record.len() >= CONFLUENT_HEADER_SIZE);
257 assert_eq!(record[0], CONFLUENT_MAGIC);
258 assert_eq!(&record[1..5], &7u32.to_be_bytes());
260 }
261 }
262
263 #[test]
264 fn test_serialize_batch_to_single_buffer() {
265 let ser = AvroSerializer::new(test_schema(), 1);
266 let batch = test_batch(2);
267 let buf = ser.serialize_batch(&batch).unwrap();
268 assert!(!buf.is_empty());
269 assert_eq!(buf[0], CONFLUENT_MAGIC);
271 }
272
273 #[test]
274 fn test_debug_output() {
275 let ser = AvroSerializer::new(test_schema(), 42);
276 let debug = format!("{ser:?}");
277 assert!(debug.contains("AvroSerializer"));
278 assert!(debug.contains("42"));
279 }
280}