1use std::time::{Duration, Instant};
8
9use foyer::{Cache, CacheBuilder};
10
11use arrow_schema::{DataType, SchemaRef};
12use reqwest::Client;
13use serde::{Deserialize, Serialize};
14
15use crate::error::{ConnectorError, SerdeError};
16use crate::kafka::config::{CompatibilityLevel, SrAuth};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum SchemaType {
21 Avro,
23 Protobuf,
25 Json,
27}
28
29impl std::str::FromStr for SchemaType {
30 type Err = ConnectorError;
31
32 fn from_str(s: &str) -> Result<Self, Self::Err> {
33 match s.to_uppercase().as_str() {
34 "AVRO" => Ok(SchemaType::Avro),
35 "PROTOBUF" => Ok(SchemaType::Protobuf),
36 "JSON" => Ok(SchemaType::Json),
37 other => Err(ConnectorError::ConfigurationError(format!(
38 "unknown schema type: '{other}'"
39 ))),
40 }
41 }
42}
43
44impl std::fmt::Display for SchemaType {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 match self {
47 SchemaType::Avro => write!(f, "AVRO"),
48 SchemaType::Protobuf => write!(f, "PROTOBUF"),
49 SchemaType::Json => write!(f, "JSON"),
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct SchemaRegistryCacheConfig {
57 pub max_entries: usize,
59 pub ttl: Option<Duration>,
61}
62
63impl Default for SchemaRegistryCacheConfig {
64 fn default() -> Self {
65 Self {
66 max_entries: 1000,
67 ttl: Some(Duration::from_secs(3600)),
68 }
69 }
70}
71
72#[derive(Debug, Clone)]
74pub struct CachedSchema {
75 pub id: i32,
77 pub version: i32,
79 pub schema_type: SchemaType,
81 pub schema_str: String,
83 pub arrow_schema: SchemaRef,
85 inserted_at: Instant,
87}
88
89#[derive(Debug, Clone)]
91pub struct CompatibilityResult {
92 pub is_compatible: bool,
94 pub messages: Vec<String>,
96}
97
98pub struct SchemaRegistryClient {
103 client: Client,
104 base_url: String,
105 auth: Option<SrAuth>,
106 cache: Cache<i32, CachedSchema>,
108 subject_cache: Cache<String, CachedSchema>,
110 cache_config: SchemaRegistryCacheConfig,
112}
113
114#[derive(Deserialize)]
117struct SchemaByIdResponse {
118 schema: String,
119 #[serde(default = "default_schema_type")]
120 #[serde(rename = "schemaType")]
121 schema_type: String,
122}
123
124#[derive(Deserialize)]
125struct SchemaVersionResponse {
126 id: i32,
127 version: i32,
128 schema: String,
129 #[serde(default = "default_schema_type")]
130 #[serde(rename = "schemaType")]
131 schema_type: String,
132}
133
134#[derive(Deserialize)]
135struct CompatibilityResponse {
136 is_compatible: bool,
137 #[serde(default)]
138 messages: Vec<String>,
139}
140
141#[derive(Deserialize)]
142struct ConfigResponse {
143 #[serde(rename = "compatibilityLevel")]
144 compatibility_level: String,
145}
146
147#[derive(Serialize)]
148struct CompatibilityRequest {
149 schema: String,
150 #[serde(rename = "schemaType")]
151 schema_type: String,
152}
153
154#[derive(Serialize)]
155struct ConfigUpdateRequest {
156 compatibility: String,
157}
158
159#[derive(Serialize)]
160struct RegisterSchemaRequest {
161 schema: String,
162 #[serde(rename = "schemaType")]
163 schema_type: String,
164}
165
166#[derive(Deserialize)]
167struct RegisterSchemaResponse {
168 id: i32,
169}
170
171fn default_schema_type() -> String {
172 "AVRO".to_string()
173}
174
175impl SchemaRegistryClient {
176 #[must_use]
178 pub fn new(base_url: impl Into<String>, auth: Option<SrAuth>) -> Self {
179 Self::with_cache_config(base_url, auth, SchemaRegistryCacheConfig::default())
180 }
181
182 pub fn with_tls(
188 base_url: impl Into<String>,
189 auth: Option<SrAuth>,
190 ca_cert_path: &str,
191 ) -> Result<Self, ConnectorError> {
192 Self::with_tls_mtls(base_url, auth, ca_cert_path, None, None)
193 }
194
195 pub fn with_tls_mtls(
202 base_url: impl Into<String>,
203 auth: Option<SrAuth>,
204 ca_cert_path: &str,
205 client_cert_path: Option<&str>,
206 client_key_path: Option<&str>,
207 ) -> Result<Self, ConnectorError> {
208 let pem = std::fs::read(ca_cert_path).map_err(|e| {
209 ConnectorError::ConfigurationError(format!(
210 "failed to read SR CA cert at '{ca_cert_path}': {e}"
211 ))
212 })?;
213 let cert = reqwest::tls::Certificate::from_pem(&pem).map_err(|e| {
214 ConnectorError::ConfigurationError(format!(
215 "invalid PEM CA cert at '{ca_cert_path}': {e}"
216 ))
217 })?;
218
219 let mut builder = Client::builder().add_root_certificate(cert);
220
221 if client_cert_path.is_some() != client_key_path.is_some() {
222 return Err(ConnectorError::ConfigurationError(
223 "mTLS requires both client cert and key — only one was provided".into(),
224 ));
225 }
226 if let (Some(cert_path), Some(key_path)) = (client_cert_path, client_key_path) {
227 let mut identity_pem = std::fs::read(cert_path).map_err(|e| {
228 ConnectorError::ConfigurationError(format!(
229 "failed to read SR client cert at '{cert_path}': {e}"
230 ))
231 })?;
232 let key_pem = std::fs::read(key_path).map_err(|e| {
233 ConnectorError::ConfigurationError(format!(
234 "failed to read SR client key at '{key_path}': {e}"
235 ))
236 })?;
237 identity_pem.extend_from_slice(&key_pem);
239 let identity = reqwest::tls::Identity::from_pem(&identity_pem).map_err(|e| {
240 ConnectorError::ConfigurationError(format!("invalid client cert/key PEM: {e}"))
241 })?;
242 builder = builder.identity(identity);
243 }
244
245 let client = builder.build().map_err(|e| {
246 ConnectorError::ConfigurationError(format!("failed to build TLS client: {e}"))
247 })?;
248
249 let cache_config = SchemaRegistryCacheConfig::default();
250 let cache = CacheBuilder::new(cache_config.max_entries)
251 .with_shards(4)
252 .build();
253 let subject_cache = CacheBuilder::new(256).with_shards(4).build();
254 Ok(Self {
255 client,
256 base_url: base_url.into().trim_end_matches('/').to_string(),
257 auth,
258 cache,
259 subject_cache,
260 cache_config,
261 })
262 }
263
264 #[must_use]
266 pub fn with_cache_config(
267 base_url: impl Into<String>,
268 auth: Option<SrAuth>,
269 cache_config: SchemaRegistryCacheConfig,
270 ) -> Self {
271 let cache = CacheBuilder::new(cache_config.max_entries)
272 .with_shards(4)
273 .build();
274 let subject_cache = CacheBuilder::new(256).with_shards(4).build();
276 Self {
277 client: Client::new(),
278 base_url: base_url.into().trim_end_matches('/').to_string(),
279 auth,
280 cache,
281 subject_cache,
282 cache_config,
283 }
284 }
285
286 #[must_use]
288 pub fn base_url(&self) -> &str {
289 &self.base_url
290 }
291
292 #[must_use]
294 pub fn has_auth(&self) -> bool {
295 self.auth.is_some()
296 }
297
298 #[must_use]
300 pub fn cache_config(&self) -> &SchemaRegistryCacheConfig {
301 &self.cache_config
302 }
303
304 fn cache_insert(&self, id: i32, mut schema: CachedSchema) {
308 schema.inserted_at = Instant::now();
309 self.cache.insert(id, schema);
310 }
311
312 fn cache_get(&self, id: i32) -> Option<CachedSchema> {
317 let entry = self.cache.get(&id)?;
318 let schema = entry.value();
319 if let Some(ttl) = self.cache_config.ttl {
320 if schema.inserted_at.elapsed() > ttl {
321 drop(entry);
322 self.cache.remove(&id);
323 return None;
324 }
325 }
326 Some(schema.clone())
328 }
329
330 pub async fn get_schema_by_id(&self, id: i32) -> Result<CachedSchema, ConnectorError> {
339 if let Some(cached) = self.cache_get(id) {
340 return Ok(cached);
341 }
342
343 let url = format!("{}/schemas/ids/{}", self.base_url, id);
344 let resp: SchemaByIdResponse = self.get_json(&url).await?;
345
346 let schema_type: SchemaType = resp.schema_type.parse()?;
347 let arrow_schema = schema_to_arrow(schema_type, &resp.schema)?;
348
349 let cached = CachedSchema {
350 id,
351 version: 0, schema_type,
353 schema_str: resp.schema,
354 arrow_schema,
355 inserted_at: Instant::now(),
356 };
357 self.cache_insert(id, cached.clone());
358 Ok(cached)
359 }
360
361 pub async fn get_latest_schema(&self, subject: &str) -> Result<CachedSchema, ConnectorError> {
367 let url = format!("{}/subjects/{}/versions/latest", self.base_url, subject);
368 let resp: SchemaVersionResponse = self.get_json(&url).await?;
369
370 let schema_type: SchemaType = resp.schema_type.parse()?;
371 let arrow_schema = schema_to_arrow(schema_type, &resp.schema)?;
372
373 let cached = CachedSchema {
374 id: resp.id,
375 version: resp.version,
376 schema_type,
377 schema_str: resp.schema,
378 arrow_schema,
379 inserted_at: Instant::now(),
380 };
381
382 self.cache_insert(resp.id, cached.clone());
383 self.subject_cache
384 .insert(subject.to_string(), cached.clone());
385 Ok(cached)
386 }
387
388 pub async fn get_schema_version(
394 &self,
395 subject: &str,
396 version: i32,
397 ) -> Result<CachedSchema, ConnectorError> {
398 let url = format!(
399 "{}/subjects/{}/versions/{}",
400 self.base_url, subject, version
401 );
402 let resp: SchemaVersionResponse = self.get_json(&url).await?;
403
404 let schema_type: SchemaType = resp.schema_type.parse()?;
405 let arrow_schema = schema_to_arrow(schema_type, &resp.schema)?;
406
407 let cached = CachedSchema {
408 id: resp.id,
409 version: resp.version,
410 schema_type,
411 schema_str: resp.schema,
412 arrow_schema,
413 inserted_at: Instant::now(),
414 };
415 self.cache_insert(resp.id, cached.clone());
416 Ok(cached)
417 }
418
419 pub async fn check_compatibility(
425 &self,
426 subject: &str,
427 schema_str: &str,
428 ) -> Result<CompatibilityResult, ConnectorError> {
429 let url = format!(
430 "{}/compatibility/subjects/{}/versions/latest",
431 self.base_url, subject
432 );
433
434 let body = CompatibilityRequest {
435 schema: schema_str.to_string(),
436 schema_type: "AVRO".to_string(),
437 };
438
439 let mut req = self.client.post(&url).json(&body);
440 if let Some(ref auth) = self.auth {
441 req = req.basic_auth(&auth.username, Some(&auth.password));
442 }
443
444 let resp = req
445 .send()
446 .await
447 .map_err(|e| ConnectorError::ConnectionFailed(format!("schema registry: {e}")))?;
448
449 if !resp.status().is_success() {
450 let status = resp.status();
451 let text = resp.text().await.unwrap_or_default();
452 return Err(ConnectorError::ConnectionFailed(format!(
453 "schema registry compatibility check failed: {status} {text}"
454 )));
455 }
456
457 let result: CompatibilityResponse = resp.json().await.map_err(|e| {
458 ConnectorError::Internal(format!("failed to parse compatibility response: {e}"))
459 })?;
460
461 Ok(CompatibilityResult {
462 is_compatible: result.is_compatible,
463 messages: result.messages,
464 })
465 }
466
467 pub async fn get_compatibility_level(
473 &self,
474 subject: &str,
475 ) -> Result<CompatibilityLevel, ConnectorError> {
476 let url = format!("{}/config/{}", self.base_url, subject);
477 let resp: ConfigResponse = self.get_json(&url).await?;
478 resp.compatibility_level.parse()
479 }
480
481 pub async fn set_compatibility_level(
487 &self,
488 subject: &str,
489 level: CompatibilityLevel,
490 ) -> Result<(), ConnectorError> {
491 let url = format!("{}/config/{}", self.base_url, subject);
492 let body = ConfigUpdateRequest {
493 compatibility: level.as_str().to_string(),
494 };
495
496 let mut req = self.client.put(&url).json(&body);
497 if let Some(ref auth) = self.auth {
498 req = req.basic_auth(&auth.username, Some(&auth.password));
499 }
500
501 let resp = req
502 .send()
503 .await
504 .map_err(|e| ConnectorError::ConnectionFailed(format!("schema registry: {e}")))?;
505
506 if !resp.status().is_success() {
507 let status = resp.status();
508 let text = resp.text().await.unwrap_or_default();
509 return Err(ConnectorError::ConnectionFailed(format!(
510 "schema registry config update failed: {status} {text}"
511 )));
512 }
513
514 Ok(())
515 }
516
517 pub async fn resolve_confluent_id(&self, id: i32) -> Result<CachedSchema, ConnectorError> {
526 self.get_schema_by_id(id).await
527 }
528
529 pub async fn register_schema(
539 &self,
540 subject: &str,
541 schema_str: &str,
542 schema_type: SchemaType,
543 ) -> Result<i32, ConnectorError> {
544 if let Some(entry) = self.subject_cache.get(subject) {
546 if entry.value().schema_str == schema_str {
547 return Ok(entry.value().id);
548 }
549 }
550
551 let url = format!("{}/subjects/{}/versions", self.base_url, subject);
552 let body = RegisterSchemaRequest {
553 schema: schema_str.to_string(),
554 schema_type: schema_type.to_string(),
555 };
556
557 let mut req = self.client.post(&url).json(&body);
558 if let Some(ref auth) = self.auth {
559 req = req.basic_auth(&auth.username, Some(&auth.password));
560 }
561
562 let resp = req
563 .send()
564 .await
565 .map_err(|e| ConnectorError::ConnectionFailed(format!("schema registry: {e}")))?;
566
567 if !resp.status().is_success() {
568 let status = resp.status();
569 let text = resp.text().await.unwrap_or_default();
570 return Err(ConnectorError::ConnectionFailed(format!(
571 "schema registry register failed: {status} {text}"
572 )));
573 }
574
575 let result: RegisterSchemaResponse = resp.json().await.map_err(|e| {
576 ConnectorError::Internal(format!("failed to parse register schema response: {e}"))
577 })?;
578
579 let arrow_schema = avro_to_arrow_schema(schema_str)?;
580 let cached = CachedSchema {
581 id: result.id,
582 version: 0,
583 schema_type,
584 schema_str: schema_str.to_string(),
585 arrow_schema,
586 inserted_at: Instant::now(),
587 };
588 self.cache_insert(result.id, cached.clone());
589 self.subject_cache.insert(subject.to_string(), cached);
590
591 Ok(result.id)
592 }
593
594 pub async fn validate_and_register_schema(
605 &self,
606 subject: &str,
607 schema_str: &str,
608 schema_type: SchemaType,
609 ) -> Result<i32, ConnectorError> {
610 match self.check_compatibility(subject, schema_str).await {
612 Ok(result) => {
613 if !result.is_compatible {
614 let message = if result.messages.is_empty() {
615 "new schema is not compatible with existing version".to_string()
616 } else {
617 result.messages.join("; ")
618 };
619 return Err(ConnectorError::Serde(SerdeError::SchemaIncompatible {
620 subject: subject.to_string(),
621 message,
622 }));
623 }
624 }
625 Err(ConnectorError::ConnectionFailed(msg)) if msg.contains("404") => {
626 }
628 Err(e) => return Err(e),
629 }
630
631 self.register_schema(subject, schema_str, schema_type).await
632 }
633
634 #[must_use]
636 pub fn is_cached(&self, id: i32) -> bool {
637 self.cache.contains(&id)
638 }
639
640 #[must_use]
642 pub fn cache_size(&self) -> usize {
643 self.cache.usage()
644 }
645
646 async fn get_json<T: serde::de::DeserializeOwned>(
651 &self,
652 url: &str,
653 ) -> Result<T, ConnectorError> {
654 let backoffs = [
655 std::time::Duration::from_millis(100),
656 std::time::Duration::from_millis(500),
657 ];
658 let mut last_err = None;
659
660 for (attempt, backoff) in std::iter::once(&std::time::Duration::ZERO)
661 .chain(backoffs.iter())
662 .enumerate()
663 {
664 if attempt > 0 {
665 tokio::time::sleep(*backoff).await;
666 }
667
668 let mut req = self.client.get(url);
669 if let Some(ref auth) = self.auth {
670 req = req.basic_auth(&auth.username, Some(&auth.password));
671 }
672
673 let resp = match req.send().await {
674 Ok(r) => r,
675 Err(e) => {
676 tracing::warn!(
677 attempt = attempt + 1,
678 error = %e,
679 "schema registry request failed, retrying"
680 );
681 last_err = Some(ConnectorError::ConnectionFailed(format!(
682 "schema registry: {e}"
683 )));
684 continue;
685 }
686 };
687
688 let status = resp.status();
689 if status.is_success() {
690 return resp.json::<T>().await.map_err(|e| {
691 ConnectorError::Internal(format!(
692 "failed to parse schema registry response: {e}"
693 ))
694 });
695 }
696
697 if status.is_client_error() {
699 let text = resp.text().await.unwrap_or_default();
700 return Err(ConnectorError::ConnectionFailed(format!(
701 "schema registry client error: {status} {text}"
702 )));
703 }
704
705 let text = resp.text().await.unwrap_or_default();
707 tracing::warn!(
708 attempt = attempt + 1,
709 status = %status,
710 "schema registry server error, retrying"
711 );
712 last_err = Some(ConnectorError::ConnectionFailed(format!(
713 "schema registry request failed: {status} {text}"
714 )));
715 }
716
717 Err(last_err.unwrap_or_else(|| {
718 ConnectorError::ConnectionFailed("schema registry: all retries exhausted".into())
719 }))
720 }
721}
722
723impl std::fmt::Debug for SchemaRegistryClient {
724 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
725 f.debug_struct("SchemaRegistryClient")
726 .field("base_url", &self.base_url)
727 .field("has_auth", &self.auth.is_some())
728 .field("cached_schemas", &self.cache.usage())
729 .field("cached_subjects", &self.subject_cache.usage())
730 .finish_non_exhaustive()
731 }
732}
733
734fn schema_to_arrow(schema_type: SchemaType, schema_str: &str) -> Result<SchemaRef, ConnectorError> {
738 let name = match schema_type {
739 SchemaType::Avro => return avro_to_arrow_schema(schema_str),
740 SchemaType::Json => "JSON Schema Registry",
741 SchemaType::Protobuf => "Protobuf Schema Registry",
742 };
743 Err(ConnectorError::SchemaMismatch(format!(
744 "{name} subjects are not yet supported for auto-discovery \
745 — declare columns explicitly or use an Avro subject"
746 )))
747}
748
749pub fn avro_to_arrow_schema(avro_schema_str: &str) -> Result<SchemaRef, ConnectorError> {
755 use arrow_avro::reader::ReaderBuilder;
756 use arrow_avro::schema::{AvroSchema, Fingerprint, FingerprintAlgorithm, SchemaStore};
757
758 let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::Id);
759 let avro_schema = AvroSchema::new(avro_schema_str.to_string());
760 let fp = Fingerprint::Id(0);
761 store
762 .set(fp, avro_schema)
763 .map_err(|e| ConnectorError::SchemaMismatch(format!("invalid Avro schema: {e}")))?;
764
765 let decoder = ReaderBuilder::new()
766 .with_writer_schema_store(store)
767 .with_active_fingerprint(fp)
768 .build_decoder()
769 .map_err(|e| ConnectorError::SchemaMismatch(format!("Avro→Arrow conversion: {e}")))?;
770
771 Ok(decoder.schema())
772}
773
774pub fn arrow_to_avro_schema(schema: &SchemaRef, record_name: &str) -> Result<String, SerdeError> {
783 let mut fields = Vec::with_capacity(schema.fields().len());
784
785 for field in schema.fields() {
786 let avro_type = arrow_to_avro_type(field.data_type())?;
787
788 let field_type = if field.is_nullable() {
789 serde_json::json!(["null", avro_type])
790 } else {
791 avro_type
792 };
793
794 fields.push(serde_json::json!({
795 "name": field.name(),
796 "type": field_type,
797 }));
798 }
799
800 let safe_name = record_name.replace('-', "_");
803
804 let schema = serde_json::json!({
805 "type": "record",
806 "name": safe_name,
807 "fields": fields,
808 });
809
810 serde_json::to_string(&schema)
811 .map_err(|e| SerdeError::MalformedInput(format!("failed to serialize Avro schema: {e}")))
812}
813
814fn arrow_to_avro_type(data_type: &DataType) -> Result<serde_json::Value, SerdeError> {
816 match data_type {
817 DataType::Null => Ok(serde_json::json!("null")),
818 DataType::Boolean => Ok(serde_json::json!("boolean")),
819 DataType::Int8
820 | DataType::Int16
821 | DataType::Int32
822 | DataType::UInt8
823 | DataType::UInt16
824 | DataType::UInt32 => Ok(serde_json::json!("int")),
825 DataType::Int64 | DataType::UInt64 => Ok(serde_json::json!("long")),
826 DataType::Float32 => Ok(serde_json::json!("float")),
827 DataType::Float64 => Ok(serde_json::json!("double")),
828 DataType::Utf8 | DataType::LargeUtf8 => Ok(serde_json::json!("string")),
829 DataType::Binary | DataType::LargeBinary => Ok(serde_json::json!("bytes")),
830 DataType::List(item_field) => {
831 let items = arrow_to_avro_type(item_field.data_type())?;
832 Ok(serde_json::json!({
833 "type": "array",
834 "items": items,
835 }))
836 }
837 DataType::Map(entries_field, _) => {
838 if let DataType::Struct(fields) = entries_field.data_type() {
840 let value_field = fields.iter().find(|f| f.name() == "value").ok_or_else(|| {
841 SerdeError::UnsupportedFormat(
842 "Arrow Map missing 'value' field in entries struct".into(),
843 )
844 })?;
845 let values = arrow_to_avro_type(value_field.data_type())?;
846 Ok(serde_json::json!({
847 "type": "map",
848 "values": values,
849 }))
850 } else {
851 Err(SerdeError::UnsupportedFormat(
852 "Arrow Map entries field is not a Struct".into(),
853 ))
854 }
855 }
856 DataType::Struct(fields) => {
857 let mut avro_fields = Vec::with_capacity(fields.len());
858 for field in fields {
859 let avro_type = arrow_to_avro_type(field.data_type())?;
860 let field_type = if field.is_nullable() {
861 serde_json::json!(["null", avro_type])
862 } else {
863 avro_type
864 };
865 avro_fields.push(serde_json::json!({
866 "name": field.name(),
867 "type": field_type,
868 }));
869 }
870 Ok(serde_json::json!({
871 "type": "record",
872 "name": "nested",
873 "fields": avro_fields,
874 }))
875 }
876 DataType::Dictionary(_, value_type) if value_type.as_ref() == &DataType::Utf8 => {
877 Ok(serde_json::json!({
878 "type": "enum",
879 "name": "enum_field",
880 "symbols": [],
881 }))
882 }
883 DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, _) => {
884 Ok(serde_json::json!({"type": "long", "logicalType": "timestamp-millis"}))
885 }
886 DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, _) => {
887 Ok(serde_json::json!({"type": "long", "logicalType": "timestamp-micros"}))
888 }
889 DataType::Date32 => Ok(serde_json::json!({"type": "int", "logicalType": "date"})),
890 DataType::Time32(arrow_schema::TimeUnit::Millisecond) => {
891 Ok(serde_json::json!({"type": "int", "logicalType": "time-millis"}))
892 }
893 DataType::Time64(arrow_schema::TimeUnit::Microsecond) => {
894 Ok(serde_json::json!({"type": "long", "logicalType": "time-micros"}))
895 }
896 DataType::FixedSizeBinary(size) => Ok(serde_json::json!({
897 "type": "fixed",
898 "name": "fixed_field",
899 "size": size,
900 })),
901 other => Err(SerdeError::UnsupportedFormat(format!(
902 "no Avro equivalent for Arrow type: {other}"
903 ))),
904 }
905}
906
907#[cfg(test)]
908mod tests {
909 use std::sync::Arc;
910
911 use super::*;
912 use arrow_schema::{Field, Fields, Schema};
913
914 #[test]
915 fn test_avro_to_arrow_simple_record() {
916 let avro = r#"{
917 "type": "record",
918 "name": "test",
919 "fields": [
920 {"name": "id", "type": "long"},
921 {"name": "name", "type": "string"},
922 {"name": "active", "type": "boolean"}
923 ]
924 }"#;
925
926 let schema = avro_to_arrow_schema(avro).unwrap();
927 assert_eq!(schema.fields().len(), 3);
928 assert_eq!(schema.field(0).name(), "id");
929 assert_eq!(schema.field(0).data_type(), &DataType::Int64);
930 assert!(!schema.field(0).is_nullable());
931 assert_eq!(schema.field(1).name(), "name");
932 assert_eq!(schema.field(1).data_type(), &DataType::Utf8);
933 assert_eq!(schema.field(2).name(), "active");
934 assert_eq!(schema.field(2).data_type(), &DataType::Boolean);
935 }
936
937 #[test]
938 fn test_avro_to_arrow_nullable_union() {
939 let avro = r#"{
940 "type": "record",
941 "name": "test",
942 "fields": [
943 {"name": "id", "type": "long"},
944 {"name": "email", "type": ["null", "string"]}
945 ]
946 }"#;
947
948 let schema = avro_to_arrow_schema(avro).unwrap();
949 assert_eq!(schema.fields().len(), 2);
950 assert!(!schema.field(0).is_nullable());
951 assert!(schema.field(1).is_nullable());
952 assert_eq!(schema.field(1).data_type(), &DataType::Utf8);
953 }
954
955 #[test]
956 fn test_avro_to_arrow_all_primitives() {
957 let avro = r#"{
958 "type": "record",
959 "name": "test",
960 "fields": [
961 {"name": "b", "type": "boolean"},
962 {"name": "i", "type": "int"},
963 {"name": "l", "type": "long"},
964 {"name": "f", "type": "float"},
965 {"name": "d", "type": "double"},
966 {"name": "s", "type": "string"},
967 {"name": "raw", "type": "bytes"}
968 ]
969 }"#;
970
971 let schema = avro_to_arrow_schema(avro).unwrap();
972 assert_eq!(schema.field(0).data_type(), &DataType::Boolean);
973 assert_eq!(schema.field(1).data_type(), &DataType::Int32);
974 assert_eq!(schema.field(2).data_type(), &DataType::Int64);
975 assert_eq!(schema.field(3).data_type(), &DataType::Float32);
976 assert_eq!(schema.field(4).data_type(), &DataType::Float64);
977 assert_eq!(schema.field(5).data_type(), &DataType::Utf8);
978 assert_eq!(schema.field(6).data_type(), &DataType::Binary);
979 }
980
981 #[test]
982 fn test_avro_to_arrow_invalid_json() {
983 assert!(avro_to_arrow_schema("not json").is_err());
984 }
985
986 #[test]
987 fn test_avro_to_arrow_missing_fields() {
988 let avro = r#"{"type": "record", "name": "test"}"#;
989 assert!(avro_to_arrow_schema(avro).is_err());
990 }
991
992 #[test]
993 fn schema_to_arrow_avro_works() {
994 let avro = r#"{"type":"record","name":"t","fields":[{"name":"x","type":"long"}]}"#;
995 let schema = schema_to_arrow(SchemaType::Avro, avro).unwrap();
996 assert_eq!(schema.field(0).name(), "x");
997 }
998
999 #[test]
1000 fn schema_to_arrow_json_returns_actionable_error() {
1001 let err = schema_to_arrow(SchemaType::Json, "{}").unwrap_err();
1002 assert!(
1003 err.to_string().contains("JSON Schema Registry"),
1004 "error should name the subject type, got: {err}"
1005 );
1006 }
1007
1008 #[test]
1009 fn schema_to_arrow_protobuf_returns_actionable_error() {
1010 let err = schema_to_arrow(SchemaType::Protobuf, "").unwrap_err();
1011 assert!(
1012 err.to_string().contains("Protobuf"),
1013 "error should name the subject type, got: {err}"
1014 );
1015 }
1016
1017 #[test]
1018 fn test_schema_type_parsing() {
1019 assert_eq!("AVRO".parse::<SchemaType>().unwrap(), SchemaType::Avro);
1020 assert_eq!(
1021 "PROTOBUF".parse::<SchemaType>().unwrap(),
1022 SchemaType::Protobuf
1023 );
1024 assert_eq!("JSON".parse::<SchemaType>().unwrap(), SchemaType::Json);
1025 assert!("UNKNOWN".parse::<SchemaType>().is_err());
1026 }
1027
1028 #[test]
1029 fn test_schema_type_display() {
1030 assert_eq!(SchemaType::Avro.to_string(), "AVRO");
1031 assert_eq!(SchemaType::Protobuf.to_string(), "PROTOBUF");
1032 assert_eq!(SchemaType::Json.to_string(), "JSON");
1033 }
1034
1035 #[test]
1036 fn test_client_creation() {
1037 let client = SchemaRegistryClient::new("http://localhost:8081", None);
1038 assert_eq!(client.base_url(), "http://localhost:8081");
1039 assert!(!client.has_auth());
1040 assert_eq!(client.cache_size(), 0);
1041 }
1042
1043 #[test]
1044 fn test_client_with_auth() {
1045 let auth = SrAuth {
1046 username: "user".into(),
1047 password: "pass".into(),
1048 };
1049 let client = SchemaRegistryClient::new("http://localhost:8081", Some(auth));
1050 assert!(client.has_auth());
1051 }
1052
1053 #[test]
1054 fn test_client_trailing_slash_stripped() {
1055 let client = SchemaRegistryClient::new("http://localhost:8081/", None);
1056 assert_eq!(client.base_url(), "http://localhost:8081");
1057 }
1058
1059 #[test]
1060 fn test_arrow_to_avro_schema_simple() {
1061 let schema = Arc::new(Schema::new(vec![
1062 Field::new("id", DataType::Int64, false),
1063 Field::new("name", DataType::Utf8, false),
1064 ]));
1065
1066 let avro_str = arrow_to_avro_schema(&schema, "test_record").unwrap();
1067 let avro: serde_json::Value = serde_json::from_str(&avro_str).unwrap();
1068
1069 assert_eq!(avro["type"], "record");
1070 assert_eq!(avro["name"], "test_record");
1071
1072 let fields = avro["fields"].as_array().unwrap();
1073 assert_eq!(fields.len(), 2);
1074 assert_eq!(fields[0]["name"], "id");
1075 assert_eq!(fields[0]["type"], "long");
1076 assert_eq!(fields[1]["name"], "name");
1077 assert_eq!(fields[1]["type"], "string");
1078 }
1079
1080 #[test]
1081 fn test_arrow_to_avro_schema_sanitizes_hyphens() {
1082 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
1083
1084 let avro_str = arrow_to_avro_schema(&schema, "trades-avro-output").unwrap();
1085 let avro: serde_json::Value = serde_json::from_str(&avro_str).unwrap();
1086 assert_eq!(avro["name"], "trades_avro_output");
1087 }
1088
1089 #[test]
1090 fn test_arrow_to_avro_schema_nullable() {
1091 let schema = Arc::new(Schema::new(vec![
1092 Field::new("id", DataType::Int64, false),
1093 Field::new("email", DataType::Utf8, true),
1094 ]));
1095
1096 let avro_str = arrow_to_avro_schema(&schema, "record").unwrap();
1097 let avro: serde_json::Value = serde_json::from_str(&avro_str).unwrap();
1098
1099 let fields = avro["fields"].as_array().unwrap();
1100 assert_eq!(fields[0]["type"], "long");
1102 let union = fields[1]["type"].as_array().unwrap();
1104 assert_eq!(union.len(), 2);
1105 assert_eq!(union[0], "null");
1106 assert_eq!(union[1], "string");
1107 }
1108
1109 #[test]
1110 fn test_arrow_to_avro_all_primitives() {
1111 let schema = Arc::new(Schema::new(vec![
1112 Field::new("b", DataType::Boolean, false),
1113 Field::new("i32", DataType::Int32, false),
1114 Field::new("i64", DataType::Int64, false),
1115 Field::new("f32", DataType::Float32, false),
1116 Field::new("f64", DataType::Float64, false),
1117 Field::new("s", DataType::Utf8, false),
1118 Field::new("bin", DataType::Binary, false),
1119 ]));
1120
1121 let avro_str = arrow_to_avro_schema(&schema, "all_types").unwrap();
1122 let avro: serde_json::Value = serde_json::from_str(&avro_str).unwrap();
1123 let fields = avro["fields"].as_array().unwrap();
1124
1125 assert_eq!(fields[0]["type"], "boolean");
1126 assert_eq!(fields[1]["type"], "int");
1127 assert_eq!(fields[2]["type"], "long");
1128 assert_eq!(fields[3]["type"], "float");
1129 assert_eq!(fields[4]["type"], "double");
1130 assert_eq!(fields[5]["type"], "string");
1131 assert_eq!(fields[6]["type"], "bytes");
1132 }
1133
1134 #[test]
1135 fn test_arrow_to_avro_roundtrip() {
1136 let original = Arc::new(Schema::new(vec![
1137 Field::new("id", DataType::Int64, false),
1138 Field::new("name", DataType::Utf8, true),
1139 Field::new("active", DataType::Boolean, false),
1140 ]));
1141
1142 let avro_str = arrow_to_avro_schema(&original, "roundtrip").unwrap();
1143 let recovered = avro_to_arrow_schema(&avro_str).unwrap();
1144
1145 assert_eq!(recovered.fields().len(), 3);
1146 assert_eq!(recovered.field(0).data_type(), &DataType::Int64);
1147 assert!(!recovered.field(0).is_nullable());
1148 assert_eq!(recovered.field(1).data_type(), &DataType::Utf8);
1149 assert!(recovered.field(1).is_nullable());
1150 assert_eq!(recovered.field(2).data_type(), &DataType::Boolean);
1151 }
1152
1153 #[test]
1156 fn test_avro_to_arrow_array_type() {
1157 let avro = r#"{
1158 "type": "record",
1159 "name": "test",
1160 "fields": [
1161 {"name": "tags", "type": {"type": "array", "items": "string"}}
1162 ]
1163 }"#;
1164
1165 let schema = avro_to_arrow_schema(avro).unwrap();
1166 assert_eq!(schema.fields().len(), 1);
1167 match schema.field(0).data_type() {
1168 DataType::List(item) => {
1169 assert_eq!(item.data_type(), &DataType::Utf8);
1170 }
1171 other => panic!("expected List, got {other:?}"),
1172 }
1173 }
1174
1175 #[test]
1176 fn test_avro_to_arrow_map_type() {
1177 let avro = r#"{
1178 "type": "record",
1179 "name": "test",
1180 "fields": [
1181 {"name": "metadata", "type": {"type": "map", "values": "long"}}
1182 ]
1183 }"#;
1184
1185 let schema = avro_to_arrow_schema(avro).unwrap();
1186 assert_eq!(schema.fields().len(), 1);
1187 match schema.field(0).data_type() {
1188 DataType::Map(entries, _) => {
1189 if let DataType::Struct(fields) = entries.data_type() {
1190 assert_eq!(fields.len(), 2);
1191 assert_eq!(fields[0].name(), "key");
1192 assert_eq!(fields[0].data_type(), &DataType::Utf8);
1193 assert_eq!(fields[1].name(), "value");
1194 assert_eq!(fields[1].data_type(), &DataType::Int64);
1195 } else {
1196 panic!("expected Struct entries");
1197 }
1198 }
1199 other => panic!("expected Map, got {other:?}"),
1200 }
1201 }
1202
1203 #[test]
1204 fn test_avro_to_arrow_nested_record() {
1205 let avro = r#"{
1206 "type": "record",
1207 "name": "test",
1208 "fields": [
1209 {
1210 "name": "address",
1211 "type": {
1212 "type": "record",
1213 "name": "Address",
1214 "fields": [
1215 {"name": "street", "type": "string"},
1216 {"name": "zip", "type": "int"}
1217 ]
1218 }
1219 }
1220 ]
1221 }"#;
1222
1223 let schema = avro_to_arrow_schema(avro).unwrap();
1224 assert_eq!(schema.fields().len(), 1);
1225 match schema.field(0).data_type() {
1226 DataType::Struct(fields) => {
1227 assert_eq!(fields.len(), 2);
1228 assert_eq!(fields[0].name(), "street");
1229 assert_eq!(fields[0].data_type(), &DataType::Utf8);
1230 assert_eq!(fields[1].name(), "zip");
1231 assert_eq!(fields[1].data_type(), &DataType::Int32);
1232 }
1233 other => panic!("expected Struct, got {other:?}"),
1234 }
1235 }
1236
1237 #[test]
1238 fn test_avro_to_arrow_enum_type() {
1239 let avro = r#"{
1240 "type": "record",
1241 "name": "test",
1242 "fields": [
1243 {
1244 "name": "status",
1245 "type": {
1246 "type": "enum",
1247 "name": "Status",
1248 "symbols": ["ACTIVE", "INACTIVE", "PENDING"]
1249 }
1250 }
1251 ]
1252 }"#;
1253
1254 let schema = avro_to_arrow_schema(avro).unwrap();
1255 assert_eq!(schema.fields().len(), 1);
1256 match schema.field(0).data_type() {
1257 DataType::Dictionary(key, value) => {
1258 assert_eq!(key.as_ref(), &DataType::Int32);
1259 assert_eq!(value.as_ref(), &DataType::Utf8);
1260 }
1261 other => panic!("expected Dictionary, got {other:?}"),
1262 }
1263 }
1264
1265 #[test]
1266 fn test_avro_to_arrow_fixed_type() {
1267 let avro = r#"{
1268 "type": "record",
1269 "name": "test",
1270 "fields": [
1271 {
1272 "name": "uuid",
1273 "type": {"type": "fixed", "name": "uuid", "size": 16}
1274 }
1275 ]
1276 }"#;
1277
1278 let schema = avro_to_arrow_schema(avro).unwrap();
1279 assert_eq!(schema.fields().len(), 1);
1280 assert_eq!(schema.field(0).data_type(), &DataType::FixedSizeBinary(16));
1281 }
1282
1283 #[test]
1284 fn test_avro_to_arrow_nullable_complex_in_union() {
1285 let avro = r#"{
1286 "type": "record",
1287 "name": "test",
1288 "fields": [
1289 {
1290 "name": "tags",
1291 "type": ["null", {"type": "array", "items": "string"}]
1292 }
1293 ]
1294 }"#;
1295
1296 let schema = avro_to_arrow_schema(avro).unwrap();
1297 assert!(schema.field(0).is_nullable());
1298 assert!(matches!(schema.field(0).data_type(), DataType::List(_)));
1299 }
1300
1301 #[test]
1302 fn test_avro_array_missing_items() {
1303 let avro = r#"{
1304 "type": "record",
1305 "name": "test",
1306 "fields": [
1307 {"name": "bad", "type": {"type": "array"}}
1308 ]
1309 }"#;
1310 assert!(avro_to_arrow_schema(avro).is_err());
1311 }
1312
1313 #[test]
1314 fn test_avro_map_missing_values() {
1315 let avro = r#"{
1316 "type": "record",
1317 "name": "test",
1318 "fields": [
1319 {"name": "bad", "type": {"type": "map"}}
1320 ]
1321 }"#;
1322 assert!(avro_to_arrow_schema(avro).is_err());
1323 }
1324
1325 #[test]
1326 fn test_arrow_to_avro_array_type() {
1327 let schema = Arc::new(Schema::new(vec![Field::new(
1328 "tags",
1329 DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
1330 false,
1331 )]));
1332
1333 let avro_str = arrow_to_avro_schema(&schema, "test").unwrap();
1334 let avro: serde_json::Value = serde_json::from_str(&avro_str).unwrap();
1335 let field = &avro["fields"][0];
1336 assert_eq!(field["type"]["type"], "array");
1337 assert_eq!(field["type"]["items"], "string");
1338 }
1339
1340 #[test]
1341 fn test_arrow_to_avro_map_type() {
1342 let schema = Arc::new(Schema::new(vec![Field::new(
1343 "metadata",
1344 DataType::Map(
1345 Arc::new(Field::new(
1346 "entries",
1347 DataType::Struct(Fields::from(vec![
1348 Field::new("key", DataType::Utf8, false),
1349 Field::new("value", DataType::Int64, true),
1350 ])),
1351 false,
1352 )),
1353 false,
1354 ),
1355 false,
1356 )]));
1357
1358 let avro_str = arrow_to_avro_schema(&schema, "test").unwrap();
1359 let avro: serde_json::Value = serde_json::from_str(&avro_str).unwrap();
1360 let field = &avro["fields"][0];
1361 assert_eq!(field["type"]["type"], "map");
1362 assert_eq!(field["type"]["values"], "long");
1363 }
1364
1365 #[test]
1366 fn test_arrow_to_avro_struct_type() {
1367 let schema = Arc::new(Schema::new(vec![Field::new(
1368 "address",
1369 DataType::Struct(Fields::from(vec![
1370 Field::new("street", DataType::Utf8, false),
1371 Field::new("zip", DataType::Int32, false),
1372 ])),
1373 false,
1374 )]));
1375
1376 let avro_str = arrow_to_avro_schema(&schema, "test").unwrap();
1377 let avro: serde_json::Value = serde_json::from_str(&avro_str).unwrap();
1378 let field = &avro["fields"][0];
1379 assert_eq!(field["type"]["type"], "record");
1380 let nested = field["type"]["fields"].as_array().unwrap();
1381 assert_eq!(nested.len(), 2);
1382 assert_eq!(nested[0]["name"], "street");
1383 assert_eq!(nested[0]["type"], "string");
1384 assert_eq!(nested[1]["name"], "zip");
1385 assert_eq!(nested[1]["type"], "int");
1386 }
1387
1388 #[test]
1389 fn test_arrow_to_avro_fixed_type() {
1390 let schema = Arc::new(Schema::new(vec![Field::new(
1391 "uuid",
1392 DataType::FixedSizeBinary(16),
1393 false,
1394 )]));
1395
1396 let avro_str = arrow_to_avro_schema(&schema, "test").unwrap();
1397 let avro: serde_json::Value = serde_json::from_str(&avro_str).unwrap();
1398 let field = &avro["fields"][0];
1399 assert_eq!(field["type"]["type"], "fixed");
1400 assert_eq!(field["type"]["size"], 16);
1401 }
1402
1403 fn make_cached_schema(id: i32) -> CachedSchema {
1406 CachedSchema {
1407 id,
1408 version: 1,
1409 schema_type: SchemaType::Avro,
1410 schema_str: format!(
1411 r#"{{"type":"record","name":"t{id}","fields":[{{"name":"x","type":"int"}}]}}"#
1412 ),
1413 arrow_schema: Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)])),
1414 inserted_at: Instant::now(),
1415 }
1416 }
1417
1418 #[test]
1419 fn test_cache_config_defaults() {
1420 let config = SchemaRegistryCacheConfig::default();
1421 assert_eq!(config.max_entries, 1000);
1422 assert_eq!(config.ttl, Some(Duration::from_secs(3600)));
1423 }
1424
1425 #[test]
1426 fn test_cache_lru_eviction() {
1427 let config = SchemaRegistryCacheConfig {
1428 max_entries: 3,
1429 ttl: None,
1430 };
1431 let client = SchemaRegistryClient::with_cache_config("http://localhost:8081", None, config);
1432
1433 client.cache_insert(1, make_cached_schema(1));
1435 client.cache_insert(2, make_cached_schema(2));
1436 client.cache_insert(3, make_cached_schema(3));
1437 assert_eq!(client.cache_size(), 3);
1438
1439 client.cache_insert(4, make_cached_schema(4));
1441 assert!(client.cache_size() <= 3);
1442 assert!(client.cache_get(4).is_some());
1444 }
1445
1446 #[test]
1447 fn test_cache_ttl_expiration() {
1448 let config = SchemaRegistryCacheConfig {
1449 max_entries: 100,
1450 ttl: Some(Duration::from_millis(50)),
1451 };
1452 let client = SchemaRegistryClient::with_cache_config("http://localhost:8081", None, config);
1453
1454 client.cache_insert(1, make_cached_schema(1));
1455 assert!(client.cache_get(1).is_some());
1456
1457 std::thread::sleep(Duration::from_millis(60));
1459 assert!(client.cache_get(1).is_none());
1461 }
1462
1463 #[test]
1464 fn test_cache_no_ttl() {
1465 let config = SchemaRegistryCacheConfig {
1466 max_entries: 100,
1467 ttl: None,
1468 };
1469 let client = SchemaRegistryClient::with_cache_config("http://localhost:8081", None, config);
1470
1471 client.cache_insert(1, make_cached_schema(1));
1472 assert!(client.cache_get(1).is_some());
1474 }
1475
1476 #[test]
1477 fn test_cache_replace_existing_id() {
1478 let config = SchemaRegistryCacheConfig {
1479 max_entries: 10,
1480 ttl: None,
1481 };
1482 let client = SchemaRegistryClient::with_cache_config("http://localhost:8081", None, config);
1483
1484 client.cache_insert(1, make_cached_schema(1));
1485 client.cache_insert(2, make_cached_schema(2));
1486 assert_eq!(client.cache_size(), 2);
1487
1488 client.cache_insert(1, make_cached_schema(1));
1490 assert_eq!(client.cache_size(), 2);
1491 }
1492
1493 #[test]
1494 fn test_schema_incompatible_error_via_serde() {
1495 let err = SerdeError::SchemaIncompatible {
1496 subject: "orders-value".into(),
1497 message: "READER_FIELD_MISSING_DEFAULT_VALUE: field 'new_field'".into(),
1498 };
1499 let conn_err: ConnectorError = err.into();
1500 assert!(matches!(
1501 conn_err,
1502 ConnectorError::Serde(SerdeError::SchemaIncompatible { .. })
1503 ));
1504 assert!(conn_err.to_string().contains("orders-value"));
1505 }
1506
1507 #[test]
1508 fn test_validate_and_register_method_exists() {
1509 let client = SchemaRegistryClient::new("http://localhost:8081", None);
1511 let _ = &client;
1513 }
1514
1515 #[test]
1516 fn test_complex_type_roundtrip() {
1517 let avro = r#"{
1518 "type": "record",
1519 "name": "test",
1520 "fields": [
1521 {"name": "tags", "type": {"type": "array", "items": "string"}},
1522 {"name": "metadata", "type": {"type": "map", "values": "long"}}
1523 ]
1524 }"#;
1525
1526 let arrow_schema = avro_to_arrow_schema(avro).unwrap();
1527 assert!(matches!(
1528 arrow_schema.field(0).data_type(),
1529 DataType::List(_)
1530 ));
1531 assert!(matches!(
1532 arrow_schema.field(1).data_type(),
1533 DataType::Map(_, _)
1534 ));
1535
1536 let avro_str = arrow_to_avro_schema(&arrow_schema, "test").unwrap();
1538 let recovered = avro_to_arrow_schema(&avro_str).unwrap();
1539
1540 assert!(matches!(recovered.field(0).data_type(), DataType::List(_)));
1541 assert!(matches!(
1542 recovered.field(1).data_type(),
1543 DataType::Map(_, _)
1544 ));
1545 }
1546}