Skip to main content

laminar_connectors/nats/
source.rs

1//! NATS source: `JetStream` pull consumer with ack-on-commit, or core
2//! subscribe (at-most-once). A background task forwards messages through
3//! an `mpsc` channel; JS message handles are retained until
4//! `notify_epoch_committed` fires, then acked in bulk.
5
6use std::collections::VecDeque;
7use std::sync::atomic::{AtomicU32, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use arrow_schema::SchemaRef;
12use async_nats::jetstream::{self, consumer::pull};
13use async_trait::async_trait;
14use bytes::Bytes;
15use crossfire::{mpsc, AsyncRx, MAsyncTx, TryRecvError};
16use futures_util::StreamExt;
17use rustc_hash::FxHashMap;
18use tokio::sync::Notify;
19use tokio::task::JoinHandle;
20use tracing::{debug, warn};
21
22use super::config::{build_connect_options, AckPolicy, DeliverPolicy, Mode, NatsSourceConfig};
23use super::metrics::NatsSourceMetrics;
24use crate::checkpoint::SourceCheckpoint;
25use crate::config::ConnectorConfig;
26use crate::connector::{PartitionInfo, SourceBatch, SourceConnector};
27use crate::error::ConnectorError;
28use crate::health::HealthStatus;
29use crate::metrics::ConnectorMetrics;
30use crate::serde::{self, RecordDeserializer};
31
32/// `ack` is `Some` only on the `JetStream` path.
33struct Incoming {
34    subject: String,
35    payload: Bytes,
36    stream_seq: Option<u64>,
37    ack: Option<jetstream::Message>,
38}
39
40struct Running {
41    deserializer: Box<dyn RecordDeserializer>,
42    rx: AsyncRx<mpsc::Array<Incoming>>,
43    shutdown: Arc<Notify>,
44    /// `Some` on `JetStream`; `None` on core.
45    consecutive_errors: Option<Arc<AtomicU32>>,
46    handle: JoinHandle<()>,
47}
48
49/// NATS source — core and `JetStream` modes.
50pub struct NatsSource {
51    schema: SchemaRef,
52    config: Option<NatsSourceConfig>,
53    data_ready: Arc<Notify>,
54    metrics: NatsSourceMetrics,
55    running: Option<Running>,
56
57    // Interior mutability so `checkpoint()` (takes `&self`) can seal
58    // `pending` into `sealed` atomically with offset capture.
59    pending: parking_lot::Mutex<Vec<jetstream::Message>>,
60    sealed: parking_lot::Mutex<VecDeque<Vec<jetstream::Message>>>,
61    offsets: parking_lot::Mutex<FxHashMap<String, u64>>,
62}
63
64impl NatsSource {
65    /// Metrics register on `registry` if provided.
66    #[must_use]
67    pub fn new(schema: SchemaRef, registry: Option<&prometheus::Registry>) -> Self {
68        Self {
69            schema,
70            config: None,
71            data_ready: Arc::new(Notify::new()),
72            metrics: NatsSourceMetrics::new(registry),
73            running: None,
74            pending: parking_lot::Mutex::new(Vec::new()),
75            sealed: parking_lot::Mutex::new(VecDeque::new()),
76            offsets: parking_lot::Mutex::new(FxHashMap::default()),
77        }
78    }
79
80    fn update_pending_gauge(&self) {
81        let pending_len = self.pending.lock().len();
82        let sealed_total: usize = self.sealed.lock().iter().map(Vec::len).sum();
83        self.metrics.set_pending_acks(pending_len + sealed_total);
84    }
85
86    /// Available after [`SourceConnector::open`].
87    #[must_use]
88    pub fn config(&self) -> Option<&NatsSourceConfig> {
89        self.config.as_ref()
90    }
91
92    async fn open_jetstream(
93        &mut self,
94        cfg: &NatsSourceConfig,
95        deserializer: Box<dyn RecordDeserializer>,
96    ) -> Result<(), ConnectorError> {
97        let client = connect(cfg).await?;
98        let js = jetstream::new(client);
99
100        let stream_name = cfg
101            .stream
102            .as_deref()
103            .ok_or_else(|| err("stream name missing after validation"))?;
104        let consumer_name = cfg
105            .consumer
106            .as_deref()
107            .ok_or_else(|| err("consumer name missing after validation"))?;
108
109        let pull_cfg = build_pull_config(cfg, consumer_name)?;
110        let stream = js
111            .get_stream(stream_name)
112            .await
113            .map_err(|e| err(&format!("get_stream('{stream_name}') failed: {e}")))?;
114        let consumer = stream
115            .create_consumer(pull_cfg)
116            .await
117            .map_err(|e| classify_create_consumer_error(&e, consumer_name))?;
118
119        let (tx, rx) = mpsc::bounded_async::<Incoming>(cfg.fetch_batch * 2);
120        let shutdown = Arc::new(Notify::new());
121        let consecutive_errors = Arc::new(AtomicU32::new(0));
122
123        let reader = JsReader {
124            consumer,
125            tx,
126            shutdown: Arc::clone(&shutdown),
127            consecutive_errors: Arc::clone(&consecutive_errors),
128            data_ready: Arc::clone(&self.data_ready),
129            metrics: self.metrics.clone(),
130            batch_size: cfg.fetch_batch,
131            max_wait: cfg.fetch_max_wait,
132            lag_poll_interval: cfg.lag_poll_interval,
133        };
134        let handle = tokio::spawn(reader.run());
135
136        self.running = Some(Running {
137            deserializer,
138            rx,
139            shutdown,
140            consecutive_errors: Some(consecutive_errors),
141            handle,
142        });
143        Ok(())
144    }
145
146    async fn open_core(
147        &mut self,
148        cfg: &NatsSourceConfig,
149        deserializer: Box<dyn RecordDeserializer>,
150    ) -> Result<(), ConnectorError> {
151        let client = connect(cfg).await?;
152        let subject = cfg
153            .subject
154            .clone()
155            .ok_or_else(|| err("subject missing after validation"))?;
156        let subscriber = if let Some(group) = cfg.queue_group.as_deref() {
157            client
158                .queue_subscribe(subject, group.to_string())
159                .await
160                .map_err(|e| err(&format!("queue_subscribe: {e}")))?
161        } else {
162            client
163                .subscribe(subject)
164                .await
165                .map_err(|e| err(&format!("subscribe: {e}")))?
166        };
167
168        let (tx, rx) = mpsc::bounded_async::<Incoming>(cfg.fetch_batch * 2);
169        let shutdown = Arc::new(Notify::new());
170
171        let reader = CoreReader {
172            subscriber,
173            tx,
174            shutdown: Arc::clone(&shutdown),
175            data_ready: Arc::clone(&self.data_ready),
176        };
177        let handle = tokio::spawn(reader.run());
178
179        self.running = Some(Running {
180            deserializer,
181            rx,
182            shutdown,
183            consecutive_errors: None,
184            handle,
185        });
186        Ok(())
187    }
188}
189
190#[async_trait]
191impl SourceConnector for NatsSource {
192    async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
193        let cfg = NatsSourceConfig::from_config(config)?;
194        // SQL DDL schema overrides the registry placeholder.
195        if let Some(schema) = config.arrow_schema() {
196            self.schema = schema;
197        }
198        let deserializer = serde::create_deserializer(cfg.format)
199            .map_err(|e| err(&format!("deserializer for format {:?}: {e}", cfg.format)))?;
200        match cfg.mode {
201            Mode::JetStream => self.open_jetstream(&cfg, deserializer).await?,
202            Mode::Core => self.open_core(&cfg, deserializer).await?,
203        }
204        self.config = Some(cfg);
205        Ok(())
206    }
207
208    async fn poll_batch(
209        &mut self,
210        max_records: usize,
211    ) -> Result<Option<SourceBatch>, ConnectorError> {
212        let Some(running) = self.running.as_mut() else {
213            return Ok(None);
214        };
215
216        let mut payloads: Vec<Bytes> = Vec::new();
217        let mut partition: Option<PartitionInfo> = None;
218        let mut new_acks: Vec<jetstream::Message> = Vec::new();
219        let mut offset_updates: Vec<(String, u64)> = Vec::new();
220
221        while payloads.len() < max_records {
222            let incoming = match running.rx.try_recv() {
223                Ok(m) => m,
224                Err(TryRecvError::Empty | TryRecvError::Disconnected) => break,
225            };
226            if let Some(seq) = incoming.stream_seq {
227                offset_updates.push((incoming.subject.clone(), seq));
228                if partition.is_none() {
229                    partition = Some(PartitionInfo::new(&incoming.subject, seq.to_string()));
230                }
231            }
232            payloads.push(incoming.payload);
233            if let Some(msg) = incoming.ack {
234                new_acks.push(msg);
235            }
236        }
237
238        if payloads.is_empty() {
239            return Ok(None);
240        }
241
242        let records: Vec<&[u8]> = payloads.iter().map(Bytes::as_ref).collect();
243        let bytes_total: u64 = records.iter().map(|r| r.len() as u64).sum();
244        // Deserialize before parking acks: on failure the handles drop
245        // un-acked and the broker redelivers after ack_wait.
246        let batch = running
247            .deserializer
248            .deserialize_batch(&records, &self.schema)
249            .map_err(|e| err(&format!("deserialize batch: {e}")))?;
250
251        if !offset_updates.is_empty() {
252            let mut offsets = self.offsets.lock();
253            for (subject, seq) in offset_updates {
254                offsets
255                    .entry(subject)
256                    .and_modify(|s| *s = (*s).max(seq))
257                    .or_insert(seq);
258            }
259        }
260        if !new_acks.is_empty() {
261            self.pending.lock().extend(new_acks);
262        }
263
264        self.metrics
265            .record_poll(batch.num_rows() as u64, bytes_total);
266        self.update_pending_gauge();
267
268        Ok(Some(SourceBatch {
269            records: batch,
270            partition,
271        }))
272    }
273
274    fn schema(&self) -> SchemaRef {
275        self.schema.clone()
276    }
277
278    fn checkpoint(&self) -> SourceCheckpoint {
279        // Seal now-pending acks into this epoch; anything arriving after
280        // goes into a fresh batch committed with a later epoch. Without
281        // this split, an ack could fire before its manifest record lands.
282        {
283            let mut pending = self.pending.lock();
284            if !pending.is_empty() {
285                let batch = std::mem::take(&mut *pending);
286                self.sealed.lock().push_back(batch);
287            }
288        }
289        let mut cp = SourceCheckpoint::default();
290        for (subject, seq) in &*self.offsets.lock() {
291            cp.set_offset(subject.as_str(), seq.to_string());
292        }
293        cp
294    }
295
296    async fn restore(&mut self, _checkpoint: &SourceCheckpoint) -> Result<(), ConnectorError> {
297        // Durable consumer ack floor on the server is authoritative.
298        Ok(())
299    }
300
301    async fn close(&mut self) -> Result<(), ConnectorError> {
302        let Some(running) = self.running.take() else {
303            return Ok(());
304        };
305        running.shutdown.notify_one();
306        let _ = tokio::time::timeout(Duration::from_secs(5), running.handle).await;
307        Ok(())
308    }
309
310    fn data_ready_notify(&self) -> Option<Arc<Notify>> {
311        Some(Arc::clone(&self.data_ready))
312    }
313
314    fn supports_replay(&self) -> bool {
315        // JetStream is replayable (durable consumer); core NATS is not.
316        !matches!(self.config.as_ref().map(|c| c.mode), Some(Mode::Core))
317    }
318
319    async fn notify_epoch_committed(&mut self, _epoch: u64) -> Result<(), ConnectorError> {
320        // Per-msg ack errors don't roll the epoch back; the broker
321        // redelivers on ack_wait.
322        loop {
323            let Some(batch) = self.sealed.lock().pop_front() else {
324                break;
325            };
326            for msg in batch {
327                match msg.ack().await {
328                    Ok(()) => self.metrics.record_ack(),
329                    Err(e) => {
330                        self.metrics.record_ack_error();
331                        warn!(error = %e, "JetStream ack failed; broker will redeliver");
332                    }
333                }
334            }
335        }
336        self.update_pending_gauge();
337        Ok(())
338    }
339
340    fn health_check(&self) -> HealthStatus {
341        let Some(cfg) = self.config.as_ref() else {
342            return HealthStatus::Unknown;
343        };
344
345        // Threshold of zero disables the flip.
346        if cfg.fetch_error_threshold > 0 {
347            if let Some(errors) = self
348                .running
349                .as_ref()
350                .and_then(|r| r.consecutive_errors.as_ref())
351            {
352                let errs = errors.load(Ordering::Acquire);
353                if errs >= cfg.fetch_error_threshold {
354                    return HealthStatus::Unhealthy(format!(
355                        "{errs} consecutive fetch errors (threshold {})",
356                        cfg.fetch_error_threshold
357                    ));
358                }
359            }
360        }
361
362        // Flag at 50% of `max_ack_pending`; -1 means unlimited.
363        if cfg.max_ack_pending > 0 {
364            #[allow(clippy::cast_sign_loss)]
365            let cap = cfg.max_ack_pending as u64;
366            #[allow(clippy::cast_sign_loss)]
367            let pending = self.metrics.pending_acks.get().max(0) as u64;
368            if pending * 2 >= cap {
369                return HealthStatus::Degraded(format!(
370                    "pending acks {pending}/{cap} — broker may throttle delivery"
371                ));
372            }
373        }
374        HealthStatus::Healthy
375    }
376
377    fn metrics(&self) -> ConnectorMetrics {
378        self.metrics.to_connector_metrics()
379    }
380}
381
382// ── helpers ──
383
384fn err(msg: &str) -> ConnectorError {
385    ConnectorError::ConfigurationError(msg.to_string())
386}
387
388async fn connect(cfg: &NatsSourceConfig) -> Result<async_nats::Client, ConnectorError> {
389    build_connect_options(&cfg.auth, &cfg.tls)?
390        .connect(&cfg.servers)
391        .await
392        .map_err(|e| err(&format!("nats connect({:?}): {e}", cfg.servers)))
393}
394
395fn build_pull_config(
396    cfg: &NatsSourceConfig,
397    consumer_name: &str,
398) -> Result<pull::Config, ConnectorError> {
399    let filter_subjects = if cfg.subject_filters.is_empty() {
400        cfg.subject.iter().cloned().collect()
401    } else {
402        cfg.subject_filters.clone()
403    };
404
405    Ok(pull::Config {
406        durable_name: Some(consumer_name.to_string()),
407        filter_subjects,
408        deliver_policy: map_deliver_policy(cfg)?,
409        ack_policy: map_ack_policy(cfg.ack_policy),
410        ack_wait: cfg.ack_wait,
411        max_deliver: cfg.max_deliver,
412        max_ack_pending: cfg.max_ack_pending,
413        ..Default::default()
414    })
415}
416
417fn map_deliver_policy(
418    cfg: &NatsSourceConfig,
419) -> Result<async_nats::jetstream::consumer::DeliverPolicy, ConnectorError> {
420    use async_nats::jetstream::consumer::DeliverPolicy as Nats;
421    Ok(match cfg.deliver_policy {
422        DeliverPolicy::All => Nats::All,
423        DeliverPolicy::New => Nats::New,
424        DeliverPolicy::ByStartSequence => Nats::ByStartSequence {
425            start_sequence: cfg.start_sequence.unwrap_or(1),
426        },
427        DeliverPolicy::ByStartTime => {
428            let raw = cfg
429                .start_time
430                .as_deref()
431                .ok_or_else(|| err("deliver.policy=by_start_time requires 'start.time'"))?;
432            let start_time =
433                time::OffsetDateTime::parse(raw, &time::format_description::well_known::Rfc3339)
434                    .map_err(|e| err(&format!("start.time '{raw}' is not valid RFC3339: {e}")))?;
435            Nats::ByStartTime { start_time }
436        }
437    })
438}
439
440fn map_ack_policy(p: AckPolicy) -> async_nats::jetstream::consumer::AckPolicy {
441    use async_nats::jetstream::consumer::AckPolicy as Nats;
442    match p {
443        AckPolicy::Explicit => Nats::Explicit,
444        AckPolicy::None => Nats::None,
445    }
446}
447
448/// 500ms, 1s, 2s, 4s, cap 5s.
449fn fetch_backoff_base(consecutive_errors: u32) -> Duration {
450    let exp = consecutive_errors.saturating_sub(1).min(4);
451    let ms = 500u64.saturating_mul(1u64 << exp);
452    Duration::from_millis(ms.min(5000))
453}
454
455/// `base ± 20%`. Tests pass a fixed `entropy` seed.
456fn with_jitter(base: Duration, entropy: u64) -> Duration {
457    let base_ms = u64::try_from(base.as_millis()).unwrap_or(u64::MAX);
458    let range = (base_ms / 5).max(1); // 20%
459    let window = range * 2 + 1;
460    let offset = entropy % window;
461    let jittered = base_ms.saturating_add(offset).saturating_sub(range);
462    Duration::from_millis(jittered)
463}
464
465fn fetch_backoff(consecutive_errors: u32, entropy: u64) -> Duration {
466    with_jitter(fetch_backoff_base(consecutive_errors), entropy)
467}
468
469/// Server 10148 / 10013 → consumer exists with a conflicting config;
470/// raise LDB-5070 with an operator fix-up.
471fn classify_create_consumer_error(
472    e: &async_nats::jetstream::stream::ConsumerError,
473    consumer_name: &str,
474) -> ConnectorError {
475    use async_nats::jetstream::stream::ConsumerErrorKind;
476    use async_nats::jetstream::ErrorCode;
477
478    let drift_code = match e.kind() {
479        ConsumerErrorKind::JetStream(server_err) => matches!(
480            server_err.error_code(),
481            ErrorCode::CONSUMER_ALREADY_EXISTS | ErrorCode::CONSUMER_NAME_EXIST
482        ),
483        _ => false,
484    };
485    if drift_code {
486        err(&format!(
487            "[LDB-5070] consumer '{consumer_name}' exists with incompatible config; \
488             rotate the durable name or delete the consumer out-of-band. \
489             Server said: {e}"
490        ))
491    } else {
492        err(&format!("create_consumer('{consumer_name}') failed: {e}"))
493    }
494}
495
496/// Wall-clock nanos for `with_jitter`. `Instant::now().elapsed()` is ~0
497/// and produces correlated jitter across tasks.
498#[allow(clippy::cast_possible_truncation)]
499fn entropy_now() -> u64 {
500    std::time::SystemTime::now()
501        .duration_since(std::time::UNIX_EPOCH)
502        .unwrap_or_default()
503        .as_nanos() as u64
504}
505
506struct JsReader {
507    consumer: jetstream::consumer::Consumer<pull::Config>,
508    tx: MAsyncTx<mpsc::Array<Incoming>>,
509    shutdown: Arc<Notify>,
510    consecutive_errors: Arc<AtomicU32>,
511    data_ready: Arc<Notify>,
512    metrics: NatsSourceMetrics,
513    batch_size: usize,
514    max_wait: Duration,
515    /// `Duration::ZERO` disables the poll.
516    lag_poll_interval: Duration,
517}
518
519impl JsReader {
520    async fn run(self) {
521        let Self {
522            mut consumer,
523            tx,
524            shutdown,
525            consecutive_errors,
526            data_ready,
527            metrics,
528            batch_size,
529            max_wait,
530            lag_poll_interval,
531        } = self;
532
533        let mut last_lag_poll = Instant::now();
534        let lag_poll_enabled = !lag_poll_interval.is_zero();
535
536        loop {
537            let fetch_result = tokio::select! {
538                biased;
539                () = shutdown.notified() => break,
540                r = consumer.fetch().max_messages(batch_size).expires(max_wait).messages() => r,
541            };
542
543            let mut stream = match fetch_result {
544                Ok(s) => s,
545                Err(e) => {
546                    let errs = consecutive_errors.fetch_add(1, Ordering::AcqRel) + 1;
547                    metrics.record_fetch_error();
548                    warn!(
549                        error = %e,
550                        consecutive_errors = errs,
551                        "nats fetch() errored; backing off",
552                    );
553                    let backoff = fetch_backoff(errs, entropy_now());
554                    tokio::select! {
555                        biased;
556                        () = shutdown.notified() => break,
557                        () = tokio::time::sleep(backoff) => {}
558                    }
559                    continue;
560                }
561            };
562
563            let mut forwarded = 0usize;
564            let mut stream_errors = 0usize;
565            loop {
566                let msg_result = tokio::select! {
567                    biased;
568                    () = shutdown.notified() => return,
569                    r = stream.next() => match r {
570                        Some(r) => r,
571                        None => break,
572                    },
573                };
574                let msg = match msg_result {
575                    Ok(m) => m,
576                    Err(e) => {
577                        metrics.record_fetch_error();
578                        stream_errors += 1;
579                        warn!(error = %e, "nats message error");
580                        continue;
581                    }
582                };
583                let incoming = Incoming {
584                    subject: msg.subject.to_string(),
585                    payload: msg.payload.clone(),
586                    stream_seq: msg.info().ok().map(|i| i.stream_sequence),
587                    ack: Some(msg),
588                };
589                if tx.send(incoming).await.is_err() {
590                    debug!("nats reader: downstream channel closed");
591                    return;
592                }
593                forwarded += 1;
594            }
595
596            // Reset on progress; an iteration with only errors counts
597            // as one failure; idle iterations don't bump.
598            if forwarded > 0 {
599                consecutive_errors.store(0, Ordering::Release);
600                data_ready.notify_one();
601            } else if stream_errors > 0 {
602                let errs = consecutive_errors.fetch_add(1, Ordering::AcqRel) + 1;
603                let backoff = fetch_backoff(errs, entropy_now());
604                tokio::select! {
605                    biased;
606                    () = shutdown.notified() => break,
607                    () = tokio::time::sleep(backoff) => {}
608                }
609            }
610
611            if lag_poll_enabled && last_lag_poll.elapsed() >= lag_poll_interval {
612                last_lag_poll = Instant::now();
613                match consumer.info().await {
614                    Ok(info) => metrics.set_consumer_lag(info.num_pending),
615                    Err(e) => warn!(error = %e, "consumer.info() failed; skipping lag update"),
616                }
617            }
618        }
619    }
620}
621
622struct CoreReader {
623    subscriber: async_nats::Subscriber,
624    tx: MAsyncTx<mpsc::Array<Incoming>>,
625    shutdown: Arc<Notify>,
626    data_ready: Arc<Notify>,
627}
628
629impl CoreReader {
630    async fn run(self) {
631        let Self {
632            mut subscriber,
633            tx,
634            shutdown,
635            data_ready,
636        } = self;
637
638        loop {
639            let msg = tokio::select! {
640                biased;
641                () = shutdown.notified() => break,
642                m = subscriber.next() => match m {
643                    Some(m) => m,
644                    None => break,
645                },
646            };
647            let incoming = Incoming {
648                subject: msg.subject.to_string(),
649                payload: msg.payload,
650                stream_seq: None,
651                ack: None,
652            };
653            if tx.send(incoming).await.is_err() {
654                return;
655            }
656            data_ready.notify_one();
657        }
658    }
659}
660
661#[cfg(test)]
662mod tests {
663    use super::*;
664    use arrow_schema::Schema;
665
666    #[test]
667    fn checkpoint_empty_pending_is_noop() {
668        // `jetstream::Message` can't be constructed without a server,
669        // so we only cover the empty path here; non-empty sealing is
670        // exercised in `tests/nats_integration.rs`.
671        let src = NatsSource::new(Arc::new(Schema::empty()), None);
672        let _ = src.checkpoint();
673        assert!(src.sealed.lock().is_empty());
674    }
675
676    #[test]
677    fn backoff_base_grows_then_caps_at_5s() {
678        assert_eq!(fetch_backoff_base(1), Duration::from_millis(500));
679        assert_eq!(fetch_backoff_base(2), Duration::from_millis(1000));
680        assert_eq!(fetch_backoff_base(3), Duration::from_millis(2000));
681        assert_eq!(fetch_backoff_base(4), Duration::from_millis(4000));
682        assert_eq!(fetch_backoff_base(5), Duration::from_millis(5000));
683        assert_eq!(fetch_backoff_base(100), Duration::from_millis(5000));
684    }
685
686    #[test]
687    fn jitter_stays_within_plus_minus_20_percent() {
688        let base = Duration::from_millis(1000);
689        for entropy in [0u64, 1, 99, 12345, u64::MAX] {
690            let j = with_jitter(base, entropy);
691            assert!(
692                j >= Duration::from_millis(800) && j <= Duration::from_millis(1200),
693                "entropy {entropy}: jittered = {j:?} outside ±20% of {base:?}"
694            );
695        }
696    }
697}