laminar_connectors/kafka/
avro.rs1use std::collections::HashSet;
9use std::sync::Arc;
10
11use arrow_array::RecordBatch;
12use arrow_avro::reader::{Decoder, ReaderBuilder};
13use arrow_avro::schema::{AvroSchema, Fingerprint, FingerprintAlgorithm, SchemaStore};
14use arrow_schema::SchemaRef;
15use parking_lot::Mutex;
16
17use crate::error::SerdeError;
18use crate::kafka::schema_registry::SchemaRegistryClient;
19use crate::serde::{Format, RecordDeserializer};
20
21const DECODER_BATCH_CAPACITY: usize = 8192;
22
23const CONFLUENT_MAGIC: u8 = 0x00;
25
26const CONFLUENT_HEADER_SIZE: usize = 5;
28
29pub struct AvroDeserializer {
35 schema_store: SchemaStore,
37 schema_registry: Option<Arc<SchemaRegistryClient>>,
39 known_ids: HashSet<i32>,
41 decoder: Mutex<Option<Decoder>>,
43}
44
45impl AvroDeserializer {
46 #[must_use]
50 pub fn new() -> Self {
51 Self {
52 schema_store: SchemaStore::new_with_type(FingerprintAlgorithm::Id),
53 schema_registry: None,
54 known_ids: HashSet::new(),
55 decoder: Mutex::new(None),
56 }
57 }
58
59 #[must_use]
64 pub fn with_schema_registry(registry: Arc<SchemaRegistryClient>) -> Self {
65 Self {
66 schema_store: SchemaStore::new_with_type(FingerprintAlgorithm::Id),
67 schema_registry: Some(registry),
68 known_ids: HashSet::new(),
69 decoder: Mutex::new(None),
70 }
71 }
72
73 #[allow(clippy::cast_sign_loss)]
79 pub fn register_schema(
80 &mut self,
81 schema_id: i32,
82 avro_schema_json: &str,
83 ) -> Result<(), SerdeError> {
84 let avro_schema = AvroSchema::new(avro_schema_json.to_string());
85 let fp = Fingerprint::Id(schema_id as u32);
88 self.schema_store
89 .set(fp, avro_schema)
90 .map_err(|e| SerdeError::MalformedInput(format!("failed to register schema: {e}")))?;
91 self.known_ids.insert(schema_id);
92 *self.decoder.lock() = None;
95 Ok(())
96 }
97
98 pub async fn ensure_schema_registered(&mut self, schema_id: i32) -> Result<bool, SerdeError> {
109 if self.known_ids.contains(&schema_id) {
110 return Ok(false);
111 }
112
113 let registry = self
114 .schema_registry
115 .as_ref()
116 .ok_or(SerdeError::SchemaNotFound { schema_id })?;
117
118 let cached = registry
119 .resolve_confluent_id(schema_id)
120 .await
121 .map_err(|_| SerdeError::SchemaNotFound { schema_id })?;
122
123 self.register_schema(schema_id, &cached.schema_str)?;
124 Ok(true)
125 }
126
127 #[must_use]
131 pub fn extract_confluent_id(data: &[u8]) -> Option<i32> {
132 if data.len() < CONFLUENT_HEADER_SIZE || data[0] != CONFLUENT_MAGIC {
133 return None;
134 }
135 let id = i32::from_be_bytes([data[1], data[2], data[3], data[4]]);
136 Some(id)
137 }
138}
139
140impl Default for AvroDeserializer {
141 fn default() -> Self {
142 Self::new()
143 }
144}
145
146impl RecordDeserializer for AvroDeserializer {
147 fn deserialize(&self, data: &[u8], schema: &SchemaRef) -> Result<RecordBatch, SerdeError> {
148 self.deserialize_batch(&[data], schema)
149 }
150
151 fn deserialize_batch(
152 &self,
153 records: &[&[u8]],
154 schema: &SchemaRef,
155 ) -> Result<RecordBatch, SerdeError> {
156 if records.is_empty() {
157 return Ok(RecordBatch::new_empty(schema.clone()));
158 }
159
160 let mut guard = self.decoder.lock();
161 let decoder = if let Some(d) = guard.as_mut() {
162 d
163 } else {
164 let d = ReaderBuilder::new()
165 .with_batch_size(DECODER_BATCH_CAPACITY)
166 .with_writer_schema_store(self.schema_store.clone())
167 .build_decoder()
168 .map_err(|e| SerdeError::MalformedInput(format!("failed to build decoder: {e}")))?;
169 guard.insert(d)
170 };
171
172 let mut partials: Vec<RecordBatch> = Vec::new();
173 for record in records {
174 let mut offset = 0;
175 while offset < record.len() {
176 let consumed = decoder
177 .decode(&record[offset..])
178 .map_err(|e| SerdeError::MalformedInput(format!("Avro decode error: {e}")))?;
179 if consumed == 0 {
180 break;
181 }
182 offset += consumed;
183 if decoder.batch_is_full() {
184 if let Some(b) = decoder
185 .flush()
186 .map_err(|e| SerdeError::MalformedInput(format!("Avro flush: {e}")))?
187 {
188 partials.push(b);
189 }
190 }
191 }
192 }
193 if let Some(b) = decoder
194 .flush()
195 .map_err(|e| SerdeError::MalformedInput(format!("Avro flush: {e}")))?
196 {
197 partials.push(b);
198 }
199
200 match partials.len() {
201 0 => Err(SerdeError::MalformedInput("no records decoded".into())),
202 1 => Ok(partials.pop().unwrap()),
203 _ => arrow_select::concat::concat_batches(schema, &partials)
204 .map_err(|e| SerdeError::MalformedInput(format!("concat: {e}"))),
205 }
206 }
207
208 fn format(&self) -> Format {
209 Format::Avro
210 }
211
212 fn as_any_mut(&mut self) -> Option<&mut dyn std::any::Any> {
213 Some(self)
214 }
215}
216
217impl std::fmt::Debug for AvroDeserializer {
218 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219 f.debug_struct("AvroDeserializer")
220 .field("known_ids", &self.known_ids)
221 .field("has_registry", &self.schema_registry.is_some())
222 .finish_non_exhaustive()
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use arrow_schema::{DataType, Field, Schema};
230
231 const TEST_AVRO_SCHEMA: &str = r#"{
232 "type": "record",
233 "name": "test",
234 "fields": [
235 {"name": "id", "type": "long"},
236 {"name": "name", "type": "string"}
237 ]
238 }"#;
239
240 fn test_arrow_schema() -> SchemaRef {
241 Arc::new(Schema::new(vec![
242 Field::new("id", DataType::Int64, false),
243 Field::new("name", DataType::Utf8, false),
244 ]))
245 }
246
247 #[test]
251 fn test_nullable_map_nullable_double_full_path() {
252 fn zigzag(val: i64) -> Vec<u8> {
253 let mut z = ((val << 1) ^ (val >> 63)) as u64;
254 let mut buf = Vec::new();
255 loop {
256 if z & !0x7F == 0 {
257 buf.push(z as u8);
258 break;
259 }
260 buf.push((z as u8 & 0x7F) | 0x80);
261 z >>= 7;
262 }
263 buf
264 }
265 fn avro_string(s: &str) -> Vec<u8> {
266 let mut b = zigzag(s.len() as i64);
267 b.extend_from_slice(s.as_bytes());
268 b
269 }
270
271 let avro_json = r#"{
272 "type": "record",
273 "name": "Metrics",
274 "fields": [
275 {"name": "sensor_id", "type": "string"},
276 {
277 "name": "data",
278 "type": ["null", {"type": "map", "values": ["null", "double"]}]
279 }
280 ]
281 }"#;
282
283 let sr_schema = crate::kafka::schema_registry::avro_to_arrow_schema(avro_json)
285 .expect("avro_to_arrow_schema should handle nullable map");
286
287 let mut deser = AvroDeserializer::new();
289 deser.register_schema(42, avro_json).unwrap();
290
291 let mut payload = Vec::new();
294 payload.extend_from_slice(&avro_string("s1"));
295 payload.extend_from_slice(&zigzag(1)); payload.extend_from_slice(&zigzag(1)); payload.extend_from_slice(&avro_string("temp"));
298 payload.extend_from_slice(&zigzag(1)); payload.extend_from_slice(&23.5_f64.to_le_bytes());
300 payload.extend_from_slice(&zigzag(0)); let mut msg = vec![0x00u8];
304 msg.extend_from_slice(&42i32.to_be_bytes());
305 msg.extend_from_slice(&payload);
306
307 let batch = deser
308 .deserialize_batch(&[msg.as_slice()], &sr_schema)
309 .expect("decode should succeed");
310 assert_eq!(batch.num_rows(), 1);
311
312 let sr_data = sr_schema.field_with_name("data").unwrap();
313 let decoded_schema = batch.schema();
314 let dec_data = decoded_schema.field_with_name("data").unwrap();
315 assert_eq!(
316 sr_data.data_type(),
317 dec_data.data_type(),
318 "schema registry and arrow-avro must produce identical Map types"
319 );
320 }
321
322 #[test]
323 fn test_extract_confluent_id() {
324 let data = [0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03];
326 assert_eq!(AvroDeserializer::extract_confluent_id(&data), Some(1));
327
328 let data = [0x00, 0x00, 0x00, 0x01, 0x00, 0x02, 0x03];
329 assert_eq!(AvroDeserializer::extract_confluent_id(&data), Some(256));
330 }
331
332 #[test]
333 fn test_extract_confluent_id_not_confluent() {
334 let data = [0x01, 0x00, 0x00, 0x00, 0x01];
335 assert_eq!(AvroDeserializer::extract_confluent_id(&data), None);
336 }
337
338 #[test]
339 fn test_extract_confluent_id_too_short() {
340 let data = [0x00, 0x00];
341 assert_eq!(AvroDeserializer::extract_confluent_id(&data), None);
342 }
343
344 #[test]
345 fn test_new_deserializer() {
346 let deser = AvroDeserializer::new();
347 assert!(deser.schema_registry.is_none());
348 assert!(deser.known_ids.is_empty());
349 }
350
351 #[test]
352 fn test_register_schema() {
353 let mut deser = AvroDeserializer::new();
354 let result = deser.register_schema(1, TEST_AVRO_SCHEMA);
355 assert!(result.is_ok());
356 assert!(deser.known_ids.contains(&1));
357 }
358
359 #[test]
360 fn test_format() {
361 let deser = AvroDeserializer::new();
362 assert_eq!(deser.format(), Format::Avro);
363 }
364
365 #[test]
366 fn test_deserialize_empty_batch() {
367 let deser = AvroDeserializer::new();
368 let schema = test_arrow_schema();
369 let result = deser.deserialize_batch(&[], &schema);
370 assert!(result.is_ok());
371 assert_eq!(result.unwrap().num_rows(), 0);
372 }
373
374 #[test]
375 fn test_extract_confluent_id_large() {
376 let mut data = vec![0x00u8];
378 data.extend_from_slice(&42i32.to_be_bytes());
379 data.push(0xFF);
380 assert_eq!(AvroDeserializer::extract_confluent_id(&data), Some(42));
381 }
382
383 #[test]
384 fn test_extract_confluent_id_edge_cases() {
385 assert_eq!(AvroDeserializer::extract_confluent_id(&[]), None);
387
388 assert_eq!(AvroDeserializer::extract_confluent_id(&[0x00]), None);
390
391 assert_eq!(
393 AvroDeserializer::extract_confluent_id(&[0x00, 0x00, 0x00, 0x00]),
394 None
395 );
396
397 let data = [0x00, 0x00, 0x00, 0x00, 0x05];
399 assert_eq!(AvroDeserializer::extract_confluent_id(&data), Some(5));
400
401 let mut data = vec![0x00];
403 data.extend_from_slice(&i32::MAX.to_be_bytes());
404 assert_eq!(
405 AvroDeserializer::extract_confluent_id(&data),
406 Some(i32::MAX)
407 );
408
409 let mut data = vec![0x00];
411 data.extend_from_slice(&i32::MIN.to_be_bytes());
412 assert_eq!(
413 AvroDeserializer::extract_confluent_id(&data),
414 Some(i32::MIN)
415 );
416 }
417}