1use std::collections::VecDeque;
7use std::sync::atomic::{AtomicU32, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use arrow_schema::SchemaRef;
12use async_nats::jetstream::{self, consumer::pull};
13use async_trait::async_trait;
14use bytes::Bytes;
15use crossfire::{mpsc, AsyncRx, MAsyncTx, TryRecvError};
16use futures_util::StreamExt;
17use rustc_hash::FxHashMap;
18use tokio::sync::Notify;
19use tokio::task::JoinHandle;
20use tracing::{debug, warn};
21
22use super::config::{build_connect_options, AckPolicy, DeliverPolicy, Mode, NatsSourceConfig};
23use super::metrics::NatsSourceMetrics;
24use crate::checkpoint::SourceCheckpoint;
25use crate::config::ConnectorConfig;
26use crate::connector::{PartitionInfo, SourceBatch, SourceConnector};
27use crate::error::ConnectorError;
28use crate::health::HealthStatus;
29use crate::metrics::ConnectorMetrics;
30use crate::serde::{self, RecordDeserializer};
31
32struct Incoming {
34 subject: String,
35 payload: Bytes,
36 stream_seq: Option<u64>,
37 ack: Option<jetstream::Message>,
38}
39
40struct Running {
41 deserializer: Box<dyn RecordDeserializer>,
42 rx: AsyncRx<mpsc::Array<Incoming>>,
43 shutdown: Arc<Notify>,
44 consecutive_errors: Option<Arc<AtomicU32>>,
46 handle: JoinHandle<()>,
47}
48
49pub struct NatsSource {
51 schema: SchemaRef,
52 config: Option<NatsSourceConfig>,
53 data_ready: Arc<Notify>,
54 metrics: NatsSourceMetrics,
55 running: Option<Running>,
56
57 pending: parking_lot::Mutex<Vec<jetstream::Message>>,
60 sealed: parking_lot::Mutex<VecDeque<Vec<jetstream::Message>>>,
61 offsets: parking_lot::Mutex<FxHashMap<String, u64>>,
62}
63
64impl NatsSource {
65 #[must_use]
67 pub fn new(schema: SchemaRef, registry: Option<&prometheus::Registry>) -> Self {
68 Self {
69 schema,
70 config: None,
71 data_ready: Arc::new(Notify::new()),
72 metrics: NatsSourceMetrics::new(registry),
73 running: None,
74 pending: parking_lot::Mutex::new(Vec::new()),
75 sealed: parking_lot::Mutex::new(VecDeque::new()),
76 offsets: parking_lot::Mutex::new(FxHashMap::default()),
77 }
78 }
79
80 fn update_pending_gauge(&self) {
81 let pending_len = self.pending.lock().len();
82 let sealed_total: usize = self.sealed.lock().iter().map(Vec::len).sum();
83 self.metrics.set_pending_acks(pending_len + sealed_total);
84 }
85
86 #[must_use]
88 pub fn config(&self) -> Option<&NatsSourceConfig> {
89 self.config.as_ref()
90 }
91
92 async fn open_jetstream(
93 &mut self,
94 cfg: &NatsSourceConfig,
95 deserializer: Box<dyn RecordDeserializer>,
96 ) -> Result<(), ConnectorError> {
97 let client = connect(cfg).await?;
98 let js = jetstream::new(client);
99
100 let stream_name = cfg
101 .stream
102 .as_deref()
103 .ok_or_else(|| err("stream name missing after validation"))?;
104 let consumer_name = cfg
105 .consumer
106 .as_deref()
107 .ok_or_else(|| err("consumer name missing after validation"))?;
108
109 let pull_cfg = build_pull_config(cfg, consumer_name)?;
110 let stream = js
111 .get_stream(stream_name)
112 .await
113 .map_err(|e| err(&format!("get_stream('{stream_name}') failed: {e}")))?;
114 let consumer = stream
115 .create_consumer(pull_cfg)
116 .await
117 .map_err(|e| classify_create_consumer_error(&e, consumer_name))?;
118
119 let (tx, rx) = mpsc::bounded_async::<Incoming>(cfg.fetch_batch * 2);
120 let shutdown = Arc::new(Notify::new());
121 let consecutive_errors = Arc::new(AtomicU32::new(0));
122
123 let reader = JsReader {
124 consumer,
125 tx,
126 shutdown: Arc::clone(&shutdown),
127 consecutive_errors: Arc::clone(&consecutive_errors),
128 data_ready: Arc::clone(&self.data_ready),
129 metrics: self.metrics.clone(),
130 batch_size: cfg.fetch_batch,
131 max_wait: cfg.fetch_max_wait,
132 lag_poll_interval: cfg.lag_poll_interval,
133 };
134 let handle = tokio::spawn(reader.run());
135
136 self.running = Some(Running {
137 deserializer,
138 rx,
139 shutdown,
140 consecutive_errors: Some(consecutive_errors),
141 handle,
142 });
143 Ok(())
144 }
145
146 async fn open_core(
147 &mut self,
148 cfg: &NatsSourceConfig,
149 deserializer: Box<dyn RecordDeserializer>,
150 ) -> Result<(), ConnectorError> {
151 let client = connect(cfg).await?;
152 let subject = cfg
153 .subject
154 .clone()
155 .ok_or_else(|| err("subject missing after validation"))?;
156 let subscriber = if let Some(group) = cfg.queue_group.as_deref() {
157 client
158 .queue_subscribe(subject, group.to_string())
159 .await
160 .map_err(|e| err(&format!("queue_subscribe: {e}")))?
161 } else {
162 client
163 .subscribe(subject)
164 .await
165 .map_err(|e| err(&format!("subscribe: {e}")))?
166 };
167
168 let (tx, rx) = mpsc::bounded_async::<Incoming>(cfg.fetch_batch * 2);
169 let shutdown = Arc::new(Notify::new());
170
171 let reader = CoreReader {
172 subscriber,
173 tx,
174 shutdown: Arc::clone(&shutdown),
175 data_ready: Arc::clone(&self.data_ready),
176 };
177 let handle = tokio::spawn(reader.run());
178
179 self.running = Some(Running {
180 deserializer,
181 rx,
182 shutdown,
183 consecutive_errors: None,
184 handle,
185 });
186 Ok(())
187 }
188}
189
190#[async_trait]
191impl SourceConnector for NatsSource {
192 async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
193 let cfg = NatsSourceConfig::from_config(config)?;
194 if let Some(schema) = config.arrow_schema() {
196 self.schema = schema;
197 }
198 let deserializer = serde::create_deserializer(cfg.format)
199 .map_err(|e| err(&format!("deserializer for format {:?}: {e}", cfg.format)))?;
200 match cfg.mode {
201 Mode::JetStream => self.open_jetstream(&cfg, deserializer).await?,
202 Mode::Core => self.open_core(&cfg, deserializer).await?,
203 }
204 self.config = Some(cfg);
205 Ok(())
206 }
207
208 async fn poll_batch(
209 &mut self,
210 max_records: usize,
211 ) -> Result<Option<SourceBatch>, ConnectorError> {
212 let Some(running) = self.running.as_mut() else {
213 return Ok(None);
214 };
215
216 let mut payloads: Vec<Bytes> = Vec::new();
217 let mut partition: Option<PartitionInfo> = None;
218 let mut new_acks: Vec<jetstream::Message> = Vec::new();
219 let mut offset_updates: Vec<(String, u64)> = Vec::new();
220
221 while payloads.len() < max_records {
222 let incoming = match running.rx.try_recv() {
223 Ok(m) => m,
224 Err(TryRecvError::Empty | TryRecvError::Disconnected) => break,
225 };
226 if let Some(seq) = incoming.stream_seq {
227 offset_updates.push((incoming.subject.clone(), seq));
228 if partition.is_none() {
229 partition = Some(PartitionInfo::new(&incoming.subject, seq.to_string()));
230 }
231 }
232 payloads.push(incoming.payload);
233 if let Some(msg) = incoming.ack {
234 new_acks.push(msg);
235 }
236 }
237
238 if payloads.is_empty() {
239 return Ok(None);
240 }
241
242 let records: Vec<&[u8]> = payloads.iter().map(Bytes::as_ref).collect();
243 let bytes_total: u64 = records.iter().map(|r| r.len() as u64).sum();
244 let batch = running
247 .deserializer
248 .deserialize_batch(&records, &self.schema)
249 .map_err(|e| err(&format!("deserialize batch: {e}")))?;
250
251 if !offset_updates.is_empty() {
252 let mut offsets = self.offsets.lock();
253 for (subject, seq) in offset_updates {
254 offsets
255 .entry(subject)
256 .and_modify(|s| *s = (*s).max(seq))
257 .or_insert(seq);
258 }
259 }
260 if !new_acks.is_empty() {
261 self.pending.lock().extend(new_acks);
262 }
263
264 self.metrics
265 .record_poll(batch.num_rows() as u64, bytes_total);
266 self.update_pending_gauge();
267
268 Ok(Some(SourceBatch {
269 records: batch,
270 partition,
271 }))
272 }
273
274 fn schema(&self) -> SchemaRef {
275 self.schema.clone()
276 }
277
278 fn checkpoint(&self) -> SourceCheckpoint {
279 {
283 let mut pending = self.pending.lock();
284 if !pending.is_empty() {
285 let batch = std::mem::take(&mut *pending);
286 self.sealed.lock().push_back(batch);
287 }
288 }
289 let mut cp = SourceCheckpoint::default();
290 for (subject, seq) in &*self.offsets.lock() {
291 cp.set_offset(subject.as_str(), seq.to_string());
292 }
293 cp
294 }
295
296 async fn restore(&mut self, _checkpoint: &SourceCheckpoint) -> Result<(), ConnectorError> {
297 Ok(())
299 }
300
301 async fn close(&mut self) -> Result<(), ConnectorError> {
302 let Some(running) = self.running.take() else {
303 return Ok(());
304 };
305 running.shutdown.notify_one();
306 let _ = tokio::time::timeout(Duration::from_secs(5), running.handle).await;
307 Ok(())
308 }
309
310 fn data_ready_notify(&self) -> Option<Arc<Notify>> {
311 Some(Arc::clone(&self.data_ready))
312 }
313
314 fn supports_replay(&self) -> bool {
315 !matches!(self.config.as_ref().map(|c| c.mode), Some(Mode::Core))
317 }
318
319 async fn notify_epoch_committed(&mut self, _epoch: u64) -> Result<(), ConnectorError> {
320 loop {
323 let Some(batch) = self.sealed.lock().pop_front() else {
324 break;
325 };
326 for msg in batch {
327 match msg.ack().await {
328 Ok(()) => self.metrics.record_ack(),
329 Err(e) => {
330 self.metrics.record_ack_error();
331 warn!(error = %e, "JetStream ack failed; broker will redeliver");
332 }
333 }
334 }
335 }
336 self.update_pending_gauge();
337 Ok(())
338 }
339
340 fn health_check(&self) -> HealthStatus {
341 let Some(cfg) = self.config.as_ref() else {
342 return HealthStatus::Unknown;
343 };
344
345 if cfg.fetch_error_threshold > 0 {
347 if let Some(errors) = self
348 .running
349 .as_ref()
350 .and_then(|r| r.consecutive_errors.as_ref())
351 {
352 let errs = errors.load(Ordering::Acquire);
353 if errs >= cfg.fetch_error_threshold {
354 return HealthStatus::Unhealthy(format!(
355 "{errs} consecutive fetch errors (threshold {})",
356 cfg.fetch_error_threshold
357 ));
358 }
359 }
360 }
361
362 if cfg.max_ack_pending > 0 {
364 #[allow(clippy::cast_sign_loss)]
365 let cap = cfg.max_ack_pending as u64;
366 #[allow(clippy::cast_sign_loss)]
367 let pending = self.metrics.pending_acks.get().max(0) as u64;
368 if pending * 2 >= cap {
369 return HealthStatus::Degraded(format!(
370 "pending acks {pending}/{cap} — broker may throttle delivery"
371 ));
372 }
373 }
374 HealthStatus::Healthy
375 }
376
377 fn metrics(&self) -> ConnectorMetrics {
378 self.metrics.to_connector_metrics()
379 }
380}
381
382fn err(msg: &str) -> ConnectorError {
385 ConnectorError::ConfigurationError(msg.to_string())
386}
387
388async fn connect(cfg: &NatsSourceConfig) -> Result<async_nats::Client, ConnectorError> {
389 build_connect_options(&cfg.auth, &cfg.tls)?
390 .connect(&cfg.servers)
391 .await
392 .map_err(|e| err(&format!("nats connect({:?}): {e}", cfg.servers)))
393}
394
395fn build_pull_config(
396 cfg: &NatsSourceConfig,
397 consumer_name: &str,
398) -> Result<pull::Config, ConnectorError> {
399 let filter_subjects = if cfg.subject_filters.is_empty() {
400 cfg.subject.iter().cloned().collect()
401 } else {
402 cfg.subject_filters.clone()
403 };
404
405 Ok(pull::Config {
406 durable_name: Some(consumer_name.to_string()),
407 filter_subjects,
408 deliver_policy: map_deliver_policy(cfg)?,
409 ack_policy: map_ack_policy(cfg.ack_policy),
410 ack_wait: cfg.ack_wait,
411 max_deliver: cfg.max_deliver,
412 max_ack_pending: cfg.max_ack_pending,
413 ..Default::default()
414 })
415}
416
417fn map_deliver_policy(
418 cfg: &NatsSourceConfig,
419) -> Result<async_nats::jetstream::consumer::DeliverPolicy, ConnectorError> {
420 use async_nats::jetstream::consumer::DeliverPolicy as Nats;
421 Ok(match cfg.deliver_policy {
422 DeliverPolicy::All => Nats::All,
423 DeliverPolicy::New => Nats::New,
424 DeliverPolicy::ByStartSequence => Nats::ByStartSequence {
425 start_sequence: cfg.start_sequence.unwrap_or(1),
426 },
427 DeliverPolicy::ByStartTime => {
428 let raw = cfg
429 .start_time
430 .as_deref()
431 .ok_or_else(|| err("deliver.policy=by_start_time requires 'start.time'"))?;
432 let start_time =
433 time::OffsetDateTime::parse(raw, &time::format_description::well_known::Rfc3339)
434 .map_err(|e| err(&format!("start.time '{raw}' is not valid RFC3339: {e}")))?;
435 Nats::ByStartTime { start_time }
436 }
437 })
438}
439
440fn map_ack_policy(p: AckPolicy) -> async_nats::jetstream::consumer::AckPolicy {
441 use async_nats::jetstream::consumer::AckPolicy as Nats;
442 match p {
443 AckPolicy::Explicit => Nats::Explicit,
444 AckPolicy::None => Nats::None,
445 }
446}
447
448fn fetch_backoff_base(consecutive_errors: u32) -> Duration {
450 let exp = consecutive_errors.saturating_sub(1).min(4);
451 let ms = 500u64.saturating_mul(1u64 << exp);
452 Duration::from_millis(ms.min(5000))
453}
454
455fn with_jitter(base: Duration, entropy: u64) -> Duration {
457 let base_ms = u64::try_from(base.as_millis()).unwrap_or(u64::MAX);
458 let range = (base_ms / 5).max(1); let window = range * 2 + 1;
460 let offset = entropy % window;
461 let jittered = base_ms.saturating_add(offset).saturating_sub(range);
462 Duration::from_millis(jittered)
463}
464
465fn fetch_backoff(consecutive_errors: u32, entropy: u64) -> Duration {
466 with_jitter(fetch_backoff_base(consecutive_errors), entropy)
467}
468
469fn classify_create_consumer_error(
472 e: &async_nats::jetstream::stream::ConsumerError,
473 consumer_name: &str,
474) -> ConnectorError {
475 use async_nats::jetstream::stream::ConsumerErrorKind;
476 use async_nats::jetstream::ErrorCode;
477
478 let drift_code = match e.kind() {
479 ConsumerErrorKind::JetStream(server_err) => matches!(
480 server_err.error_code(),
481 ErrorCode::CONSUMER_ALREADY_EXISTS | ErrorCode::CONSUMER_NAME_EXIST
482 ),
483 _ => false,
484 };
485 if drift_code {
486 err(&format!(
487 "[LDB-5070] consumer '{consumer_name}' exists with incompatible config; \
488 rotate the durable name or delete the consumer out-of-band. \
489 Server said: {e}"
490 ))
491 } else {
492 err(&format!("create_consumer('{consumer_name}') failed: {e}"))
493 }
494}
495
496#[allow(clippy::cast_possible_truncation)]
499fn entropy_now() -> u64 {
500 std::time::SystemTime::now()
501 .duration_since(std::time::UNIX_EPOCH)
502 .unwrap_or_default()
503 .as_nanos() as u64
504}
505
506struct JsReader {
507 consumer: jetstream::consumer::Consumer<pull::Config>,
508 tx: MAsyncTx<mpsc::Array<Incoming>>,
509 shutdown: Arc<Notify>,
510 consecutive_errors: Arc<AtomicU32>,
511 data_ready: Arc<Notify>,
512 metrics: NatsSourceMetrics,
513 batch_size: usize,
514 max_wait: Duration,
515 lag_poll_interval: Duration,
517}
518
519impl JsReader {
520 async fn run(self) {
521 let Self {
522 mut consumer,
523 tx,
524 shutdown,
525 consecutive_errors,
526 data_ready,
527 metrics,
528 batch_size,
529 max_wait,
530 lag_poll_interval,
531 } = self;
532
533 let mut last_lag_poll = Instant::now();
534 let lag_poll_enabled = !lag_poll_interval.is_zero();
535
536 loop {
537 let fetch_result = tokio::select! {
538 biased;
539 () = shutdown.notified() => break,
540 r = consumer.fetch().max_messages(batch_size).expires(max_wait).messages() => r,
541 };
542
543 let mut stream = match fetch_result {
544 Ok(s) => s,
545 Err(e) => {
546 let errs = consecutive_errors.fetch_add(1, Ordering::AcqRel) + 1;
547 metrics.record_fetch_error();
548 warn!(
549 error = %e,
550 consecutive_errors = errs,
551 "nats fetch() errored; backing off",
552 );
553 let backoff = fetch_backoff(errs, entropy_now());
554 tokio::select! {
555 biased;
556 () = shutdown.notified() => break,
557 () = tokio::time::sleep(backoff) => {}
558 }
559 continue;
560 }
561 };
562
563 let mut forwarded = 0usize;
564 let mut stream_errors = 0usize;
565 loop {
566 let msg_result = tokio::select! {
567 biased;
568 () = shutdown.notified() => return,
569 r = stream.next() => match r {
570 Some(r) => r,
571 None => break,
572 },
573 };
574 let msg = match msg_result {
575 Ok(m) => m,
576 Err(e) => {
577 metrics.record_fetch_error();
578 stream_errors += 1;
579 warn!(error = %e, "nats message error");
580 continue;
581 }
582 };
583 let incoming = Incoming {
584 subject: msg.subject.to_string(),
585 payload: msg.payload.clone(),
586 stream_seq: msg.info().ok().map(|i| i.stream_sequence),
587 ack: Some(msg),
588 };
589 if tx.send(incoming).await.is_err() {
590 debug!("nats reader: downstream channel closed");
591 return;
592 }
593 forwarded += 1;
594 }
595
596 if forwarded > 0 {
599 consecutive_errors.store(0, Ordering::Release);
600 data_ready.notify_one();
601 } else if stream_errors > 0 {
602 let errs = consecutive_errors.fetch_add(1, Ordering::AcqRel) + 1;
603 let backoff = fetch_backoff(errs, entropy_now());
604 tokio::select! {
605 biased;
606 () = shutdown.notified() => break,
607 () = tokio::time::sleep(backoff) => {}
608 }
609 }
610
611 if lag_poll_enabled && last_lag_poll.elapsed() >= lag_poll_interval {
612 last_lag_poll = Instant::now();
613 match consumer.info().await {
614 Ok(info) => metrics.set_consumer_lag(info.num_pending),
615 Err(e) => warn!(error = %e, "consumer.info() failed; skipping lag update"),
616 }
617 }
618 }
619 }
620}
621
622struct CoreReader {
623 subscriber: async_nats::Subscriber,
624 tx: MAsyncTx<mpsc::Array<Incoming>>,
625 shutdown: Arc<Notify>,
626 data_ready: Arc<Notify>,
627}
628
629impl CoreReader {
630 async fn run(self) {
631 let Self {
632 mut subscriber,
633 tx,
634 shutdown,
635 data_ready,
636 } = self;
637
638 loop {
639 let msg = tokio::select! {
640 biased;
641 () = shutdown.notified() => break,
642 m = subscriber.next() => match m {
643 Some(m) => m,
644 None => break,
645 },
646 };
647 let incoming = Incoming {
648 subject: msg.subject.to_string(),
649 payload: msg.payload,
650 stream_seq: None,
651 ack: None,
652 };
653 if tx.send(incoming).await.is_err() {
654 return;
655 }
656 data_ready.notify_one();
657 }
658 }
659}
660
661#[cfg(test)]
662mod tests {
663 use super::*;
664 use arrow_schema::Schema;
665
666 #[test]
667 fn checkpoint_empty_pending_is_noop() {
668 let src = NatsSource::new(Arc::new(Schema::empty()), None);
672 let _ = src.checkpoint();
673 assert!(src.sealed.lock().is_empty());
674 }
675
676 #[test]
677 fn backoff_base_grows_then_caps_at_5s() {
678 assert_eq!(fetch_backoff_base(1), Duration::from_millis(500));
679 assert_eq!(fetch_backoff_base(2), Duration::from_millis(1000));
680 assert_eq!(fetch_backoff_base(3), Duration::from_millis(2000));
681 assert_eq!(fetch_backoff_base(4), Duration::from_millis(4000));
682 assert_eq!(fetch_backoff_base(5), Duration::from_millis(5000));
683 assert_eq!(fetch_backoff_base(100), Duration::from_millis(5000));
684 }
685
686 #[test]
687 fn jitter_stays_within_plus_minus_20_percent() {
688 let base = Duration::from_millis(1000);
689 for entropy in [0u64, 1, 99, 12345, u64::MAX] {
690 let j = with_jitter(base, entropy);
691 assert!(
692 j >= Duration::from_millis(800) && j <= Duration::from_millis(1200),
693 "entropy {entropy}: jittered = {j:?} outside ±20% of {base:?}"
694 );
695 }
696 }
697}