laminar_connectors/kafka/
avro.rs1use std::collections::HashSet;
9use std::sync::Arc;
10
11use arrow_array::RecordBatch;
12use arrow_avro::reader::ReaderBuilder;
13use arrow_avro::schema::{AvroSchema, Fingerprint, FingerprintAlgorithm, SchemaStore};
14use arrow_schema::SchemaRef;
15
16use crate::error::SerdeError;
17use crate::kafka::schema_registry::SchemaRegistryClient;
18use crate::serde::{Format, RecordDeserializer};
19
20const CONFLUENT_MAGIC: u8 = 0x00;
22
23const CONFLUENT_HEADER_SIZE: usize = 5;
25
26pub struct AvroDeserializer {
32 schema_store: SchemaStore,
34 schema_registry: Option<Arc<SchemaRegistryClient>>,
36 known_ids: HashSet<i32>,
38}
39
40impl AvroDeserializer {
41 #[must_use]
45 pub fn new() -> Self {
46 Self {
47 schema_store: SchemaStore::new_with_type(FingerprintAlgorithm::Id),
48 schema_registry: None,
49 known_ids: HashSet::new(),
50 }
51 }
52
53 #[must_use]
58 pub fn with_schema_registry(registry: Arc<SchemaRegistryClient>) -> Self {
59 Self {
60 schema_store: SchemaStore::new_with_type(FingerprintAlgorithm::Id),
61 schema_registry: Some(registry),
62 known_ids: HashSet::new(),
63 }
64 }
65
66 #[allow(clippy::cast_sign_loss)]
72 pub fn register_schema(
73 &mut self,
74 schema_id: i32,
75 avro_schema_json: &str,
76 ) -> Result<(), SerdeError> {
77 let avro_schema = AvroSchema::new(avro_schema_json.to_string());
78 let fp = Fingerprint::Id(schema_id as u32);
81 self.schema_store
82 .set(fp, avro_schema)
83 .map_err(|e| SerdeError::MalformedInput(format!("failed to register schema: {e}")))?;
84 self.known_ids.insert(schema_id);
85 Ok(())
86 }
87
88 pub async fn ensure_schema_registered(&mut self, schema_id: i32) -> Result<(), SerdeError> {
97 if self.known_ids.contains(&schema_id) {
98 return Ok(());
99 }
100
101 let registry = self
102 .schema_registry
103 .as_ref()
104 .ok_or(SerdeError::SchemaNotFound { schema_id })?;
105
106 let cached = registry
107 .resolve_confluent_id(schema_id)
108 .await
109 .map_err(|_| SerdeError::SchemaNotFound { schema_id })?;
110
111 self.register_schema(schema_id, &cached.schema_str)?;
112 Ok(())
113 }
114
115 #[must_use]
119 pub fn extract_confluent_id(data: &[u8]) -> Option<i32> {
120 if data.len() < CONFLUENT_HEADER_SIZE || data[0] != CONFLUENT_MAGIC {
121 return None;
122 }
123 let id = i32::from_be_bytes([data[1], data[2], data[3], data[4]]);
124 Some(id)
125 }
126}
127
128impl Default for AvroDeserializer {
129 fn default() -> Self {
130 Self::new()
131 }
132}
133
134impl RecordDeserializer for AvroDeserializer {
135 fn deserialize(&self, data: &[u8], schema: &SchemaRef) -> Result<RecordBatch, SerdeError> {
136 self.deserialize_batch(&[data], schema)
137 }
138
139 fn deserialize_batch(
140 &self,
141 records: &[&[u8]],
142 schema: &SchemaRef,
143 ) -> Result<RecordBatch, SerdeError> {
144 if records.is_empty() {
145 return Ok(RecordBatch::new_empty(schema.clone()));
146 }
147
148 let mut decoder = ReaderBuilder::new()
149 .with_batch_size(records.len())
150 .with_writer_schema_store(self.schema_store.clone())
151 .build_decoder()
152 .map_err(|e| SerdeError::MalformedInput(format!("failed to build decoder: {e}")))?;
153
154 for record in records {
155 let mut offset = 0;
156 while offset < record.len() {
157 let consumed = decoder
158 .decode(&record[offset..])
159 .map_err(|e| SerdeError::MalformedInput(format!("Avro decode error: {e}")))?;
160 if consumed == 0 {
161 break;
162 }
163 offset += consumed;
164 }
165 }
166
167 decoder
168 .flush()
169 .map_err(|e| SerdeError::MalformedInput(format!("Avro flush error: {e}")))?
170 .ok_or_else(|| SerdeError::MalformedInput("no records decoded".into()))
171 }
172
173 fn format(&self) -> Format {
174 Format::Avro
175 }
176
177 fn as_any_mut(&mut self) -> Option<&mut dyn std::any::Any> {
178 Some(self)
179 }
180}
181
182impl std::fmt::Debug for AvroDeserializer {
183 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184 f.debug_struct("AvroDeserializer")
185 .field("known_ids", &self.known_ids)
186 .field("has_registry", &self.schema_registry.is_some())
187 .finish_non_exhaustive()
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use arrow_schema::{DataType, Field, Schema};
195
196 #[allow(dead_code)]
197 const TEST_AVRO_SCHEMA: &str = r#"{
198 "type": "record",
199 "name": "test",
200 "fields": [
201 {"name": "id", "type": "long"},
202 {"name": "name", "type": "string"}
203 ]
204 }"#;
205
206 fn test_arrow_schema() -> SchemaRef {
207 Arc::new(Schema::new(vec![
208 Field::new("id", DataType::Int64, false),
209 Field::new("name", DataType::Utf8, false),
210 ]))
211 }
212
213 #[test]
214 fn test_extract_confluent_id() {
215 let data = [0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03];
217 assert_eq!(AvroDeserializer::extract_confluent_id(&data), Some(1));
218
219 let data = [0x00, 0x00, 0x00, 0x01, 0x00, 0x02, 0x03];
220 assert_eq!(AvroDeserializer::extract_confluent_id(&data), Some(256));
221 }
222
223 #[test]
224 fn test_extract_confluent_id_not_confluent() {
225 let data = [0x01, 0x00, 0x00, 0x00, 0x01];
226 assert_eq!(AvroDeserializer::extract_confluent_id(&data), None);
227 }
228
229 #[test]
230 fn test_extract_confluent_id_too_short() {
231 let data = [0x00, 0x00];
232 assert_eq!(AvroDeserializer::extract_confluent_id(&data), None);
233 }
234
235 #[test]
236 fn test_new_deserializer() {
237 let deser = AvroDeserializer::new();
238 assert!(deser.schema_registry.is_none());
239 assert!(deser.known_ids.is_empty());
240 }
241
242 #[test]
243 fn test_register_schema() {
244 let mut deser = AvroDeserializer::new();
245 let result = deser.register_schema(1, TEST_AVRO_SCHEMA);
246 assert!(result.is_ok());
247 assert!(deser.known_ids.contains(&1));
248 }
249
250 #[test]
251 fn test_format() {
252 let deser = AvroDeserializer::new();
253 assert_eq!(deser.format(), Format::Avro);
254 }
255
256 #[test]
257 fn test_deserialize_empty_batch() {
258 let deser = AvroDeserializer::new();
259 let schema = test_arrow_schema();
260 let result = deser.deserialize_batch(&[], &schema);
261 assert!(result.is_ok());
262 assert_eq!(result.unwrap().num_rows(), 0);
263 }
264
265 #[test]
266 fn test_extract_confluent_id_large() {
267 let mut data = vec![0x00u8];
269 data.extend_from_slice(&42i32.to_be_bytes());
270 data.push(0xFF);
271 assert_eq!(AvroDeserializer::extract_confluent_id(&data), Some(42));
272 }
273
274 #[test]
275 fn test_extract_confluent_id_edge_cases() {
276 assert_eq!(AvroDeserializer::extract_confluent_id(&[]), None);
278
279 assert_eq!(AvroDeserializer::extract_confluent_id(&[0x00]), None);
281
282 assert_eq!(
284 AvroDeserializer::extract_confluent_id(&[0x00, 0x00, 0x00, 0x00]),
285 None
286 );
287
288 let data = [0x00, 0x00, 0x00, 0x00, 0x05];
290 assert_eq!(AvroDeserializer::extract_confluent_id(&data), Some(5));
291
292 let mut data = vec![0x00];
294 data.extend_from_slice(&i32::MAX.to_be_bytes());
295 assert_eq!(
296 AvroDeserializer::extract_confluent_id(&data),
297 Some(i32::MAX)
298 );
299
300 let mut data = vec![0x00];
302 data.extend_from_slice(&i32::MIN.to_be_bytes());
303 assert_eq!(
304 AvroDeserializer::extract_confluent_id(&data),
305 Some(i32::MIN)
306 );
307 }
308}