Skip to main content

laminar_connectors/kafka/
avro.rs

1//! Avro deserialization using `arrow-avro` with Confluent Schema Registry.
2//!
3//! [`AvroDeserializer`] implements [`RecordDeserializer`] by wrapping the
4//! `arrow-avro` push-based [`Decoder`], which
5//! natively supports the Confluent wire format (`0x00` + 4-byte BE schema ID
6//! + Avro payload).
7
8use std::collections::HashSet;
9use std::sync::Arc;
10
11use arrow_array::RecordBatch;
12use arrow_avro::reader::{Decoder, ReaderBuilder};
13use arrow_avro::schema::{AvroSchema, Fingerprint, FingerprintAlgorithm, SchemaStore};
14use arrow_schema::SchemaRef;
15use parking_lot::Mutex;
16
17use crate::error::SerdeError;
18use crate::kafka::schema_registry::SchemaRegistryClient;
19use crate::serde::{Format, RecordDeserializer};
20
21const DECODER_BATCH_CAPACITY: usize = 8192;
22
23/// Confluent wire format magic byte.
24const CONFLUENT_MAGIC: u8 = 0x00;
25
26/// Size of the Confluent wire format header (1 magic + 4 schema ID).
27const CONFLUENT_HEADER_SIZE: usize = 5;
28
29/// Avro deserializer backed by `arrow-avro` with optional Schema Registry.
30///
31/// Supports both raw Avro and the Confluent wire format. When a Schema
32/// Registry client is provided, unknown schema IDs are fetched and
33/// registered automatically.
34pub struct AvroDeserializer {
35    /// Schema store shared with the Decoder.
36    schema_store: SchemaStore,
37    /// Optional Schema Registry client for resolving unknown IDs.
38    schema_registry: Option<Arc<SchemaRegistryClient>>,
39    /// Set of schema IDs already registered in the store.
40    known_ids: HashSet<i32>,
41    /// Reused across batches; rebuilt when `register_schema` runs.
42    decoder: Mutex<Option<Decoder>>,
43}
44
45impl AvroDeserializer {
46    /// Creates a new Avro deserializer without Schema Registry integration.
47    ///
48    /// The caller must register schemas manually via [`register_schema`](Self::register_schema).
49    #[must_use]
50    pub fn new() -> Self {
51        Self {
52            schema_store: SchemaStore::new_with_type(FingerprintAlgorithm::Id),
53            schema_registry: None,
54            known_ids: HashSet::new(),
55            decoder: Mutex::new(None),
56        }
57    }
58
59    /// Creates a new Avro deserializer with Schema Registry integration.
60    ///
61    /// Unknown schema IDs encountered in the Confluent wire format will
62    /// be fetched from the registry automatically.
63    #[must_use]
64    pub fn with_schema_registry(registry: Arc<SchemaRegistryClient>) -> Self {
65        Self {
66            schema_store: SchemaStore::new_with_type(FingerprintAlgorithm::Id),
67            schema_registry: Some(registry),
68            known_ids: HashSet::new(),
69            decoder: Mutex::new(None),
70        }
71    }
72
73    /// Registers an Avro schema with a Confluent schema ID.
74    ///
75    /// # Errors
76    ///
77    /// Returns `SerdeError` if the fingerprint cannot be set.
78    #[allow(clippy::cast_sign_loss)]
79    pub fn register_schema(
80        &mut self,
81        schema_id: i32,
82        avro_schema_json: &str,
83    ) -> Result<(), SerdeError> {
84        let avro_schema = AvroSchema::new(avro_schema_json.to_string());
85        // Use Fingerprint::Id directly — NOT load_fingerprint_id which
86        // applies from_be byte-swap meant for raw wire bytes.
87        let fp = Fingerprint::Id(schema_id as u32);
88        self.schema_store
89            .set(fp, avro_schema)
90            .map_err(|e| SerdeError::MalformedInput(format!("failed to register schema: {e}")))?;
91        self.known_ids.insert(schema_id);
92        // Schema store changed — drop the cached decoder so it rebuilds
93        // against the new store on the next deserialize_batch call.
94        *self.decoder.lock() = None;
95        Ok(())
96    }
97
98    /// Ensures a schema ID is registered, fetching from SR if needed.
99    ///
100    /// Called by the Kafka source connector when an unknown schema ID is
101    /// encountered in the Confluent wire format during poll.
102    ///
103    /// # Errors
104    ///
105    /// Returns `SerdeError` if the schema cannot be fetched or registered.
106    /// Returns `Ok(true)` if this was a newly registered schema ID,
107    /// `Ok(false)` if already known.
108    pub async fn ensure_schema_registered(&mut self, schema_id: i32) -> Result<bool, SerdeError> {
109        if self.known_ids.contains(&schema_id) {
110            return Ok(false);
111        }
112
113        let registry = self
114            .schema_registry
115            .as_ref()
116            .ok_or(SerdeError::SchemaNotFound { schema_id })?;
117
118        let cached = registry
119            .resolve_confluent_id(schema_id)
120            .await
121            .map_err(|_| SerdeError::SchemaNotFound { schema_id })?;
122
123        self.register_schema(schema_id, &cached.schema_str)?;
124        Ok(true)
125    }
126
127    /// Extracts the Confluent schema ID from a wire-format message.
128    ///
129    /// Returns `None` if the message is not in Confluent wire format.
130    #[must_use]
131    pub fn extract_confluent_id(data: &[u8]) -> Option<i32> {
132        if data.len() < CONFLUENT_HEADER_SIZE || data[0] != CONFLUENT_MAGIC {
133            return None;
134        }
135        let id = i32::from_be_bytes([data[1], data[2], data[3], data[4]]);
136        Some(id)
137    }
138}
139
140impl Default for AvroDeserializer {
141    fn default() -> Self {
142        Self::new()
143    }
144}
145
146impl RecordDeserializer for AvroDeserializer {
147    fn deserialize(&self, data: &[u8], schema: &SchemaRef) -> Result<RecordBatch, SerdeError> {
148        self.deserialize_batch(&[data], schema)
149    }
150
151    fn deserialize_batch(
152        &self,
153        records: &[&[u8]],
154        schema: &SchemaRef,
155    ) -> Result<RecordBatch, SerdeError> {
156        if records.is_empty() {
157            return Ok(RecordBatch::new_empty(schema.clone()));
158        }
159
160        let mut guard = self.decoder.lock();
161        let decoder = if let Some(d) = guard.as_mut() {
162            d
163        } else {
164            let d = ReaderBuilder::new()
165                .with_batch_size(DECODER_BATCH_CAPACITY)
166                .with_writer_schema_store(self.schema_store.clone())
167                .build_decoder()
168                .map_err(|e| SerdeError::MalformedInput(format!("failed to build decoder: {e}")))?;
169            guard.insert(d)
170        };
171
172        let mut partials: Vec<RecordBatch> = Vec::new();
173        for record in records {
174            let mut offset = 0;
175            while offset < record.len() {
176                let consumed = decoder
177                    .decode(&record[offset..])
178                    .map_err(|e| SerdeError::MalformedInput(format!("Avro decode error: {e}")))?;
179                if consumed == 0 {
180                    break;
181                }
182                offset += consumed;
183                if decoder.batch_is_full() {
184                    if let Some(b) = decoder
185                        .flush()
186                        .map_err(|e| SerdeError::MalformedInput(format!("Avro flush: {e}")))?
187                    {
188                        partials.push(b);
189                    }
190                }
191            }
192        }
193        if let Some(b) = decoder
194            .flush()
195            .map_err(|e| SerdeError::MalformedInput(format!("Avro flush: {e}")))?
196        {
197            partials.push(b);
198        }
199
200        match partials.len() {
201            0 => Err(SerdeError::MalformedInput("no records decoded".into())),
202            1 => Ok(partials.pop().unwrap()),
203            _ => arrow_select::concat::concat_batches(schema, &partials)
204                .map_err(|e| SerdeError::MalformedInput(format!("concat: {e}"))),
205        }
206    }
207
208    fn format(&self) -> Format {
209        Format::Avro
210    }
211
212    fn as_any_mut(&mut self) -> Option<&mut dyn std::any::Any> {
213        Some(self)
214    }
215}
216
217impl std::fmt::Debug for AvroDeserializer {
218    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219        f.debug_struct("AvroDeserializer")
220            .field("known_ids", &self.known_ids)
221            .field("has_registry", &self.schema_registry.is_some())
222            .finish_non_exhaustive()
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229    use arrow_schema::{DataType, Field, Schema};
230
231    const TEST_AVRO_SCHEMA: &str = r#"{
232        "type": "record",
233        "name": "test",
234        "fields": [
235            {"name": "id", "type": "long"},
236            {"name": "name", "type": "string"}
237        ]
238    }"#;
239
240    fn test_arrow_schema() -> SchemaRef {
241        Arc::new(Schema::new(vec![
242            Field::new("id", DataType::Int64, false),
243            Field::new("name", DataType::Utf8, false),
244        ]))
245    }
246
247    /// Avro `union<null, map<string, union<null, double>>>` must produce
248    /// identical Arrow schemas from both `avro_to_arrow_schema` (SR path)
249    /// and `AvroDeserializer` (wire decode path).
250    #[test]
251    fn test_nullable_map_nullable_double_full_path() {
252        fn zigzag(val: i64) -> Vec<u8> {
253            let mut z = ((val << 1) ^ (val >> 63)) as u64;
254            let mut buf = Vec::new();
255            loop {
256                if z & !0x7F == 0 {
257                    buf.push(z as u8);
258                    break;
259                }
260                buf.push((z as u8 & 0x7F) | 0x80);
261                z >>= 7;
262            }
263            buf
264        }
265        fn avro_string(s: &str) -> Vec<u8> {
266            let mut b = zigzag(s.len() as i64);
267            b.extend_from_slice(s.as_bytes());
268            b
269        }
270
271        let avro_json = r#"{
272            "type": "record",
273            "name": "Metrics",
274            "fields": [
275                {"name": "sensor_id", "type": "string"},
276                {
277                    "name": "data",
278                    "type": ["null", {"type": "map", "values": ["null", "double"]}]
279                }
280            ]
281        }"#;
282
283        // Path 1: what the schema registry infers
284        let sr_schema = crate::kafka::schema_registry::avro_to_arrow_schema(avro_json)
285            .expect("avro_to_arrow_schema should handle nullable map");
286
287        // Path 2: what the decoder actually produces from wire bytes
288        let mut deser = AvroDeserializer::new();
289        deser.register_schema(42, avro_json).unwrap();
290
291        // Encode { sensor_id: "s1", data: {"temp": 23.5} } in Avro binary.
292        // Union branches are prefixed with a zigzag-encoded index.
293        let mut payload = Vec::new();
294        payload.extend_from_slice(&avro_string("s1"));
295        payload.extend_from_slice(&zigzag(1)); // data: branch 1 (map, not null)
296        payload.extend_from_slice(&zigzag(1)); // map block: 1 entry
297        payload.extend_from_slice(&avro_string("temp"));
298        payload.extend_from_slice(&zigzag(1)); // value: branch 1 (double, not null)
299        payload.extend_from_slice(&23.5_f64.to_le_bytes());
300        payload.extend_from_slice(&zigzag(0)); // end of map
301
302        // Confluent wire frame: magic + schema_id + payload
303        let mut msg = vec![0x00u8];
304        msg.extend_from_slice(&42i32.to_be_bytes());
305        msg.extend_from_slice(&payload);
306
307        let batch = deser
308            .deserialize_batch(&[msg.as_slice()], &sr_schema)
309            .expect("decode should succeed");
310        assert_eq!(batch.num_rows(), 1);
311
312        let sr_data = sr_schema.field_with_name("data").unwrap();
313        let decoded_schema = batch.schema();
314        let dec_data = decoded_schema.field_with_name("data").unwrap();
315        assert_eq!(
316            sr_data.data_type(),
317            dec_data.data_type(),
318            "schema registry and arrow-avro must produce identical Map types"
319        );
320    }
321
322    #[test]
323    fn test_extract_confluent_id() {
324        // Valid: 0x00 + 4-byte BE schema ID
325        let data = [0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03];
326        assert_eq!(AvroDeserializer::extract_confluent_id(&data), Some(1));
327
328        let data = [0x00, 0x00, 0x00, 0x01, 0x00, 0x02, 0x03];
329        assert_eq!(AvroDeserializer::extract_confluent_id(&data), Some(256));
330    }
331
332    #[test]
333    fn test_extract_confluent_id_not_confluent() {
334        let data = [0x01, 0x00, 0x00, 0x00, 0x01];
335        assert_eq!(AvroDeserializer::extract_confluent_id(&data), None);
336    }
337
338    #[test]
339    fn test_extract_confluent_id_too_short() {
340        let data = [0x00, 0x00];
341        assert_eq!(AvroDeserializer::extract_confluent_id(&data), None);
342    }
343
344    #[test]
345    fn test_new_deserializer() {
346        let deser = AvroDeserializer::new();
347        assert!(deser.schema_registry.is_none());
348        assert!(deser.known_ids.is_empty());
349    }
350
351    #[test]
352    fn test_register_schema() {
353        let mut deser = AvroDeserializer::new();
354        let result = deser.register_schema(1, TEST_AVRO_SCHEMA);
355        assert!(result.is_ok());
356        assert!(deser.known_ids.contains(&1));
357    }
358
359    #[test]
360    fn test_format() {
361        let deser = AvroDeserializer::new();
362        assert_eq!(deser.format(), Format::Avro);
363    }
364
365    #[test]
366    fn test_deserialize_empty_batch() {
367        let deser = AvroDeserializer::new();
368        let schema = test_arrow_schema();
369        let result = deser.deserialize_batch(&[], &schema);
370        assert!(result.is_ok());
371        assert_eq!(result.unwrap().num_rows(), 0);
372    }
373
374    #[test]
375    fn test_extract_confluent_id_large() {
376        // Schema ID 42
377        let mut data = vec![0x00u8];
378        data.extend_from_slice(&42i32.to_be_bytes());
379        data.push(0xFF);
380        assert_eq!(AvroDeserializer::extract_confluent_id(&data), Some(42));
381    }
382
383    #[test]
384    fn test_extract_confluent_id_edge_cases() {
385        // Empty slice
386        assert_eq!(AvroDeserializer::extract_confluent_id(&[]), None);
387
388        // Magic byte only (too short)
389        assert_eq!(AvroDeserializer::extract_confluent_id(&[0x00]), None);
390
391        // 4 bytes with magic (too short)
392        assert_eq!(
393            AvroDeserializer::extract_confluent_id(&[0x00, 0x00, 0x00, 0x00]),
394            None
395        );
396
397        // Exactly 5 bytes (boundary of CONFLUENT_HEADER_SIZE)
398        let data = [0x00, 0x00, 0x00, 0x00, 0x05];
399        assert_eq!(AvroDeserializer::extract_confluent_id(&data), Some(5));
400
401        // i32::MAX ID
402        let mut data = vec![0x00];
403        data.extend_from_slice(&i32::MAX.to_be_bytes());
404        assert_eq!(
405            AvroDeserializer::extract_confluent_id(&data),
406            Some(i32::MAX)
407        );
408
409        // i32::MIN ID
410        let mut data = vec![0x00];
411        data.extend_from_slice(&i32::MIN.to_be_bytes());
412        assert_eq!(
413            AvroDeserializer::extract_confluent_id(&data),
414            Some(i32::MIN)
415        );
416    }
417}