1use std::time::{Duration, Instant};
8
9use quick_cache::sync::Cache;
10
11use arrow_schema::{DataType, SchemaRef};
12use reqwest::Client;
13use serde::{Deserialize, Serialize};
14
15use crate::error::{ConnectorError, SerdeError};
16use crate::kafka::config::{CompatibilityLevel, SrAuth};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum SchemaType {
21 Avro,
23 Protobuf,
25 Json,
27}
28
29impl std::str::FromStr for SchemaType {
30 type Err = ConnectorError;
31
32 fn from_str(s: &str) -> Result<Self, Self::Err> {
33 match s.to_uppercase().as_str() {
34 "AVRO" => Ok(SchemaType::Avro),
35 "PROTOBUF" => Ok(SchemaType::Protobuf),
36 "JSON" => Ok(SchemaType::Json),
37 other => Err(ConnectorError::ConfigurationError(format!(
38 "unknown schema type: '{other}'"
39 ))),
40 }
41 }
42}
43
44impl std::fmt::Display for SchemaType {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 match self {
47 SchemaType::Avro => write!(f, "AVRO"),
48 SchemaType::Protobuf => write!(f, "PROTOBUF"),
49 SchemaType::Json => write!(f, "JSON"),
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct SchemaRegistryCacheConfig {
57 pub max_entries: usize,
59 pub ttl: Option<Duration>,
61}
62
63impl Default for SchemaRegistryCacheConfig {
64 fn default() -> Self {
65 Self {
66 max_entries: 1000,
67 ttl: Some(Duration::from_secs(3600)),
68 }
69 }
70}
71
72#[derive(Debug, Clone)]
74pub struct CachedSchema {
75 pub id: i32,
77 pub version: i32,
79 pub schema_type: SchemaType,
81 pub schema_str: String,
83 pub arrow_schema: SchemaRef,
85 inserted_at: Instant,
87}
88
89#[derive(Debug, Clone)]
91pub struct CompatibilityResult {
92 pub is_compatible: bool,
94 pub messages: Vec<String>,
96}
97
98pub struct SchemaRegistryClient {
103 client: Client,
104 base_url: String,
105 auth: Option<SrAuth>,
106 cache: Cache<i32, CachedSchema>,
108 subject_cache: Cache<String, CachedSchema>,
110 cache_config: SchemaRegistryCacheConfig,
112}
113
114#[derive(Deserialize)]
117struct SchemaByIdResponse {
118 schema: String,
119 #[serde(default = "default_schema_type")]
120 #[serde(rename = "schemaType")]
121 schema_type: String,
122}
123
124#[derive(Deserialize)]
125struct SchemaVersionResponse {
126 id: i32,
127 version: i32,
128 schema: String,
129 #[serde(default = "default_schema_type")]
130 #[serde(rename = "schemaType")]
131 schema_type: String,
132}
133
134#[derive(Deserialize)]
135struct CompatibilityResponse {
136 is_compatible: bool,
137 #[serde(default)]
138 messages: Vec<String>,
139}
140
141#[derive(Deserialize)]
142struct ConfigResponse {
143 #[serde(rename = "compatibilityLevel")]
144 compatibility_level: String,
145}
146
147#[derive(Serialize)]
148struct CompatibilityRequest {
149 schema: String,
150 #[serde(rename = "schemaType")]
151 schema_type: String,
152}
153
154#[derive(Serialize)]
155struct ConfigUpdateRequest {
156 compatibility: String,
157}
158
159#[derive(Serialize)]
160struct RegisterSchemaRequest {
161 schema: String,
162 #[serde(rename = "schemaType")]
163 schema_type: String,
164}
165
166#[derive(Deserialize)]
167struct RegisterSchemaResponse {
168 id: i32,
169}
170
171fn default_schema_type() -> String {
172 "AVRO".to_string()
173}
174
175impl SchemaRegistryClient {
176 #[must_use]
178 pub fn new(base_url: impl Into<String>, auth: Option<SrAuth>) -> Self {
179 Self::with_cache_config(base_url, auth, SchemaRegistryCacheConfig::default())
180 }
181
182 pub fn with_tls(
188 base_url: impl Into<String>,
189 auth: Option<SrAuth>,
190 ca_cert_path: &str,
191 ) -> Result<Self, ConnectorError> {
192 Self::with_tls_mtls(base_url, auth, ca_cert_path, None, None)
193 }
194
195 pub fn with_tls_mtls(
202 base_url: impl Into<String>,
203 auth: Option<SrAuth>,
204 ca_cert_path: &str,
205 client_cert_path: Option<&str>,
206 client_key_path: Option<&str>,
207 ) -> Result<Self, ConnectorError> {
208 let pem = std::fs::read(ca_cert_path).map_err(|e| {
209 ConnectorError::ConfigurationError(format!(
210 "failed to read SR CA cert at '{ca_cert_path}': {e}"
211 ))
212 })?;
213 let cert = reqwest::tls::Certificate::from_pem(&pem).map_err(|e| {
214 ConnectorError::ConfigurationError(format!(
215 "invalid PEM CA cert at '{ca_cert_path}': {e}"
216 ))
217 })?;
218
219 let mut builder = Client::builder().add_root_certificate(cert);
220
221 if client_cert_path.is_some() != client_key_path.is_some() {
222 return Err(ConnectorError::ConfigurationError(
223 "mTLS requires both client cert and key — only one was provided".into(),
224 ));
225 }
226 if let (Some(cert_path), Some(key_path)) = (client_cert_path, client_key_path) {
227 let mut identity_pem = std::fs::read(cert_path).map_err(|e| {
228 ConnectorError::ConfigurationError(format!(
229 "failed to read SR client cert at '{cert_path}': {e}"
230 ))
231 })?;
232 let key_pem = std::fs::read(key_path).map_err(|e| {
233 ConnectorError::ConfigurationError(format!(
234 "failed to read SR client key at '{key_path}': {e}"
235 ))
236 })?;
237 identity_pem.extend_from_slice(&key_pem);
239 let identity = reqwest::tls::Identity::from_pem(&identity_pem).map_err(|e| {
240 ConnectorError::ConfigurationError(format!("invalid client cert/key PEM: {e}"))
241 })?;
242 builder = builder.identity(identity);
243 }
244
245 let client = builder.build().map_err(|e| {
246 ConnectorError::ConfigurationError(format!("failed to build TLS client: {e}"))
247 })?;
248
249 let cache_config = SchemaRegistryCacheConfig::default();
250 let cache = Cache::new(cache_config.max_entries);
251 let subject_cache = Cache::new(256);
252 Ok(Self {
253 client,
254 base_url: base_url.into().trim_end_matches('/').to_string(),
255 auth,
256 cache,
257 subject_cache,
258 cache_config,
259 })
260 }
261
262 #[must_use]
264 pub fn with_cache_config(
265 base_url: impl Into<String>,
266 auth: Option<SrAuth>,
267 cache_config: SchemaRegistryCacheConfig,
268 ) -> Self {
269 let cache = Cache::new(cache_config.max_entries);
270 let subject_cache = Cache::new(256);
272 Self {
273 client: Client::new(),
274 base_url: base_url.into().trim_end_matches('/').to_string(),
275 auth,
276 cache,
277 subject_cache,
278 cache_config,
279 }
280 }
281
282 #[must_use]
284 pub fn base_url(&self) -> &str {
285 &self.base_url
286 }
287
288 #[must_use]
290 pub fn has_auth(&self) -> bool {
291 self.auth.is_some()
292 }
293
294 #[must_use]
296 pub fn cache_config(&self) -> &SchemaRegistryCacheConfig {
297 &self.cache_config
298 }
299
300 fn cache_insert(&self, id: i32, mut schema: CachedSchema) {
304 schema.inserted_at = Instant::now();
305 self.cache.insert(id, schema);
306 }
307
308 fn cache_get(&self, id: i32) -> Option<CachedSchema> {
313 let schema = self.cache.get(&id)?;
314 if let Some(ttl) = self.cache_config.ttl {
315 if schema.inserted_at.elapsed() > ttl {
316 self.cache.remove(&id);
317 return None;
318 }
319 }
320 Some(schema)
322 }
323
324 pub async fn get_schema_by_id(&self, id: i32) -> Result<CachedSchema, ConnectorError> {
333 if let Some(cached) = self.cache_get(id) {
334 return Ok(cached);
335 }
336
337 let url = format!("{}/schemas/ids/{}", self.base_url, id);
338 let resp: SchemaByIdResponse = self.get_json(&url).await?;
339
340 let schema_type: SchemaType = resp.schema_type.parse()?;
341 let arrow_schema = schema_to_arrow(schema_type, &resp.schema)?;
342
343 let cached = CachedSchema {
344 id,
345 version: 0, schema_type,
347 schema_str: resp.schema,
348 arrow_schema,
349 inserted_at: Instant::now(),
350 };
351 self.cache_insert(id, cached.clone());
352 Ok(cached)
353 }
354
355 pub async fn get_latest_schema(&self, subject: &str) -> Result<CachedSchema, ConnectorError> {
361 let url = format!("{}/subjects/{}/versions/latest", self.base_url, subject);
362 let resp: SchemaVersionResponse = self.get_json(&url).await?;
363
364 let schema_type: SchemaType = resp.schema_type.parse()?;
365 let arrow_schema = schema_to_arrow(schema_type, &resp.schema)?;
366
367 let cached = CachedSchema {
368 id: resp.id,
369 version: resp.version,
370 schema_type,
371 schema_str: resp.schema,
372 arrow_schema,
373 inserted_at: Instant::now(),
374 };
375
376 self.cache_insert(resp.id, cached.clone());
377 self.subject_cache
378 .insert(subject.to_string(), cached.clone());
379 Ok(cached)
380 }
381
382 pub async fn get_schema_version(
388 &self,
389 subject: &str,
390 version: i32,
391 ) -> Result<CachedSchema, ConnectorError> {
392 let url = format!(
393 "{}/subjects/{}/versions/{}",
394 self.base_url, subject, version
395 );
396 let resp: SchemaVersionResponse = self.get_json(&url).await?;
397
398 let schema_type: SchemaType = resp.schema_type.parse()?;
399 let arrow_schema = schema_to_arrow(schema_type, &resp.schema)?;
400
401 let cached = CachedSchema {
402 id: resp.id,
403 version: resp.version,
404 schema_type,
405 schema_str: resp.schema,
406 arrow_schema,
407 inserted_at: Instant::now(),
408 };
409 self.cache_insert(resp.id, cached.clone());
410 Ok(cached)
411 }
412
413 pub async fn check_compatibility(
419 &self,
420 subject: &str,
421 schema_str: &str,
422 ) -> Result<CompatibilityResult, ConnectorError> {
423 let url = format!(
424 "{}/compatibility/subjects/{}/versions/latest",
425 self.base_url, subject
426 );
427
428 let body = CompatibilityRequest {
429 schema: schema_str.to_string(),
430 schema_type: "AVRO".to_string(),
431 };
432
433 let mut req = self.client.post(&url).json(&body);
434 if let Some(ref auth) = self.auth {
435 req = req.basic_auth(&auth.username, Some(&auth.password));
436 }
437
438 let resp = req
439 .send()
440 .await
441 .map_err(|e| ConnectorError::ConnectionFailed(format!("schema registry: {e}")))?;
442
443 if !resp.status().is_success() {
444 let status = resp.status();
445 let text = resp.text().await.unwrap_or_default();
446 return Err(ConnectorError::ConnectionFailed(format!(
447 "schema registry compatibility check failed: {status} {text}"
448 )));
449 }
450
451 let result: CompatibilityResponse = resp.json().await.map_err(|e| {
452 ConnectorError::Internal(format!("failed to parse compatibility response: {e}"))
453 })?;
454
455 Ok(CompatibilityResult {
456 is_compatible: result.is_compatible,
457 messages: result.messages,
458 })
459 }
460
461 pub async fn get_compatibility_level(
467 &self,
468 subject: &str,
469 ) -> Result<CompatibilityLevel, ConnectorError> {
470 let url = format!("{}/config/{}", self.base_url, subject);
471 let resp: ConfigResponse = self.get_json(&url).await?;
472 resp.compatibility_level.parse()
473 }
474
475 pub async fn set_compatibility_level(
481 &self,
482 subject: &str,
483 level: CompatibilityLevel,
484 ) -> Result<(), ConnectorError> {
485 let url = format!("{}/config/{}", self.base_url, subject);
486 let body = ConfigUpdateRequest {
487 compatibility: level.as_str().to_string(),
488 };
489
490 let mut req = self.client.put(&url).json(&body);
491 if let Some(ref auth) = self.auth {
492 req = req.basic_auth(&auth.username, Some(&auth.password));
493 }
494
495 let resp = req
496 .send()
497 .await
498 .map_err(|e| ConnectorError::ConnectionFailed(format!("schema registry: {e}")))?;
499
500 if !resp.status().is_success() {
501 let status = resp.status();
502 let text = resp.text().await.unwrap_or_default();
503 return Err(ConnectorError::ConnectionFailed(format!(
504 "schema registry config update failed: {status} {text}"
505 )));
506 }
507
508 Ok(())
509 }
510
511 pub async fn resolve_confluent_id(&self, id: i32) -> Result<CachedSchema, ConnectorError> {
520 self.get_schema_by_id(id).await
521 }
522
523 pub async fn register_schema(
533 &self,
534 subject: &str,
535 schema_str: &str,
536 schema_type: SchemaType,
537 ) -> Result<i32, ConnectorError> {
538 if let Some(cached) = self.subject_cache.get(subject) {
540 if cached.schema_str == schema_str {
541 return Ok(cached.id);
542 }
543 }
544
545 let url = format!("{}/subjects/{}/versions", self.base_url, subject);
546 let body = RegisterSchemaRequest {
547 schema: schema_str.to_string(),
548 schema_type: schema_type.to_string(),
549 };
550
551 let mut req = self.client.post(&url).json(&body);
552 if let Some(ref auth) = self.auth {
553 req = req.basic_auth(&auth.username, Some(&auth.password));
554 }
555
556 let resp = req
557 .send()
558 .await
559 .map_err(|e| ConnectorError::ConnectionFailed(format!("schema registry: {e}")))?;
560
561 if !resp.status().is_success() {
562 let status = resp.status();
563 let text = resp.text().await.unwrap_or_default();
564 return Err(ConnectorError::ConnectionFailed(format!(
565 "schema registry register failed: {status} {text}"
566 )));
567 }
568
569 let result: RegisterSchemaResponse = resp.json().await.map_err(|e| {
570 ConnectorError::Internal(format!("failed to parse register schema response: {e}"))
571 })?;
572
573 let arrow_schema = avro_to_arrow_schema(schema_str)?;
574 let cached = CachedSchema {
575 id: result.id,
576 version: 0,
577 schema_type,
578 schema_str: schema_str.to_string(),
579 arrow_schema,
580 inserted_at: Instant::now(),
581 };
582 self.cache_insert(result.id, cached.clone());
583 self.subject_cache.insert(subject.to_string(), cached);
584
585 Ok(result.id)
586 }
587
588 pub async fn validate_and_register_schema(
599 &self,
600 subject: &str,
601 schema_str: &str,
602 schema_type: SchemaType,
603 ) -> Result<i32, ConnectorError> {
604 match self.check_compatibility(subject, schema_str).await {
606 Ok(result) => {
607 if !result.is_compatible {
608 let message = if result.messages.is_empty() {
609 "new schema is not compatible with existing version".to_string()
610 } else {
611 result.messages.join("; ")
612 };
613 return Err(ConnectorError::Serde(SerdeError::SchemaIncompatible {
614 subject: subject.to_string(),
615 message,
616 }));
617 }
618 }
619 Err(ConnectorError::ConnectionFailed(msg)) if msg.contains("404") => {
620 }
622 Err(e) => return Err(e),
623 }
624
625 self.register_schema(subject, schema_str, schema_type).await
626 }
627
628 #[must_use]
630 pub fn is_cached(&self, id: i32) -> bool {
631 self.cache.contains_key(&id)
632 }
633
634 #[must_use]
636 pub fn cache_size(&self) -> usize {
637 self.cache.len()
638 }
639
640 async fn get_json<T: serde::de::DeserializeOwned>(
645 &self,
646 url: &str,
647 ) -> Result<T, ConnectorError> {
648 let backoffs = [
649 std::time::Duration::from_millis(100),
650 std::time::Duration::from_millis(500),
651 ];
652 let mut last_err = None;
653
654 for (attempt, backoff) in std::iter::once(&std::time::Duration::ZERO)
655 .chain(backoffs.iter())
656 .enumerate()
657 {
658 if attempt > 0 {
659 tokio::time::sleep(*backoff).await;
660 }
661
662 let mut req = self.client.get(url);
663 if let Some(ref auth) = self.auth {
664 req = req.basic_auth(&auth.username, Some(&auth.password));
665 }
666
667 let resp = match req.send().await {
668 Ok(r) => r,
669 Err(e) => {
670 tracing::warn!(
671 attempt = attempt + 1,
672 error = %e,
673 "schema registry request failed, retrying"
674 );
675 last_err = Some(ConnectorError::ConnectionFailed(format!(
676 "schema registry: {e}"
677 )));
678 continue;
679 }
680 };
681
682 let status = resp.status();
683 if status.is_success() {
684 return resp.json::<T>().await.map_err(|e| {
685 ConnectorError::Internal(format!(
686 "failed to parse schema registry response: {e}"
687 ))
688 });
689 }
690
691 if status.is_client_error() {
693 let text = resp.text().await.unwrap_or_default();
694 return Err(ConnectorError::ConnectionFailed(format!(
695 "schema registry client error: {status} {text}"
696 )));
697 }
698
699 let text = resp.text().await.unwrap_or_default();
701 tracing::warn!(
702 attempt = attempt + 1,
703 status = %status,
704 "schema registry server error, retrying"
705 );
706 last_err = Some(ConnectorError::ConnectionFailed(format!(
707 "schema registry request failed: {status} {text}"
708 )));
709 }
710
711 Err(last_err.unwrap_or_else(|| {
712 ConnectorError::ConnectionFailed("schema registry: all retries exhausted".into())
713 }))
714 }
715}
716
717impl std::fmt::Debug for SchemaRegistryClient {
718 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
719 f.debug_struct("SchemaRegistryClient")
720 .field("base_url", &self.base_url)
721 .field("has_auth", &self.auth.is_some())
722 .field("cached_schemas", &self.cache.len())
723 .field("cached_subjects", &self.subject_cache.len())
724 .finish_non_exhaustive()
725 }
726}
727
728fn schema_to_arrow(schema_type: SchemaType, schema_str: &str) -> Result<SchemaRef, ConnectorError> {
732 let name = match schema_type {
733 SchemaType::Avro => return avro_to_arrow_schema(schema_str),
734 SchemaType::Json => "JSON Schema Registry",
735 SchemaType::Protobuf => "Protobuf Schema Registry",
736 };
737 Err(ConnectorError::SchemaMismatch(format!(
738 "{name} subjects are not yet supported for auto-discovery \
739 — declare columns explicitly or use an Avro subject"
740 )))
741}
742
743pub fn avro_to_arrow_schema(avro_schema_str: &str) -> Result<SchemaRef, ConnectorError> {
749 use arrow_avro::reader::ReaderBuilder;
750 use arrow_avro::schema::{AvroSchema, Fingerprint, FingerprintAlgorithm, SchemaStore};
751
752 let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::Id);
753 let avro_schema = AvroSchema::new(avro_schema_str.to_string());
754 let fp = Fingerprint::Id(0);
755 store
756 .set(fp, avro_schema)
757 .map_err(|e| ConnectorError::SchemaMismatch(format!("invalid Avro schema: {e}")))?;
758
759 let decoder = ReaderBuilder::new()
760 .with_writer_schema_store(store)
761 .with_active_fingerprint(fp)
762 .build_decoder()
763 .map_err(|e| ConnectorError::SchemaMismatch(format!("Avro→Arrow conversion: {e}")))?;
764
765 Ok(decoder.schema())
766}
767
768pub fn arrow_to_avro_schema(schema: &SchemaRef, record_name: &str) -> Result<String, SerdeError> {
777 let mut fields = Vec::with_capacity(schema.fields().len());
778
779 for field in schema.fields() {
780 let avro_type = arrow_to_avro_type(field.data_type())?;
781
782 let field_type = if field.is_nullable() {
783 serde_json::json!(["null", avro_type])
784 } else {
785 avro_type
786 };
787
788 fields.push(serde_json::json!({
789 "name": field.name(),
790 "type": field_type,
791 }));
792 }
793
794 let safe_name = record_name.replace('-', "_");
797
798 let schema = serde_json::json!({
799 "type": "record",
800 "name": safe_name,
801 "fields": fields,
802 });
803
804 serde_json::to_string(&schema)
805 .map_err(|e| SerdeError::MalformedInput(format!("failed to serialize Avro schema: {e}")))
806}
807
808fn arrow_to_avro_type(data_type: &DataType) -> Result<serde_json::Value, SerdeError> {
810 match data_type {
811 DataType::Null => Ok(serde_json::json!("null")),
812 DataType::Boolean => Ok(serde_json::json!("boolean")),
813 DataType::Int8
814 | DataType::Int16
815 | DataType::Int32
816 | DataType::UInt8
817 | DataType::UInt16
818 | DataType::UInt32 => Ok(serde_json::json!("int")),
819 DataType::Int64 | DataType::UInt64 => Ok(serde_json::json!("long")),
820 DataType::Float32 => Ok(serde_json::json!("float")),
821 DataType::Float64 => Ok(serde_json::json!("double")),
822 DataType::Utf8 | DataType::LargeUtf8 => Ok(serde_json::json!("string")),
823 DataType::Binary | DataType::LargeBinary => Ok(serde_json::json!("bytes")),
824 DataType::List(item_field) => {
825 let items = arrow_to_avro_type(item_field.data_type())?;
826 Ok(serde_json::json!({
827 "type": "array",
828 "items": items,
829 }))
830 }
831 DataType::Map(entries_field, _) => {
832 if let DataType::Struct(fields) = entries_field.data_type() {
834 let value_field = fields.iter().find(|f| f.name() == "value").ok_or_else(|| {
835 SerdeError::UnsupportedFormat(
836 "Arrow Map missing 'value' field in entries struct".into(),
837 )
838 })?;
839 let values = arrow_to_avro_type(value_field.data_type())?;
840 Ok(serde_json::json!({
841 "type": "map",
842 "values": values,
843 }))
844 } else {
845 Err(SerdeError::UnsupportedFormat(
846 "Arrow Map entries field is not a Struct".into(),
847 ))
848 }
849 }
850 DataType::Struct(fields) => {
851 let mut avro_fields = Vec::with_capacity(fields.len());
852 for field in fields {
853 let avro_type = arrow_to_avro_type(field.data_type())?;
854 let field_type = if field.is_nullable() {
855 serde_json::json!(["null", avro_type])
856 } else {
857 avro_type
858 };
859 avro_fields.push(serde_json::json!({
860 "name": field.name(),
861 "type": field_type,
862 }));
863 }
864 Ok(serde_json::json!({
865 "type": "record",
866 "name": "nested",
867 "fields": avro_fields,
868 }))
869 }
870 DataType::Dictionary(_, value_type) if value_type.as_ref() == &DataType::Utf8 => {
871 Ok(serde_json::json!({
872 "type": "enum",
873 "name": "enum_field",
874 "symbols": [],
875 }))
876 }
877 DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, _) => {
878 Ok(serde_json::json!({"type": "long", "logicalType": "timestamp-millis"}))
879 }
880 DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, _) => {
881 Ok(serde_json::json!({"type": "long", "logicalType": "timestamp-micros"}))
882 }
883 DataType::Date32 => Ok(serde_json::json!({"type": "int", "logicalType": "date"})),
884 DataType::Time32(arrow_schema::TimeUnit::Millisecond) => {
885 Ok(serde_json::json!({"type": "int", "logicalType": "time-millis"}))
886 }
887 DataType::Time64(arrow_schema::TimeUnit::Microsecond) => {
888 Ok(serde_json::json!({"type": "long", "logicalType": "time-micros"}))
889 }
890 DataType::FixedSizeBinary(size) => Ok(serde_json::json!({
891 "type": "fixed",
892 "name": "fixed_field",
893 "size": size,
894 })),
895 other => Err(SerdeError::UnsupportedFormat(format!(
896 "no Avro equivalent for Arrow type: {other}"
897 ))),
898 }
899}
900
901#[cfg(test)]
902mod tests {
903 use std::sync::Arc;
904
905 use super::*;
906 use arrow_schema::{Field, Fields, Schema};
907
908 #[test]
909 fn test_avro_to_arrow_simple_record() {
910 let avro = r#"{
911 "type": "record",
912 "name": "test",
913 "fields": [
914 {"name": "id", "type": "long"},
915 {"name": "name", "type": "string"},
916 {"name": "active", "type": "boolean"}
917 ]
918 }"#;
919
920 let schema = avro_to_arrow_schema(avro).unwrap();
921 assert_eq!(schema.fields().len(), 3);
922 assert_eq!(schema.field(0).name(), "id");
923 assert_eq!(schema.field(0).data_type(), &DataType::Int64);
924 assert!(!schema.field(0).is_nullable());
925 assert_eq!(schema.field(1).name(), "name");
926 assert_eq!(schema.field(1).data_type(), &DataType::Utf8);
927 assert_eq!(schema.field(2).name(), "active");
928 assert_eq!(schema.field(2).data_type(), &DataType::Boolean);
929 }
930
931 #[test]
932 fn test_avro_to_arrow_nullable_union() {
933 let avro = r#"{
934 "type": "record",
935 "name": "test",
936 "fields": [
937 {"name": "id", "type": "long"},
938 {"name": "email", "type": ["null", "string"]}
939 ]
940 }"#;
941
942 let schema = avro_to_arrow_schema(avro).unwrap();
943 assert_eq!(schema.fields().len(), 2);
944 assert!(!schema.field(0).is_nullable());
945 assert!(schema.field(1).is_nullable());
946 assert_eq!(schema.field(1).data_type(), &DataType::Utf8);
947 }
948
949 #[test]
950 fn test_avro_to_arrow_all_primitives() {
951 let avro = r#"{
952 "type": "record",
953 "name": "test",
954 "fields": [
955 {"name": "b", "type": "boolean"},
956 {"name": "i", "type": "int"},
957 {"name": "l", "type": "long"},
958 {"name": "f", "type": "float"},
959 {"name": "d", "type": "double"},
960 {"name": "s", "type": "string"},
961 {"name": "raw", "type": "bytes"}
962 ]
963 }"#;
964
965 let schema = avro_to_arrow_schema(avro).unwrap();
966 assert_eq!(schema.field(0).data_type(), &DataType::Boolean);
967 assert_eq!(schema.field(1).data_type(), &DataType::Int32);
968 assert_eq!(schema.field(2).data_type(), &DataType::Int64);
969 assert_eq!(schema.field(3).data_type(), &DataType::Float32);
970 assert_eq!(schema.field(4).data_type(), &DataType::Float64);
971 assert_eq!(schema.field(5).data_type(), &DataType::Utf8);
972 assert_eq!(schema.field(6).data_type(), &DataType::Binary);
973 }
974
975 #[test]
976 fn test_avro_to_arrow_invalid_json() {
977 assert!(avro_to_arrow_schema("not json").is_err());
978 }
979
980 #[test]
981 fn test_avro_to_arrow_missing_fields() {
982 let avro = r#"{"type": "record", "name": "test"}"#;
983 assert!(avro_to_arrow_schema(avro).is_err());
984 }
985
986 #[test]
987 fn schema_to_arrow_avro_works() {
988 let avro = r#"{"type":"record","name":"t","fields":[{"name":"x","type":"long"}]}"#;
989 let schema = schema_to_arrow(SchemaType::Avro, avro).unwrap();
990 assert_eq!(schema.field(0).name(), "x");
991 }
992
993 #[test]
994 fn schema_to_arrow_json_returns_actionable_error() {
995 let err = schema_to_arrow(SchemaType::Json, "{}").unwrap_err();
996 assert!(
997 err.to_string().contains("JSON Schema Registry"),
998 "error should name the subject type, got: {err}"
999 );
1000 }
1001
1002 #[test]
1003 fn schema_to_arrow_protobuf_returns_actionable_error() {
1004 let err = schema_to_arrow(SchemaType::Protobuf, "").unwrap_err();
1005 assert!(
1006 err.to_string().contains("Protobuf"),
1007 "error should name the subject type, got: {err}"
1008 );
1009 }
1010
1011 #[test]
1012 fn test_schema_type_parsing() {
1013 assert_eq!("AVRO".parse::<SchemaType>().unwrap(), SchemaType::Avro);
1014 assert_eq!(
1015 "PROTOBUF".parse::<SchemaType>().unwrap(),
1016 SchemaType::Protobuf
1017 );
1018 assert_eq!("JSON".parse::<SchemaType>().unwrap(), SchemaType::Json);
1019 assert!("UNKNOWN".parse::<SchemaType>().is_err());
1020 }
1021
1022 #[test]
1023 fn test_schema_type_display() {
1024 assert_eq!(SchemaType::Avro.to_string(), "AVRO");
1025 assert_eq!(SchemaType::Protobuf.to_string(), "PROTOBUF");
1026 assert_eq!(SchemaType::Json.to_string(), "JSON");
1027 }
1028
1029 #[test]
1030 fn test_client_creation() {
1031 let client = SchemaRegistryClient::new("http://localhost:8081", None);
1032 assert_eq!(client.base_url(), "http://localhost:8081");
1033 assert!(!client.has_auth());
1034 assert_eq!(client.cache_size(), 0);
1035 }
1036
1037 #[test]
1038 fn test_client_with_auth() {
1039 let auth = SrAuth {
1040 username: "user".into(),
1041 password: "pass".into(),
1042 };
1043 let client = SchemaRegistryClient::new("http://localhost:8081", Some(auth));
1044 assert!(client.has_auth());
1045 }
1046
1047 #[test]
1048 fn test_client_trailing_slash_stripped() {
1049 let client = SchemaRegistryClient::new("http://localhost:8081/", None);
1050 assert_eq!(client.base_url(), "http://localhost:8081");
1051 }
1052
1053 #[test]
1054 fn test_arrow_to_avro_schema_simple() {
1055 let schema = Arc::new(Schema::new(vec![
1056 Field::new("id", DataType::Int64, false),
1057 Field::new("name", DataType::Utf8, false),
1058 ]));
1059
1060 let avro_str = arrow_to_avro_schema(&schema, "test_record").unwrap();
1061 let avro: serde_json::Value = serde_json::from_str(&avro_str).unwrap();
1062
1063 assert_eq!(avro["type"], "record");
1064 assert_eq!(avro["name"], "test_record");
1065
1066 let fields = avro["fields"].as_array().unwrap();
1067 assert_eq!(fields.len(), 2);
1068 assert_eq!(fields[0]["name"], "id");
1069 assert_eq!(fields[0]["type"], "long");
1070 assert_eq!(fields[1]["name"], "name");
1071 assert_eq!(fields[1]["type"], "string");
1072 }
1073
1074 #[test]
1075 fn test_arrow_to_avro_schema_sanitizes_hyphens() {
1076 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
1077
1078 let avro_str = arrow_to_avro_schema(&schema, "trades-avro-output").unwrap();
1079 let avro: serde_json::Value = serde_json::from_str(&avro_str).unwrap();
1080 assert_eq!(avro["name"], "trades_avro_output");
1081 }
1082
1083 #[test]
1084 fn test_arrow_to_avro_schema_nullable() {
1085 let schema = Arc::new(Schema::new(vec![
1086 Field::new("id", DataType::Int64, false),
1087 Field::new("email", DataType::Utf8, true),
1088 ]));
1089
1090 let avro_str = arrow_to_avro_schema(&schema, "record").unwrap();
1091 let avro: serde_json::Value = serde_json::from_str(&avro_str).unwrap();
1092
1093 let fields = avro["fields"].as_array().unwrap();
1094 assert_eq!(fields[0]["type"], "long");
1096 let union = fields[1]["type"].as_array().unwrap();
1098 assert_eq!(union.len(), 2);
1099 assert_eq!(union[0], "null");
1100 assert_eq!(union[1], "string");
1101 }
1102
1103 #[test]
1104 fn test_arrow_to_avro_all_primitives() {
1105 let schema = Arc::new(Schema::new(vec![
1106 Field::new("b", DataType::Boolean, false),
1107 Field::new("i32", DataType::Int32, false),
1108 Field::new("i64", DataType::Int64, false),
1109 Field::new("f32", DataType::Float32, false),
1110 Field::new("f64", DataType::Float64, false),
1111 Field::new("s", DataType::Utf8, false),
1112 Field::new("bin", DataType::Binary, false),
1113 ]));
1114
1115 let avro_str = arrow_to_avro_schema(&schema, "all_types").unwrap();
1116 let avro: serde_json::Value = serde_json::from_str(&avro_str).unwrap();
1117 let fields = avro["fields"].as_array().unwrap();
1118
1119 assert_eq!(fields[0]["type"], "boolean");
1120 assert_eq!(fields[1]["type"], "int");
1121 assert_eq!(fields[2]["type"], "long");
1122 assert_eq!(fields[3]["type"], "float");
1123 assert_eq!(fields[4]["type"], "double");
1124 assert_eq!(fields[5]["type"], "string");
1125 assert_eq!(fields[6]["type"], "bytes");
1126 }
1127
1128 #[test]
1129 fn test_arrow_to_avro_roundtrip() {
1130 let original = Arc::new(Schema::new(vec![
1131 Field::new("id", DataType::Int64, false),
1132 Field::new("name", DataType::Utf8, true),
1133 Field::new("active", DataType::Boolean, false),
1134 ]));
1135
1136 let avro_str = arrow_to_avro_schema(&original, "roundtrip").unwrap();
1137 let recovered = avro_to_arrow_schema(&avro_str).unwrap();
1138
1139 assert_eq!(recovered.fields().len(), 3);
1140 assert_eq!(recovered.field(0).data_type(), &DataType::Int64);
1141 assert!(!recovered.field(0).is_nullable());
1142 assert_eq!(recovered.field(1).data_type(), &DataType::Utf8);
1143 assert!(recovered.field(1).is_nullable());
1144 assert_eq!(recovered.field(2).data_type(), &DataType::Boolean);
1145 }
1146
1147 #[test]
1150 fn test_avro_to_arrow_array_type() {
1151 let avro = r#"{
1152 "type": "record",
1153 "name": "test",
1154 "fields": [
1155 {"name": "tags", "type": {"type": "array", "items": "string"}}
1156 ]
1157 }"#;
1158
1159 let schema = avro_to_arrow_schema(avro).unwrap();
1160 assert_eq!(schema.fields().len(), 1);
1161 match schema.field(0).data_type() {
1162 DataType::List(item) => {
1163 assert_eq!(item.data_type(), &DataType::Utf8);
1164 }
1165 other => panic!("expected List, got {other:?}"),
1166 }
1167 }
1168
1169 #[test]
1170 fn test_avro_to_arrow_map_type() {
1171 let avro = r#"{
1172 "type": "record",
1173 "name": "test",
1174 "fields": [
1175 {"name": "metadata", "type": {"type": "map", "values": "long"}}
1176 ]
1177 }"#;
1178
1179 let schema = avro_to_arrow_schema(avro).unwrap();
1180 assert_eq!(schema.fields().len(), 1);
1181 match schema.field(0).data_type() {
1182 DataType::Map(entries, _) => {
1183 if let DataType::Struct(fields) = entries.data_type() {
1184 assert_eq!(fields.len(), 2);
1185 assert_eq!(fields[0].name(), "key");
1186 assert_eq!(fields[0].data_type(), &DataType::Utf8);
1187 assert_eq!(fields[1].name(), "value");
1188 assert_eq!(fields[1].data_type(), &DataType::Int64);
1189 } else {
1190 panic!("expected Struct entries");
1191 }
1192 }
1193 other => panic!("expected Map, got {other:?}"),
1194 }
1195 }
1196
1197 #[test]
1198 fn test_avro_to_arrow_nested_record() {
1199 let avro = r#"{
1200 "type": "record",
1201 "name": "test",
1202 "fields": [
1203 {
1204 "name": "address",
1205 "type": {
1206 "type": "record",
1207 "name": "Address",
1208 "fields": [
1209 {"name": "street", "type": "string"},
1210 {"name": "zip", "type": "int"}
1211 ]
1212 }
1213 }
1214 ]
1215 }"#;
1216
1217 let schema = avro_to_arrow_schema(avro).unwrap();
1218 assert_eq!(schema.fields().len(), 1);
1219 match schema.field(0).data_type() {
1220 DataType::Struct(fields) => {
1221 assert_eq!(fields.len(), 2);
1222 assert_eq!(fields[0].name(), "street");
1223 assert_eq!(fields[0].data_type(), &DataType::Utf8);
1224 assert_eq!(fields[1].name(), "zip");
1225 assert_eq!(fields[1].data_type(), &DataType::Int32);
1226 }
1227 other => panic!("expected Struct, got {other:?}"),
1228 }
1229 }
1230
1231 #[test]
1232 fn test_avro_to_arrow_enum_type() {
1233 let avro = r#"{
1234 "type": "record",
1235 "name": "test",
1236 "fields": [
1237 {
1238 "name": "status",
1239 "type": {
1240 "type": "enum",
1241 "name": "Status",
1242 "symbols": ["ACTIVE", "INACTIVE", "PENDING"]
1243 }
1244 }
1245 ]
1246 }"#;
1247
1248 let schema = avro_to_arrow_schema(avro).unwrap();
1249 assert_eq!(schema.fields().len(), 1);
1250 match schema.field(0).data_type() {
1251 DataType::Dictionary(key, value) => {
1252 assert_eq!(key.as_ref(), &DataType::Int32);
1253 assert_eq!(value.as_ref(), &DataType::Utf8);
1254 }
1255 other => panic!("expected Dictionary, got {other:?}"),
1256 }
1257 }
1258
1259 #[test]
1260 fn test_avro_to_arrow_fixed_type() {
1261 let avro = r#"{
1262 "type": "record",
1263 "name": "test",
1264 "fields": [
1265 {
1266 "name": "uuid",
1267 "type": {"type": "fixed", "name": "uuid", "size": 16}
1268 }
1269 ]
1270 }"#;
1271
1272 let schema = avro_to_arrow_schema(avro).unwrap();
1273 assert_eq!(schema.fields().len(), 1);
1274 assert_eq!(schema.field(0).data_type(), &DataType::FixedSizeBinary(16));
1275 }
1276
1277 #[test]
1278 fn test_avro_to_arrow_nullable_complex_in_union() {
1279 let avro = r#"{
1280 "type": "record",
1281 "name": "test",
1282 "fields": [
1283 {
1284 "name": "tags",
1285 "type": ["null", {"type": "array", "items": "string"}]
1286 }
1287 ]
1288 }"#;
1289
1290 let schema = avro_to_arrow_schema(avro).unwrap();
1291 assert!(schema.field(0).is_nullable());
1292 assert!(matches!(schema.field(0).data_type(), DataType::List(_)));
1293 }
1294
1295 #[test]
1296 fn test_avro_array_missing_items() {
1297 let avro = r#"{
1298 "type": "record",
1299 "name": "test",
1300 "fields": [
1301 {"name": "bad", "type": {"type": "array"}}
1302 ]
1303 }"#;
1304 assert!(avro_to_arrow_schema(avro).is_err());
1305 }
1306
1307 #[test]
1308 fn test_avro_map_missing_values() {
1309 let avro = r#"{
1310 "type": "record",
1311 "name": "test",
1312 "fields": [
1313 {"name": "bad", "type": {"type": "map"}}
1314 ]
1315 }"#;
1316 assert!(avro_to_arrow_schema(avro).is_err());
1317 }
1318
1319 #[test]
1320 fn test_arrow_to_avro_array_type() {
1321 let schema = Arc::new(Schema::new(vec![Field::new(
1322 "tags",
1323 DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
1324 false,
1325 )]));
1326
1327 let avro_str = arrow_to_avro_schema(&schema, "test").unwrap();
1328 let avro: serde_json::Value = serde_json::from_str(&avro_str).unwrap();
1329 let field = &avro["fields"][0];
1330 assert_eq!(field["type"]["type"], "array");
1331 assert_eq!(field["type"]["items"], "string");
1332 }
1333
1334 #[test]
1335 fn test_arrow_to_avro_map_type() {
1336 let schema = Arc::new(Schema::new(vec![Field::new(
1337 "metadata",
1338 DataType::Map(
1339 Arc::new(Field::new(
1340 "entries",
1341 DataType::Struct(Fields::from(vec![
1342 Field::new("key", DataType::Utf8, false),
1343 Field::new("value", DataType::Int64, true),
1344 ])),
1345 false,
1346 )),
1347 false,
1348 ),
1349 false,
1350 )]));
1351
1352 let avro_str = arrow_to_avro_schema(&schema, "test").unwrap();
1353 let avro: serde_json::Value = serde_json::from_str(&avro_str).unwrap();
1354 let field = &avro["fields"][0];
1355 assert_eq!(field["type"]["type"], "map");
1356 assert_eq!(field["type"]["values"], "long");
1357 }
1358
1359 #[test]
1360 fn test_arrow_to_avro_struct_type() {
1361 let schema = Arc::new(Schema::new(vec![Field::new(
1362 "address",
1363 DataType::Struct(Fields::from(vec![
1364 Field::new("street", DataType::Utf8, false),
1365 Field::new("zip", DataType::Int32, false),
1366 ])),
1367 false,
1368 )]));
1369
1370 let avro_str = arrow_to_avro_schema(&schema, "test").unwrap();
1371 let avro: serde_json::Value = serde_json::from_str(&avro_str).unwrap();
1372 let field = &avro["fields"][0];
1373 assert_eq!(field["type"]["type"], "record");
1374 let nested = field["type"]["fields"].as_array().unwrap();
1375 assert_eq!(nested.len(), 2);
1376 assert_eq!(nested[0]["name"], "street");
1377 assert_eq!(nested[0]["type"], "string");
1378 assert_eq!(nested[1]["name"], "zip");
1379 assert_eq!(nested[1]["type"], "int");
1380 }
1381
1382 #[test]
1383 fn test_arrow_to_avro_fixed_type() {
1384 let schema = Arc::new(Schema::new(vec![Field::new(
1385 "uuid",
1386 DataType::FixedSizeBinary(16),
1387 false,
1388 )]));
1389
1390 let avro_str = arrow_to_avro_schema(&schema, "test").unwrap();
1391 let avro: serde_json::Value = serde_json::from_str(&avro_str).unwrap();
1392 let field = &avro["fields"][0];
1393 assert_eq!(field["type"]["type"], "fixed");
1394 assert_eq!(field["type"]["size"], 16);
1395 }
1396
1397 fn make_cached_schema(id: i32) -> CachedSchema {
1400 CachedSchema {
1401 id,
1402 version: 1,
1403 schema_type: SchemaType::Avro,
1404 schema_str: format!(
1405 r#"{{"type":"record","name":"t{id}","fields":[{{"name":"x","type":"int"}}]}}"#
1406 ),
1407 arrow_schema: Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)])),
1408 inserted_at: Instant::now(),
1409 }
1410 }
1411
1412 #[test]
1413 fn test_cache_config_defaults() {
1414 let config = SchemaRegistryCacheConfig::default();
1415 assert_eq!(config.max_entries, 1000);
1416 assert_eq!(config.ttl, Some(Duration::from_secs(3600)));
1417 }
1418
1419 #[test]
1420 fn test_cache_lru_eviction() {
1421 let config = SchemaRegistryCacheConfig {
1422 max_entries: 3,
1423 ttl: None,
1424 };
1425 let client = SchemaRegistryClient::with_cache_config("http://localhost:8081", None, config);
1426
1427 client.cache_insert(1, make_cached_schema(1));
1429 client.cache_insert(2, make_cached_schema(2));
1430 client.cache_insert(3, make_cached_schema(3));
1431 assert_eq!(client.cache_size(), 3);
1432
1433 client.cache_insert(4, make_cached_schema(4));
1435 assert!(client.cache_size() <= 3);
1436 assert!(client.cache_get(4).is_some());
1438 }
1439
1440 #[test]
1441 fn test_cache_ttl_expiration() {
1442 let config = SchemaRegistryCacheConfig {
1446 max_entries: 100,
1447 ttl: Some(Duration::from_millis(1000)),
1448 };
1449 let client = SchemaRegistryClient::with_cache_config("http://localhost:8081", None, config);
1450
1451 client.cache_insert(1, make_cached_schema(1));
1452 assert!(client.cache_get(1).is_some());
1453
1454 std::thread::sleep(Duration::from_millis(1200));
1456 assert!(client.cache_get(1).is_none());
1458 }
1459
1460 #[test]
1461 fn test_cache_no_ttl() {
1462 let config = SchemaRegistryCacheConfig {
1463 max_entries: 100,
1464 ttl: None,
1465 };
1466 let client = SchemaRegistryClient::with_cache_config("http://localhost:8081", None, config);
1467
1468 client.cache_insert(1, make_cached_schema(1));
1469 assert!(client.cache_get(1).is_some());
1471 }
1472
1473 #[test]
1474 fn test_cache_replace_existing_id() {
1475 let config = SchemaRegistryCacheConfig {
1476 max_entries: 10,
1477 ttl: None,
1478 };
1479 let client = SchemaRegistryClient::with_cache_config("http://localhost:8081", None, config);
1480
1481 client.cache_insert(1, make_cached_schema(1));
1482 client.cache_insert(2, make_cached_schema(2));
1483 assert_eq!(client.cache_size(), 2);
1484
1485 client.cache_insert(1, make_cached_schema(1));
1487 assert_eq!(client.cache_size(), 2);
1488 }
1489
1490 #[test]
1491 fn test_schema_incompatible_error_via_serde() {
1492 let err = SerdeError::SchemaIncompatible {
1493 subject: "orders-value".into(),
1494 message: "READER_FIELD_MISSING_DEFAULT_VALUE: field 'new_field'".into(),
1495 };
1496 let conn_err: ConnectorError = err.into();
1497 assert!(matches!(
1498 conn_err,
1499 ConnectorError::Serde(SerdeError::SchemaIncompatible { .. })
1500 ));
1501 assert!(conn_err.to_string().contains("orders-value"));
1502 }
1503
1504 #[test]
1505 fn test_validate_and_register_method_exists() {
1506 let client = SchemaRegistryClient::new("http://localhost:8081", None);
1508 let _ = &client;
1510 }
1511
1512 #[test]
1513 fn test_complex_type_roundtrip() {
1514 let avro = r#"{
1515 "type": "record",
1516 "name": "test",
1517 "fields": [
1518 {"name": "tags", "type": {"type": "array", "items": "string"}},
1519 {"name": "metadata", "type": {"type": "map", "values": "long"}}
1520 ]
1521 }"#;
1522
1523 let arrow_schema = avro_to_arrow_schema(avro).unwrap();
1524 assert!(matches!(
1525 arrow_schema.field(0).data_type(),
1526 DataType::List(_)
1527 ));
1528 assert!(matches!(
1529 arrow_schema.field(1).data_type(),
1530 DataType::Map(_, _)
1531 ));
1532
1533 let avro_str = arrow_to_avro_schema(&arrow_schema, "test").unwrap();
1535 let recovered = avro_to_arrow_schema(&avro_str).unwrap();
1536
1537 assert!(matches!(recovered.field(0).data_type(), DataType::List(_)));
1538 assert!(matches!(
1539 recovered.field(1).data_type(),
1540 DataType::Map(_, _)
1541 ));
1542 }
1543}