Skip to main content

laminar_connectors/kafka/
avro_serializer.rs

1//! Avro serialization using `arrow-avro` with Confluent Schema Registry.
2//!
3//! `AvroSerializer` implements `RecordSerializer` by wrapping the
4//! `arrow-avro` `Writer` with SOE format, producing per-record payloads
5//! with the Confluent wire format prefix (`0x00` + 4-byte BE schema ID
6//! + Avro body).
7
8use std::sync::atomic::{AtomicU32, Ordering};
9use std::sync::Arc;
10
11use arrow_array::RecordBatch;
12use arrow_avro::schema::FingerprintStrategy;
13use arrow_avro::writer::format::AvroSoeFormat;
14use arrow_avro::writer::WriterBuilder;
15use arrow_schema::SchemaRef;
16
17use crate::error::SerdeError;
18use crate::kafka::schema_registry::SchemaRegistryClient;
19use crate::serde::{Format, RecordSerializer};
20
21/// Avro serializer backed by `arrow-avro` with optional Schema Registry.
22///
23/// Produces per-row byte payloads in the Confluent wire format suitable
24/// for individual Kafka producer messages.
25pub struct AvroSerializer {
26    /// Schema ID shared with `KafkaSink` for post-registration updates.
27    schema_id: Arc<AtomicU32>,
28    /// Arrow schema for the records being serialized.
29    schema: SchemaRef,
30    /// Optional Schema Registry client for schema registration.
31    schema_registry: Option<Arc<SchemaRegistryClient>>,
32}
33
34impl AvroSerializer {
35    /// Creates a new Avro serializer with a known schema ID.
36    ///
37    /// Each serialized record is prefixed with `0x00` + `schema_id` (4-byte BE).
38    #[must_use]
39    pub fn new(schema: SchemaRef, schema_id: u32) -> Self {
40        Self {
41            schema_id: Arc::new(AtomicU32::new(schema_id)),
42            schema,
43            schema_registry: None,
44        }
45    }
46
47    /// Creates a new Avro serializer with a shared schema ID handle.
48    ///
49    /// The `KafkaSink` retains a clone of the `Arc<AtomicU32>` so it can
50    /// update the schema ID after registration without downcasting.
51    #[must_use]
52    pub fn with_shared_schema_id(
53        schema: SchemaRef,
54        schema_id: Arc<AtomicU32>,
55        registry: Option<Arc<SchemaRegistryClient>>,
56    ) -> Self {
57        Self {
58            schema_id,
59            schema,
60            schema_registry: registry,
61        }
62    }
63
64    /// Creates a new Avro serializer with Schema Registry integration.
65    #[must_use]
66    pub fn with_schema_registry(
67        schema: SchemaRef,
68        schema_id: u32,
69        registry: Arc<SchemaRegistryClient>,
70    ) -> Self {
71        Self {
72            schema_id: Arc::new(AtomicU32::new(schema_id)),
73            schema,
74            schema_registry: Some(registry),
75        }
76    }
77
78    /// Returns a shared handle to the schema ID for external updates.
79    #[must_use]
80    pub fn schema_id_handle(&self) -> Arc<AtomicU32> {
81        Arc::clone(&self.schema_id)
82    }
83
84    /// Returns the current schema ID.
85    #[must_use]
86    pub fn schema_id(&self) -> u32 {
87        self.schema_id.load(Ordering::Relaxed)
88    }
89
90    /// Returns whether a Schema Registry client is configured.
91    #[must_use]
92    pub fn has_schema_registry(&self) -> bool {
93        self.schema_registry.is_some()
94    }
95
96    /// Serializes a `RecordBatch` into per-row Avro payloads with
97    /// Confluent wire format prefix.
98    ///
99    /// Each output `Vec<u8>` is: `0x00` | `schema_id` (4-byte BE) | Avro body.
100    ///
101    /// Serializes one row at a time to produce exact record boundaries.
102    /// This avoids the unsound byte-scanning approach where Avro data values
103    /// could contain the magic byte + schema ID pattern.
104    fn serialize_with_confluent_prefix(
105        &self,
106        batch: &RecordBatch,
107    ) -> Result<Vec<Vec<u8>>, SerdeError> {
108        let num_rows = batch.num_rows();
109        if num_rows == 0 {
110            return Ok(Vec::new());
111        }
112
113        let id = self.schema_id.load(Ordering::Relaxed);
114        // Clone schema once, outside the loop.
115        let arrow_schema = (*self.schema).clone();
116        let mut records = Vec::with_capacity(num_rows);
117
118        // Serialize each row individually to get exact record boundaries.
119        // batch.slice() is zero-copy (Arc offset adjustment only).
120        for row_idx in 0..num_rows {
121            let mut buf = Vec::new();
122            let row_batch = batch.slice(row_idx, 1);
123
124            let mut writer = WriterBuilder::new(arrow_schema.clone())
125                .with_fingerprint_strategy(FingerprintStrategy::Id(id))
126                .build::<_, AvroSoeFormat>(&mut buf)
127                .map_err(|e| {
128                    SerdeError::MalformedInput(format!("failed to build Avro writer: {e}"))
129                })?;
130
131            writer
132                .write(&row_batch)
133                .map_err(|e| SerdeError::MalformedInput(format!("Avro encode error: {e}")))?;
134
135            writer
136                .finish()
137                .map_err(|e| SerdeError::MalformedInput(format!("Avro flush error: {e}")))?;
138
139            records.push(buf);
140        }
141
142        Ok(records)
143    }
144}
145
146impl RecordSerializer for AvroSerializer {
147    fn serialize(&self, batch: &RecordBatch) -> Result<Vec<Vec<u8>>, SerdeError> {
148        self.serialize_with_confluent_prefix(batch)
149    }
150
151    fn serialize_batch(&self, batch: &RecordBatch) -> Result<Vec<u8>, SerdeError> {
152        if batch.num_rows() == 0 {
153            return Ok(Vec::new());
154        }
155
156        let mut buf = Vec::new();
157        let arrow_schema = (*self.schema).clone();
158        let id = self.schema_id.load(Ordering::Relaxed);
159
160        let mut writer = WriterBuilder::new(arrow_schema)
161            .with_fingerprint_strategy(FingerprintStrategy::Id(id))
162            .build::<_, AvroSoeFormat>(&mut buf)
163            .map_err(|e| SerdeError::MalformedInput(format!("failed to build Avro writer: {e}")))?;
164
165        writer
166            .write(batch)
167            .map_err(|e| SerdeError::MalformedInput(format!("Avro encode error: {e}")))?;
168
169        writer
170            .finish()
171            .map_err(|e| SerdeError::MalformedInput(format!("Avro flush error: {e}")))?;
172
173        Ok(buf)
174    }
175
176    fn format(&self) -> Format {
177        Format::Avro
178    }
179}
180
181impl std::fmt::Debug for AvroSerializer {
182    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183        f.debug_struct("AvroSerializer")
184            .field("schema_id", &self.schema_id.load(Ordering::Relaxed))
185            .field("has_registry", &self.schema_registry.is_some())
186            .finish_non_exhaustive()
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use arrow_array::{Int64Array, StringArray};
194
195    /// Confluent wire format header size (1 magic + 4 schema ID).
196    const CONFLUENT_HEADER_SIZE: usize = 5;
197
198    /// Confluent wire format magic byte.
199    const CONFLUENT_MAGIC: u8 = 0x00;
200    use arrow_schema::{DataType, Field, Schema};
201
202    fn test_schema() -> SchemaRef {
203        Arc::new(Schema::new(vec![
204            Field::new("id", DataType::Int64, false),
205            Field::new("name", DataType::Utf8, false),
206        ]))
207    }
208
209    fn test_batch(n: usize) -> RecordBatch {
210        let ids: Vec<i64> = (0..n as i64).collect();
211        let names: Vec<String> = (0..n).map(|i| format!("name-{i}")).collect();
212        RecordBatch::try_new(
213            test_schema(),
214            vec![
215                Arc::new(Int64Array::from(ids)),
216                Arc::new(StringArray::from(names)),
217            ],
218        )
219        .unwrap()
220    }
221
222    #[test]
223    fn test_new_serializer() {
224        let ser = AvroSerializer::new(test_schema(), 42);
225        assert_eq!(ser.schema_id(), 42);
226        assert!(!ser.has_schema_registry());
227        assert_eq!(ser.format(), Format::Avro);
228    }
229
230    #[test]
231    fn test_shared_schema_id() {
232        let ser = AvroSerializer::new(test_schema(), 1);
233        assert_eq!(ser.schema_id(), 1);
234        let handle = ser.schema_id_handle();
235        handle.store(99, std::sync::atomic::Ordering::Relaxed);
236        assert_eq!(ser.schema_id(), 99);
237    }
238
239    #[test]
240    fn test_serialize_empty_batch() {
241        let ser = AvroSerializer::new(test_schema(), 1);
242        let batch = RecordBatch::new_empty(test_schema());
243        let result = ser.serialize(&batch).unwrap();
244        assert!(result.is_empty());
245    }
246
247    #[test]
248    fn test_serialize_batch_produces_records() {
249        let ser = AvroSerializer::new(test_schema(), 7);
250        let batch = test_batch(3);
251        let records = ser.serialize(&batch).unwrap();
252        assert_eq!(records.len(), 3);
253
254        // Each record should start with Confluent wire format
255        for record in &records {
256            assert!(record.len() >= CONFLUENT_HEADER_SIZE);
257            assert_eq!(record[0], CONFLUENT_MAGIC);
258            // Schema ID = 7 in big-endian
259            assert_eq!(&record[1..5], &7u32.to_be_bytes());
260        }
261    }
262
263    #[test]
264    fn test_serialize_batch_to_single_buffer() {
265        let ser = AvroSerializer::new(test_schema(), 1);
266        let batch = test_batch(2);
267        let buf = ser.serialize_batch(&batch).unwrap();
268        assert!(!buf.is_empty());
269        // Should contain Confluent prefix
270        assert_eq!(buf[0], CONFLUENT_MAGIC);
271    }
272
273    #[test]
274    fn test_debug_output() {
275        let ser = AvroSerializer::new(test_schema(), 42);
276        let debug = format!("{ser:?}");
277        assert!(debug.contains("AvroSerializer"));
278        assert!(debug.contains("42"));
279    }
280}