Skip to main content

laminar_connectors/kafka/
avro.rs

1//! Avro deserialization using `arrow-avro` with Confluent Schema Registry.
2//!
3//! [`AvroDeserializer`] implements [`RecordDeserializer`] by wrapping the
4//! `arrow-avro` push-based [`Decoder`](arrow_avro::reader::Decoder), which
5//! natively supports the Confluent wire format (`0x00` + 4-byte BE schema ID
6//! + Avro payload).
7
8use std::collections::HashSet;
9use std::sync::Arc;
10
11use arrow_array::RecordBatch;
12use arrow_avro::reader::ReaderBuilder;
13use arrow_avro::schema::{AvroSchema, Fingerprint, FingerprintAlgorithm, SchemaStore};
14use arrow_schema::SchemaRef;
15
16use crate::error::SerdeError;
17use crate::kafka::schema_registry::SchemaRegistryClient;
18use crate::serde::{Format, RecordDeserializer};
19
20/// Confluent wire format magic byte.
21const CONFLUENT_MAGIC: u8 = 0x00;
22
23/// Size of the Confluent wire format header (1 magic + 4 schema ID).
24const CONFLUENT_HEADER_SIZE: usize = 5;
25
26/// Avro deserializer backed by `arrow-avro` with optional Schema Registry.
27///
28/// Supports both raw Avro and the Confluent wire format. When a Schema
29/// Registry client is provided, unknown schema IDs are fetched and
30/// registered automatically.
31pub struct AvroDeserializer {
32    /// Schema store shared with the Decoder.
33    schema_store: SchemaStore,
34    /// Optional Schema Registry client for resolving unknown IDs.
35    schema_registry: Option<Arc<SchemaRegistryClient>>,
36    /// Set of schema IDs already registered in the store.
37    known_ids: HashSet<i32>,
38}
39
40impl AvroDeserializer {
41    /// Creates a new Avro deserializer without Schema Registry integration.
42    ///
43    /// The caller must register schemas manually via [`register_schema`](Self::register_schema).
44    #[must_use]
45    pub fn new() -> Self {
46        Self {
47            schema_store: SchemaStore::new_with_type(FingerprintAlgorithm::Id),
48            schema_registry: None,
49            known_ids: HashSet::new(),
50        }
51    }
52
53    /// Creates a new Avro deserializer with Schema Registry integration.
54    ///
55    /// Unknown schema IDs encountered in the Confluent wire format will
56    /// be fetched from the registry automatically.
57    #[must_use]
58    pub fn with_schema_registry(registry: Arc<SchemaRegistryClient>) -> Self {
59        Self {
60            schema_store: SchemaStore::new_with_type(FingerprintAlgorithm::Id),
61            schema_registry: Some(registry),
62            known_ids: HashSet::new(),
63        }
64    }
65
66    /// Registers an Avro schema with a Confluent schema ID.
67    ///
68    /// # Errors
69    ///
70    /// Returns `SerdeError` if the fingerprint cannot be set.
71    #[allow(clippy::cast_sign_loss)]
72    pub fn register_schema(
73        &mut self,
74        schema_id: i32,
75        avro_schema_json: &str,
76    ) -> Result<(), SerdeError> {
77        let avro_schema = AvroSchema::new(avro_schema_json.to_string());
78        // Use Fingerprint::Id directly — NOT load_fingerprint_id which
79        // applies from_be byte-swap meant for raw wire bytes.
80        let fp = Fingerprint::Id(schema_id as u32);
81        self.schema_store
82            .set(fp, avro_schema)
83            .map_err(|e| SerdeError::MalformedInput(format!("failed to register schema: {e}")))?;
84        self.known_ids.insert(schema_id);
85        Ok(())
86    }
87
88    /// Ensures a schema ID is registered, fetching from SR if needed.
89    ///
90    /// Called by the Kafka source connector when an unknown schema ID is
91    /// encountered in the Confluent wire format during poll.
92    ///
93    /// # Errors
94    ///
95    /// Returns `SerdeError` if the schema cannot be fetched or registered.
96    pub async fn ensure_schema_registered(&mut self, schema_id: i32) -> Result<(), SerdeError> {
97        if self.known_ids.contains(&schema_id) {
98            return Ok(());
99        }
100
101        let registry = self
102            .schema_registry
103            .as_ref()
104            .ok_or(SerdeError::SchemaNotFound { schema_id })?;
105
106        let cached = registry
107            .resolve_confluent_id(schema_id)
108            .await
109            .map_err(|_| SerdeError::SchemaNotFound { schema_id })?;
110
111        self.register_schema(schema_id, &cached.schema_str)?;
112        Ok(())
113    }
114
115    /// Extracts the Confluent schema ID from a wire-format message.
116    ///
117    /// Returns `None` if the message is not in Confluent wire format.
118    #[must_use]
119    pub fn extract_confluent_id(data: &[u8]) -> Option<i32> {
120        if data.len() < CONFLUENT_HEADER_SIZE || data[0] != CONFLUENT_MAGIC {
121            return None;
122        }
123        let id = i32::from_be_bytes([data[1], data[2], data[3], data[4]]);
124        Some(id)
125    }
126}
127
128impl Default for AvroDeserializer {
129    fn default() -> Self {
130        Self::new()
131    }
132}
133
134impl RecordDeserializer for AvroDeserializer {
135    fn deserialize(&self, data: &[u8], schema: &SchemaRef) -> Result<RecordBatch, SerdeError> {
136        self.deserialize_batch(&[data], schema)
137    }
138
139    fn deserialize_batch(
140        &self,
141        records: &[&[u8]],
142        schema: &SchemaRef,
143    ) -> Result<RecordBatch, SerdeError> {
144        if records.is_empty() {
145            return Ok(RecordBatch::new_empty(schema.clone()));
146        }
147
148        let mut decoder = ReaderBuilder::new()
149            .with_batch_size(records.len())
150            .with_writer_schema_store(self.schema_store.clone())
151            .build_decoder()
152            .map_err(|e| SerdeError::MalformedInput(format!("failed to build decoder: {e}")))?;
153
154        for record in records {
155            let mut offset = 0;
156            while offset < record.len() {
157                let consumed = decoder
158                    .decode(&record[offset..])
159                    .map_err(|e| SerdeError::MalformedInput(format!("Avro decode error: {e}")))?;
160                if consumed == 0 {
161                    break;
162                }
163                offset += consumed;
164            }
165        }
166
167        decoder
168            .flush()
169            .map_err(|e| SerdeError::MalformedInput(format!("Avro flush error: {e}")))?
170            .ok_or_else(|| SerdeError::MalformedInput("no records decoded".into()))
171    }
172
173    fn format(&self) -> Format {
174        Format::Avro
175    }
176
177    fn as_any_mut(&mut self) -> Option<&mut dyn std::any::Any> {
178        Some(self)
179    }
180}
181
182impl std::fmt::Debug for AvroDeserializer {
183    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184        f.debug_struct("AvroDeserializer")
185            .field("known_ids", &self.known_ids)
186            .field("has_registry", &self.schema_registry.is_some())
187            .finish_non_exhaustive()
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use arrow_schema::{DataType, Field, Schema};
195
196    #[allow(dead_code)]
197    const TEST_AVRO_SCHEMA: &str = r#"{
198        "type": "record",
199        "name": "test",
200        "fields": [
201            {"name": "id", "type": "long"},
202            {"name": "name", "type": "string"}
203        ]
204    }"#;
205
206    fn test_arrow_schema() -> SchemaRef {
207        Arc::new(Schema::new(vec![
208            Field::new("id", DataType::Int64, false),
209            Field::new("name", DataType::Utf8, false),
210        ]))
211    }
212
213    #[test]
214    fn test_extract_confluent_id() {
215        // Valid: 0x00 + 4-byte BE schema ID
216        let data = [0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03];
217        assert_eq!(AvroDeserializer::extract_confluent_id(&data), Some(1));
218
219        let data = [0x00, 0x00, 0x00, 0x01, 0x00, 0x02, 0x03];
220        assert_eq!(AvroDeserializer::extract_confluent_id(&data), Some(256));
221    }
222
223    #[test]
224    fn test_extract_confluent_id_not_confluent() {
225        let data = [0x01, 0x00, 0x00, 0x00, 0x01];
226        assert_eq!(AvroDeserializer::extract_confluent_id(&data), None);
227    }
228
229    #[test]
230    fn test_extract_confluent_id_too_short() {
231        let data = [0x00, 0x00];
232        assert_eq!(AvroDeserializer::extract_confluent_id(&data), None);
233    }
234
235    #[test]
236    fn test_new_deserializer() {
237        let deser = AvroDeserializer::new();
238        assert!(deser.schema_registry.is_none());
239        assert!(deser.known_ids.is_empty());
240    }
241
242    #[test]
243    fn test_register_schema() {
244        let mut deser = AvroDeserializer::new();
245        let result = deser.register_schema(1, TEST_AVRO_SCHEMA);
246        assert!(result.is_ok());
247        assert!(deser.known_ids.contains(&1));
248    }
249
250    #[test]
251    fn test_format() {
252        let deser = AvroDeserializer::new();
253        assert_eq!(deser.format(), Format::Avro);
254    }
255
256    #[test]
257    fn test_deserialize_empty_batch() {
258        let deser = AvroDeserializer::new();
259        let schema = test_arrow_schema();
260        let result = deser.deserialize_batch(&[], &schema);
261        assert!(result.is_ok());
262        assert_eq!(result.unwrap().num_rows(), 0);
263    }
264
265    #[test]
266    fn test_extract_confluent_id_large() {
267        // Schema ID 42
268        let mut data = vec![0x00u8];
269        data.extend_from_slice(&42i32.to_be_bytes());
270        data.push(0xFF);
271        assert_eq!(AvroDeserializer::extract_confluent_id(&data), Some(42));
272    }
273
274    #[test]
275    fn test_extract_confluent_id_edge_cases() {
276        // Empty slice
277        assert_eq!(AvroDeserializer::extract_confluent_id(&[]), None);
278
279        // Magic byte only (too short)
280        assert_eq!(AvroDeserializer::extract_confluent_id(&[0x00]), None);
281
282        // 4 bytes with magic (too short)
283        assert_eq!(
284            AvroDeserializer::extract_confluent_id(&[0x00, 0x00, 0x00, 0x00]),
285            None
286        );
287
288        // Exactly 5 bytes (boundary of CONFLUENT_HEADER_SIZE)
289        let data = [0x00, 0x00, 0x00, 0x00, 0x05];
290        assert_eq!(AvroDeserializer::extract_confluent_id(&data), Some(5));
291
292        // i32::MAX ID
293        let mut data = vec![0x00];
294        data.extend_from_slice(&i32::MAX.to_be_bytes());
295        assert_eq!(
296            AvroDeserializer::extract_confluent_id(&data),
297            Some(i32::MAX)
298        );
299
300        // i32::MIN ID
301        let mut data = vec![0x00];
302        data.extend_from_slice(&i32::MIN.to_be_bytes());
303        assert_eq!(
304            AvroDeserializer::extract_confluent_id(&data),
305            Some(i32::MIN)
306        );
307    }
308}