1use std::collections::VecDeque;
7use std::sync::atomic::{AtomicU32, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use arrow_schema::SchemaRef;
12use async_nats::jetstream::{self, consumer::pull};
13use async_trait::async_trait;
14use bytes::Bytes;
15use crossfire::{mpsc, AsyncRx, MAsyncTx, TryRecvError};
16use futures_util::StreamExt;
17use rustc_hash::FxHashMap;
18use tokio::sync::Notify;
19use tokio::task::JoinHandle;
20use tracing::{debug, warn};
21
22use super::config::{build_connect_options, AckPolicy, DeliverPolicy, Mode, NatsSourceConfig};
23use super::metrics::NatsSourceMetrics;
24use crate::checkpoint::SourceCheckpoint;
25use crate::config::ConnectorConfig;
26use crate::connector::{PartitionInfo, SourceBatch, SourceConnector};
27use crate::error::ConnectorError;
28use crate::serde::{self, RecordDeserializer};
29
30struct Incoming {
32 subject: String,
33 payload: Bytes,
34 stream_seq: Option<u64>,
35 ack: Option<jetstream::Message>,
36}
37
38struct Running {
39 deserializer: Box<dyn RecordDeserializer>,
40 rx: AsyncRx<mpsc::Array<Incoming>>,
41 shutdown: Arc<Notify>,
42 handle: JoinHandle<()>,
43}
44
45pub struct NatsSource {
47 schema: SchemaRef,
48 config: Option<NatsSourceConfig>,
49 data_ready: Arc<Notify>,
50 metrics: NatsSourceMetrics,
51 running: Option<Running>,
52
53 pending: parking_lot::Mutex<Vec<jetstream::Message>>,
56 sealed: parking_lot::Mutex<VecDeque<Vec<jetstream::Message>>>,
57 offsets: parking_lot::Mutex<FxHashMap<String, u64>>,
58}
59
60impl NatsSource {
61 #[must_use]
63 pub fn new(schema: SchemaRef, registry: Option<&prometheus::Registry>) -> Self {
64 Self {
65 schema,
66 config: None,
67 data_ready: Arc::new(Notify::new()),
68 metrics: NatsSourceMetrics::new(registry),
69 running: None,
70 pending: parking_lot::Mutex::new(Vec::new()),
71 sealed: parking_lot::Mutex::new(VecDeque::new()),
72 offsets: parking_lot::Mutex::new(FxHashMap::default()),
73 }
74 }
75
76 fn update_pending_gauge(&self) {
77 let pending_len = self.pending.lock().len();
78 let sealed_total: usize = self.sealed.lock().iter().map(Vec::len).sum();
79 self.metrics.set_pending_acks(pending_len + sealed_total);
80 }
81
82 #[must_use]
84 pub fn config(&self) -> Option<&NatsSourceConfig> {
85 self.config.as_ref()
86 }
87
88 #[must_use]
90 pub fn metrics_handle(&self) -> &NatsSourceMetrics {
91 &self.metrics
92 }
93
94 async fn open_jetstream(
95 &mut self,
96 cfg: &NatsSourceConfig,
97 deserializer: Box<dyn RecordDeserializer>,
98 ) -> Result<(), ConnectorError> {
99 let client = connect(cfg).await?;
100 let js = jetstream::new(client);
101
102 let stream_name = cfg
103 .stream
104 .as_deref()
105 .ok_or_else(|| err("stream name missing after validation"))?;
106 let consumer_name = cfg
107 .consumer
108 .as_deref()
109 .ok_or_else(|| err("consumer name missing after validation"))?;
110
111 let pull_cfg = build_pull_config(cfg, consumer_name)?;
112 let stream = js
113 .get_stream(stream_name)
114 .await
115 .map_err(|e| err(&format!("get_stream('{stream_name}') failed: {e}")))?;
116 let consumer = stream
117 .create_consumer(pull_cfg)
118 .await
119 .map_err(|e| classify_create_consumer_error(&e, consumer_name))?;
120
121 let (tx, rx) = mpsc::bounded_async::<Incoming>(cfg.fetch_batch * 2);
122 let shutdown = Arc::new(Notify::new());
123
124 let reader = JsReader {
125 consumer,
126 tx,
127 shutdown: Arc::clone(&shutdown),
128 consecutive_errors: Arc::new(AtomicU32::new(0)),
129 data_ready: Arc::clone(&self.data_ready),
130 metrics: self.metrics.clone(),
131 batch_size: cfg.fetch_batch,
132 max_wait: cfg.fetch_max_wait,
133 lag_poll_interval: cfg.lag_poll_interval,
134 };
135 let handle = tokio::spawn(reader.run());
136
137 self.running = Some(Running {
138 deserializer,
139 rx,
140 shutdown,
141 handle,
142 });
143 Ok(())
144 }
145
146 async fn open_core(
147 &mut self,
148 cfg: &NatsSourceConfig,
149 deserializer: Box<dyn RecordDeserializer>,
150 ) -> Result<(), ConnectorError> {
151 let client = connect(cfg).await?;
152 let subject = cfg
153 .subject
154 .clone()
155 .ok_or_else(|| err("subject missing after validation"))?;
156 let subscriber = if let Some(group) = cfg.queue_group.as_deref() {
157 client
158 .queue_subscribe(subject, group.to_string())
159 .await
160 .map_err(|e| err(&format!("queue_subscribe: {e}")))?
161 } else {
162 client
163 .subscribe(subject)
164 .await
165 .map_err(|e| err(&format!("subscribe: {e}")))?
166 };
167
168 let (tx, rx) = mpsc::bounded_async::<Incoming>(cfg.fetch_batch * 2);
169 let shutdown = Arc::new(Notify::new());
170
171 let reader = CoreReader {
172 subscriber,
173 tx,
174 shutdown: Arc::clone(&shutdown),
175 data_ready: Arc::clone(&self.data_ready),
176 };
177 let handle = tokio::spawn(reader.run());
178
179 self.running = Some(Running {
180 deserializer,
181 rx,
182 shutdown,
183 handle,
184 });
185 Ok(())
186 }
187}
188
189#[async_trait]
190impl SourceConnector for NatsSource {
191 async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
192 let cfg = NatsSourceConfig::from_config(config)?;
193 if let Some(schema) = config.arrow_schema() {
195 self.schema = schema;
196 }
197 let deserializer = serde::create_deserializer(cfg.format)
198 .map_err(|e| err(&format!("deserializer for format {:?}: {e}", cfg.format)))?;
199 match cfg.mode {
200 Mode::JetStream => self.open_jetstream(&cfg, deserializer).await?,
201 Mode::Core => self.open_core(&cfg, deserializer).await?,
202 }
203 self.config = Some(cfg);
204 Ok(())
205 }
206
207 async fn poll_batch(
208 &mut self,
209 max_records: usize,
210 ) -> Result<Option<SourceBatch>, ConnectorError> {
211 let Some(running) = self.running.as_mut() else {
212 return Ok(None);
213 };
214
215 let mut payloads: Vec<Bytes> = Vec::new();
216 let mut partition: Option<PartitionInfo> = None;
217 let mut new_acks: Vec<jetstream::Message> = Vec::new();
218 let mut offset_updates: Vec<(String, u64)> = Vec::new();
219
220 while payloads.len() < max_records {
221 let incoming = match running.rx.try_recv() {
222 Ok(m) => m,
223 Err(TryRecvError::Empty | TryRecvError::Disconnected) => break,
224 };
225 if let Some(seq) = incoming.stream_seq {
226 offset_updates.push((incoming.subject.clone(), seq));
227 if partition.is_none() {
228 partition = Some(PartitionInfo::new(&incoming.subject, seq.to_string()));
229 }
230 }
231 payloads.push(incoming.payload);
232 if let Some(msg) = incoming.ack {
233 new_acks.push(msg);
234 }
235 }
236
237 if payloads.is_empty() {
238 return Ok(None);
239 }
240
241 let records: Vec<&[u8]> = payloads.iter().map(Bytes::as_ref).collect();
242 let bytes_total: u64 = records.iter().map(|r| r.len() as u64).sum();
243 let batch = running
246 .deserializer
247 .deserialize_batch(&records, &self.schema)
248 .map_err(|e| err(&format!("deserialize batch: {e}")))?;
249
250 if !offset_updates.is_empty() {
251 let mut offsets = self.offsets.lock();
252 for (subject, seq) in offset_updates {
253 offsets
254 .entry(subject)
255 .and_modify(|s| *s = (*s).max(seq))
256 .or_insert(seq);
257 }
258 }
259 if !new_acks.is_empty() {
260 self.pending.lock().extend(new_acks);
261 }
262
263 self.metrics
264 .record_poll(batch.num_rows() as u64, bytes_total);
265 self.update_pending_gauge();
266
267 Ok(Some(SourceBatch {
268 records: batch,
269 partition,
270 }))
271 }
272
273 fn schema(&self) -> SchemaRef {
274 self.schema.clone()
275 }
276
277 fn checkpoint(&self) -> SourceCheckpoint {
278 {
282 let mut pending = self.pending.lock();
283 if !pending.is_empty() {
284 let batch = std::mem::take(&mut *pending);
285 self.sealed.lock().push_back(batch);
286 }
287 }
288 let mut cp = SourceCheckpoint::default();
289 for (subject, seq) in &*self.offsets.lock() {
290 cp.set_offset(subject.as_str(), seq.to_string());
291 }
292 cp
293 }
294
295 async fn restore(&mut self, _checkpoint: &SourceCheckpoint) -> Result<(), ConnectorError> {
296 Ok(())
298 }
299
300 async fn close(&mut self) -> Result<(), ConnectorError> {
301 let Some(running) = self.running.take() else {
302 return Ok(());
303 };
304 running.shutdown.notify_one();
305 let _ = tokio::time::timeout(Duration::from_secs(5), running.handle).await;
306 Ok(())
307 }
308
309 fn data_ready_notify(&self) -> Option<Arc<Notify>> {
310 Some(Arc::clone(&self.data_ready))
311 }
312
313 fn supports_replay(&self) -> bool {
314 !matches!(self.config.as_ref().map(|c| c.mode), Some(Mode::Core))
316 }
317
318 async fn notify_epoch_committed(
319 &mut self,
320 _epoch: u64,
321 _checkpoint: &crate::checkpoint::SourceCheckpoint,
322 ) -> Result<(), ConnectorError> {
323 loop {
326 let Some(batch) = self.sealed.lock().pop_front() else {
327 break;
328 };
329 for msg in batch {
330 match msg.ack().await {
331 Ok(()) => self.metrics.record_ack(),
332 Err(e) => {
333 self.metrics.record_ack_error();
334 warn!(error = %e, "JetStream ack failed; broker will redeliver");
335 }
336 }
337 }
338 }
339 self.update_pending_gauge();
340 Ok(())
341 }
342}
343
344fn err(msg: &str) -> ConnectorError {
347 ConnectorError::ConfigurationError(msg.to_string())
348}
349
350async fn connect(cfg: &NatsSourceConfig) -> Result<async_nats::Client, ConnectorError> {
351 build_connect_options(&cfg.auth, &cfg.tls)?
352 .connect(&cfg.servers)
353 .await
354 .map_err(|e| err(&format!("nats connect({:?}): {e}", cfg.servers)))
355}
356
357fn build_pull_config(
358 cfg: &NatsSourceConfig,
359 consumer_name: &str,
360) -> Result<pull::Config, ConnectorError> {
361 let filter_subjects = if cfg.subject_filters.is_empty() {
362 cfg.subject.iter().cloned().collect()
363 } else {
364 cfg.subject_filters.clone()
365 };
366
367 Ok(pull::Config {
368 durable_name: Some(consumer_name.to_string()),
369 filter_subjects,
370 deliver_policy: map_deliver_policy(cfg)?,
371 ack_policy: map_ack_policy(cfg.ack_policy),
372 ack_wait: cfg.ack_wait,
373 max_deliver: cfg.max_deliver,
374 max_ack_pending: cfg.max_ack_pending,
375 ..Default::default()
376 })
377}
378
379fn map_deliver_policy(
380 cfg: &NatsSourceConfig,
381) -> Result<async_nats::jetstream::consumer::DeliverPolicy, ConnectorError> {
382 use async_nats::jetstream::consumer::DeliverPolicy as Nats;
383 Ok(match cfg.deliver_policy {
384 DeliverPolicy::All => Nats::All,
385 DeliverPolicy::New => Nats::New,
386 DeliverPolicy::ByStartSequence => Nats::ByStartSequence {
387 start_sequence: cfg.start_sequence.unwrap_or(1),
388 },
389 DeliverPolicy::ByStartTime => {
390 let raw = cfg
391 .start_time
392 .as_deref()
393 .ok_or_else(|| err("deliver.policy=by_start_time requires 'start.time'"))?;
394 let start_time =
395 time::OffsetDateTime::parse(raw, &time::format_description::well_known::Rfc3339)
396 .map_err(|e| err(&format!("start.time '{raw}' is not valid RFC3339: {e}")))?;
397 Nats::ByStartTime { start_time }
398 }
399 })
400}
401
402fn map_ack_policy(p: AckPolicy) -> async_nats::jetstream::consumer::AckPolicy {
403 use async_nats::jetstream::consumer::AckPolicy as Nats;
404 match p {
405 AckPolicy::Explicit => Nats::Explicit,
406 AckPolicy::None => Nats::None,
407 }
408}
409
410fn fetch_backoff_base(consecutive_errors: u32) -> Duration {
412 let exp = consecutive_errors.saturating_sub(1).min(4);
413 let ms = 500u64.saturating_mul(1u64 << exp);
414 Duration::from_millis(ms.min(5000))
415}
416
417fn with_jitter(base: Duration, entropy: u64) -> Duration {
419 let base_ms = u64::try_from(base.as_millis()).unwrap_or(u64::MAX);
420 let range = (base_ms / 5).max(1); let window = range * 2 + 1;
422 let offset = entropy % window;
423 let jittered = base_ms.saturating_add(offset).saturating_sub(range);
424 Duration::from_millis(jittered)
425}
426
427fn fetch_backoff(consecutive_errors: u32, entropy: u64) -> Duration {
428 with_jitter(fetch_backoff_base(consecutive_errors), entropy)
429}
430
431fn classify_create_consumer_error(
434 e: &async_nats::jetstream::stream::ConsumerError,
435 consumer_name: &str,
436) -> ConnectorError {
437 use async_nats::jetstream::stream::ConsumerErrorKind;
438 use async_nats::jetstream::ErrorCode;
439
440 let drift_code = match e.kind() {
441 ConsumerErrorKind::JetStream(server_err) => matches!(
442 server_err.error_code(),
443 ErrorCode::CONSUMER_ALREADY_EXISTS | ErrorCode::CONSUMER_NAME_EXIST
444 ),
445 _ => false,
446 };
447 if drift_code {
448 err(&format!(
449 "[LDB-5070] consumer '{consumer_name}' exists with incompatible config; \
450 rotate the durable name or delete the consumer out-of-band. \
451 Server said: {e}"
452 ))
453 } else {
454 err(&format!("create_consumer('{consumer_name}') failed: {e}"))
455 }
456}
457
458#[allow(clippy::cast_possible_truncation)]
461fn entropy_now() -> u64 {
462 std::time::SystemTime::now()
463 .duration_since(std::time::UNIX_EPOCH)
464 .unwrap_or_default()
465 .as_nanos() as u64
466}
467
468struct JsReader {
469 consumer: jetstream::consumer::Consumer<pull::Config>,
470 tx: MAsyncTx<mpsc::Array<Incoming>>,
471 shutdown: Arc<Notify>,
472 consecutive_errors: Arc<AtomicU32>,
473 data_ready: Arc<Notify>,
474 metrics: NatsSourceMetrics,
475 batch_size: usize,
476 max_wait: Duration,
477 lag_poll_interval: Duration,
479}
480
481impl JsReader {
482 async fn run(self) {
483 let Self {
484 mut consumer,
485 tx,
486 shutdown,
487 consecutive_errors,
488 data_ready,
489 metrics,
490 batch_size,
491 max_wait,
492 lag_poll_interval,
493 } = self;
494
495 let mut last_lag_poll = Instant::now();
496 let lag_poll_enabled = !lag_poll_interval.is_zero();
497
498 loop {
499 let fetch_result = tokio::select! {
500 biased;
501 () = shutdown.notified() => break,
502 r = consumer.fetch().max_messages(batch_size).expires(max_wait).messages() => r,
503 };
504
505 let mut stream = match fetch_result {
506 Ok(s) => s,
507 Err(e) => {
508 let errs = consecutive_errors.fetch_add(1, Ordering::AcqRel) + 1;
509 metrics.record_fetch_error();
510 warn!(
511 error = %e,
512 consecutive_errors = errs,
513 "nats fetch() errored; backing off",
514 );
515 let backoff = fetch_backoff(errs, entropy_now());
516 tokio::select! {
517 biased;
518 () = shutdown.notified() => break,
519 () = tokio::time::sleep(backoff) => {}
520 }
521 continue;
522 }
523 };
524
525 let mut forwarded = 0usize;
526 let mut stream_errors = 0usize;
527 loop {
528 let msg_result = tokio::select! {
529 biased;
530 () = shutdown.notified() => return,
531 r = stream.next() => match r {
532 Some(r) => r,
533 None => break,
534 },
535 };
536 let msg = match msg_result {
537 Ok(m) => m,
538 Err(e) => {
539 metrics.record_fetch_error();
540 stream_errors += 1;
541 warn!(error = %e, "nats message error");
542 continue;
543 }
544 };
545 let incoming = Incoming {
546 subject: msg.subject.to_string(),
547 payload: msg.payload.clone(),
548 stream_seq: msg.info().ok().map(|i| i.stream_sequence),
549 ack: Some(msg),
550 };
551 if tx.send(incoming).await.is_err() {
552 debug!("nats reader: downstream channel closed");
553 return;
554 }
555 forwarded += 1;
556 }
557
558 if forwarded > 0 {
561 consecutive_errors.store(0, Ordering::Release);
562 data_ready.notify_one();
563 } else if stream_errors > 0 {
564 let errs = consecutive_errors.fetch_add(1, Ordering::AcqRel) + 1;
565 let backoff = fetch_backoff(errs, entropy_now());
566 tokio::select! {
567 biased;
568 () = shutdown.notified() => break,
569 () = tokio::time::sleep(backoff) => {}
570 }
571 }
572
573 if lag_poll_enabled && last_lag_poll.elapsed() >= lag_poll_interval {
574 last_lag_poll = Instant::now();
575 match consumer.info().await {
576 Ok(info) => metrics.set_consumer_lag(info.num_pending),
577 Err(e) => warn!(error = %e, "consumer.info() failed; skipping lag update"),
578 }
579 }
580 }
581 }
582}
583
584struct CoreReader {
585 subscriber: async_nats::Subscriber,
586 tx: MAsyncTx<mpsc::Array<Incoming>>,
587 shutdown: Arc<Notify>,
588 data_ready: Arc<Notify>,
589}
590
591impl CoreReader {
592 async fn run(self) {
593 let Self {
594 mut subscriber,
595 tx,
596 shutdown,
597 data_ready,
598 } = self;
599
600 loop {
601 let msg = tokio::select! {
602 biased;
603 () = shutdown.notified() => break,
604 m = subscriber.next() => match m {
605 Some(m) => m,
606 None => break,
607 },
608 };
609 let incoming = Incoming {
610 subject: msg.subject.to_string(),
611 payload: msg.payload,
612 stream_seq: None,
613 ack: None,
614 };
615 if tx.send(incoming).await.is_err() {
616 return;
617 }
618 data_ready.notify_one();
619 }
620 }
621}
622
623#[cfg(test)]
624mod tests {
625 use super::*;
626 use arrow_schema::Schema;
627
628 #[test]
629 fn checkpoint_empty_pending_is_noop() {
630 let src = NatsSource::new(Arc::new(Schema::empty()), None);
634 let _ = src.checkpoint();
635 assert!(src.sealed.lock().is_empty());
636 }
637
638 #[test]
639 fn backoff_base_grows_then_caps_at_5s() {
640 assert_eq!(fetch_backoff_base(1), Duration::from_millis(500));
641 assert_eq!(fetch_backoff_base(2), Duration::from_millis(1000));
642 assert_eq!(fetch_backoff_base(3), Duration::from_millis(2000));
643 assert_eq!(fetch_backoff_base(4), Duration::from_millis(4000));
644 assert_eq!(fetch_backoff_base(5), Duration::from_millis(5000));
645 assert_eq!(fetch_backoff_base(100), Duration::from_millis(5000));
646 }
647
648 #[test]
649 fn jitter_stays_within_plus_minus_20_percent() {
650 let base = Duration::from_millis(1000);
651 for entropy in [0u64, 1, 99, 12345, u64::MAX] {
652 let j = with_jitter(base, entropy);
653 assert!(
654 j >= Duration::from_millis(800) && j <= Duration::from_millis(1200),
655 "entropy {entropy}: jittered = {j:?} outside ±20% of {base:?}"
656 );
657 }
658 }
659}