Skip to main content

laminar_connectors/nats/
source.rs

1//! NATS source: `JetStream` pull consumer with ack-on-commit, or core
2//! subscribe (at-most-once). A background task forwards messages through
3//! an `mpsc` channel; JS message handles are retained until
4//! `notify_epoch_committed` fires, then acked in bulk.
5
6use std::collections::VecDeque;
7use std::sync::atomic::{AtomicU32, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use arrow_schema::SchemaRef;
12use async_nats::jetstream::{self, consumer::pull};
13use async_trait::async_trait;
14use bytes::Bytes;
15use crossfire::{mpsc, AsyncRx, MAsyncTx, TryRecvError};
16use futures_util::StreamExt;
17use rustc_hash::FxHashMap;
18use tokio::sync::Notify;
19use tokio::task::JoinHandle;
20use tracing::{debug, warn};
21
22use super::config::{build_connect_options, AckPolicy, DeliverPolicy, Mode, NatsSourceConfig};
23use super::metrics::NatsSourceMetrics;
24use crate::checkpoint::SourceCheckpoint;
25use crate::config::ConnectorConfig;
26use crate::connector::{PartitionInfo, SourceBatch, SourceConnector};
27use crate::error::ConnectorError;
28use crate::serde::{self, RecordDeserializer};
29
30/// `ack` is `Some` only on the `JetStream` path.
31struct Incoming {
32    subject: String,
33    payload: Bytes,
34    stream_seq: Option<u64>,
35    ack: Option<jetstream::Message>,
36}
37
38struct Running {
39    deserializer: Box<dyn RecordDeserializer>,
40    rx: AsyncRx<mpsc::Array<Incoming>>,
41    shutdown: Arc<Notify>,
42    handle: JoinHandle<()>,
43}
44
45/// NATS source — core and `JetStream` modes.
46pub struct NatsSource {
47    schema: SchemaRef,
48    config: Option<NatsSourceConfig>,
49    data_ready: Arc<Notify>,
50    metrics: NatsSourceMetrics,
51    running: Option<Running>,
52
53    // Interior mutability so `checkpoint()` (takes `&self`) can seal
54    // `pending` into `sealed` atomically with offset capture.
55    pending: parking_lot::Mutex<Vec<jetstream::Message>>,
56    sealed: parking_lot::Mutex<VecDeque<Vec<jetstream::Message>>>,
57    offsets: parking_lot::Mutex<FxHashMap<String, u64>>,
58}
59
60impl NatsSource {
61    /// Metrics register on `registry` if provided.
62    #[must_use]
63    pub fn new(schema: SchemaRef, registry: Option<&prometheus::Registry>) -> Self {
64        Self {
65            schema,
66            config: None,
67            data_ready: Arc::new(Notify::new()),
68            metrics: NatsSourceMetrics::new(registry),
69            running: None,
70            pending: parking_lot::Mutex::new(Vec::new()),
71            sealed: parking_lot::Mutex::new(VecDeque::new()),
72            offsets: parking_lot::Mutex::new(FxHashMap::default()),
73        }
74    }
75
76    fn update_pending_gauge(&self) {
77        let pending_len = self.pending.lock().len();
78        let sealed_total: usize = self.sealed.lock().iter().map(Vec::len).sum();
79        self.metrics.set_pending_acks(pending_len + sealed_total);
80    }
81
82    /// Available after [`SourceConnector::open`].
83    #[must_use]
84    pub fn config(&self) -> Option<&NatsSourceConfig> {
85        self.config.as_ref()
86    }
87
88    /// Snapshot accessor for the prometheus-backed metrics struct.
89    #[must_use]
90    pub fn metrics_handle(&self) -> &NatsSourceMetrics {
91        &self.metrics
92    }
93
94    async fn open_jetstream(
95        &mut self,
96        cfg: &NatsSourceConfig,
97        deserializer: Box<dyn RecordDeserializer>,
98    ) -> Result<(), ConnectorError> {
99        let client = connect(cfg).await?;
100        let js = jetstream::new(client);
101
102        let stream_name = cfg
103            .stream
104            .as_deref()
105            .ok_or_else(|| err("stream name missing after validation"))?;
106        let consumer_name = cfg
107            .consumer
108            .as_deref()
109            .ok_or_else(|| err("consumer name missing after validation"))?;
110
111        let pull_cfg = build_pull_config(cfg, consumer_name)?;
112        let stream = js
113            .get_stream(stream_name)
114            .await
115            .map_err(|e| err(&format!("get_stream('{stream_name}') failed: {e}")))?;
116        let consumer = stream
117            .create_consumer(pull_cfg)
118            .await
119            .map_err(|e| classify_create_consumer_error(&e, consumer_name))?;
120
121        let (tx, rx) = mpsc::bounded_async::<Incoming>(cfg.fetch_batch * 2);
122        let shutdown = Arc::new(Notify::new());
123
124        let reader = JsReader {
125            consumer,
126            tx,
127            shutdown: Arc::clone(&shutdown),
128            consecutive_errors: Arc::new(AtomicU32::new(0)),
129            data_ready: Arc::clone(&self.data_ready),
130            metrics: self.metrics.clone(),
131            batch_size: cfg.fetch_batch,
132            max_wait: cfg.fetch_max_wait,
133            lag_poll_interval: cfg.lag_poll_interval,
134        };
135        let handle = tokio::spawn(reader.run());
136
137        self.running = Some(Running {
138            deserializer,
139            rx,
140            shutdown,
141            handle,
142        });
143        Ok(())
144    }
145
146    async fn open_core(
147        &mut self,
148        cfg: &NatsSourceConfig,
149        deserializer: Box<dyn RecordDeserializer>,
150    ) -> Result<(), ConnectorError> {
151        let client = connect(cfg).await?;
152        let subject = cfg
153            .subject
154            .clone()
155            .ok_or_else(|| err("subject missing after validation"))?;
156        let subscriber = if let Some(group) = cfg.queue_group.as_deref() {
157            client
158                .queue_subscribe(subject, group.to_string())
159                .await
160                .map_err(|e| err(&format!("queue_subscribe: {e}")))?
161        } else {
162            client
163                .subscribe(subject)
164                .await
165                .map_err(|e| err(&format!("subscribe: {e}")))?
166        };
167
168        let (tx, rx) = mpsc::bounded_async::<Incoming>(cfg.fetch_batch * 2);
169        let shutdown = Arc::new(Notify::new());
170
171        let reader = CoreReader {
172            subscriber,
173            tx,
174            shutdown: Arc::clone(&shutdown),
175            data_ready: Arc::clone(&self.data_ready),
176        };
177        let handle = tokio::spawn(reader.run());
178
179        self.running = Some(Running {
180            deserializer,
181            rx,
182            shutdown,
183            handle,
184        });
185        Ok(())
186    }
187}
188
189#[async_trait]
190impl SourceConnector for NatsSource {
191    async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
192        let cfg = NatsSourceConfig::from_config(config)?;
193        // SQL DDL schema overrides the registry placeholder.
194        if let Some(schema) = config.arrow_schema() {
195            self.schema = schema;
196        }
197        let deserializer = serde::create_deserializer(cfg.format)
198            .map_err(|e| err(&format!("deserializer for format {:?}: {e}", cfg.format)))?;
199        match cfg.mode {
200            Mode::JetStream => self.open_jetstream(&cfg, deserializer).await?,
201            Mode::Core => self.open_core(&cfg, deserializer).await?,
202        }
203        self.config = Some(cfg);
204        Ok(())
205    }
206
207    async fn poll_batch(
208        &mut self,
209        max_records: usize,
210    ) -> Result<Option<SourceBatch>, ConnectorError> {
211        let Some(running) = self.running.as_mut() else {
212            return Ok(None);
213        };
214
215        let mut payloads: Vec<Bytes> = Vec::new();
216        let mut partition: Option<PartitionInfo> = None;
217        let mut new_acks: Vec<jetstream::Message> = Vec::new();
218        let mut offset_updates: Vec<(String, u64)> = Vec::new();
219
220        while payloads.len() < max_records {
221            let incoming = match running.rx.try_recv() {
222                Ok(m) => m,
223                Err(TryRecvError::Empty | TryRecvError::Disconnected) => break,
224            };
225            if let Some(seq) = incoming.stream_seq {
226                offset_updates.push((incoming.subject.clone(), seq));
227                if partition.is_none() {
228                    partition = Some(PartitionInfo::new(&incoming.subject, seq.to_string()));
229                }
230            }
231            payloads.push(incoming.payload);
232            if let Some(msg) = incoming.ack {
233                new_acks.push(msg);
234            }
235        }
236
237        if payloads.is_empty() {
238            return Ok(None);
239        }
240
241        let records: Vec<&[u8]> = payloads.iter().map(Bytes::as_ref).collect();
242        let bytes_total: u64 = records.iter().map(|r| r.len() as u64).sum();
243        // Deserialize before parking acks: on failure the handles drop
244        // un-acked and the broker redelivers after ack_wait.
245        let batch = running
246            .deserializer
247            .deserialize_batch(&records, &self.schema)
248            .map_err(|e| err(&format!("deserialize batch: {e}")))?;
249
250        if !offset_updates.is_empty() {
251            let mut offsets = self.offsets.lock();
252            for (subject, seq) in offset_updates {
253                offsets
254                    .entry(subject)
255                    .and_modify(|s| *s = (*s).max(seq))
256                    .or_insert(seq);
257            }
258        }
259        if !new_acks.is_empty() {
260            self.pending.lock().extend(new_acks);
261        }
262
263        self.metrics
264            .record_poll(batch.num_rows() as u64, bytes_total);
265        self.update_pending_gauge();
266
267        Ok(Some(SourceBatch {
268            records: batch,
269            partition,
270        }))
271    }
272
273    fn schema(&self) -> SchemaRef {
274        self.schema.clone()
275    }
276
277    fn checkpoint(&self) -> SourceCheckpoint {
278        // Seal now-pending acks into this epoch; anything arriving after
279        // goes into a fresh batch committed with a later epoch. Without
280        // this split, an ack could fire before its manifest record lands.
281        {
282            let mut pending = self.pending.lock();
283            if !pending.is_empty() {
284                let batch = std::mem::take(&mut *pending);
285                self.sealed.lock().push_back(batch);
286            }
287        }
288        let mut cp = SourceCheckpoint::default();
289        for (subject, seq) in &*self.offsets.lock() {
290            cp.set_offset(subject.as_str(), seq.to_string());
291        }
292        cp
293    }
294
295    async fn restore(&mut self, _checkpoint: &SourceCheckpoint) -> Result<(), ConnectorError> {
296        // Durable consumer ack floor on the server is authoritative.
297        Ok(())
298    }
299
300    async fn close(&mut self) -> Result<(), ConnectorError> {
301        let Some(running) = self.running.take() else {
302            return Ok(());
303        };
304        running.shutdown.notify_one();
305        let _ = tokio::time::timeout(Duration::from_secs(5), running.handle).await;
306        Ok(())
307    }
308
309    fn data_ready_notify(&self) -> Option<Arc<Notify>> {
310        Some(Arc::clone(&self.data_ready))
311    }
312
313    fn supports_replay(&self) -> bool {
314        // JetStream is replayable (durable consumer); core NATS is not.
315        !matches!(self.config.as_ref().map(|c| c.mode), Some(Mode::Core))
316    }
317
318    async fn notify_epoch_committed(
319        &mut self,
320        _epoch: u64,
321        _checkpoint: &crate::checkpoint::SourceCheckpoint,
322    ) -> Result<(), ConnectorError> {
323        // Per-msg ack errors don't roll the epoch back; the broker
324        // redelivers on ack_wait.
325        loop {
326            let Some(batch) = self.sealed.lock().pop_front() else {
327                break;
328            };
329            for msg in batch {
330                match msg.ack().await {
331                    Ok(()) => self.metrics.record_ack(),
332                    Err(e) => {
333                        self.metrics.record_ack_error();
334                        warn!(error = %e, "JetStream ack failed; broker will redeliver");
335                    }
336                }
337            }
338        }
339        self.update_pending_gauge();
340        Ok(())
341    }
342}
343
344// ── helpers ──
345
346fn err(msg: &str) -> ConnectorError {
347    ConnectorError::ConfigurationError(msg.to_string())
348}
349
350async fn connect(cfg: &NatsSourceConfig) -> Result<async_nats::Client, ConnectorError> {
351    build_connect_options(&cfg.auth, &cfg.tls)?
352        .connect(&cfg.servers)
353        .await
354        .map_err(|e| err(&format!("nats connect({:?}): {e}", cfg.servers)))
355}
356
357fn build_pull_config(
358    cfg: &NatsSourceConfig,
359    consumer_name: &str,
360) -> Result<pull::Config, ConnectorError> {
361    let filter_subjects = if cfg.subject_filters.is_empty() {
362        cfg.subject.iter().cloned().collect()
363    } else {
364        cfg.subject_filters.clone()
365    };
366
367    Ok(pull::Config {
368        durable_name: Some(consumer_name.to_string()),
369        filter_subjects,
370        deliver_policy: map_deliver_policy(cfg)?,
371        ack_policy: map_ack_policy(cfg.ack_policy),
372        ack_wait: cfg.ack_wait,
373        max_deliver: cfg.max_deliver,
374        max_ack_pending: cfg.max_ack_pending,
375        ..Default::default()
376    })
377}
378
379fn map_deliver_policy(
380    cfg: &NatsSourceConfig,
381) -> Result<async_nats::jetstream::consumer::DeliverPolicy, ConnectorError> {
382    use async_nats::jetstream::consumer::DeliverPolicy as Nats;
383    Ok(match cfg.deliver_policy {
384        DeliverPolicy::All => Nats::All,
385        DeliverPolicy::New => Nats::New,
386        DeliverPolicy::ByStartSequence => Nats::ByStartSequence {
387            start_sequence: cfg.start_sequence.unwrap_or(1),
388        },
389        DeliverPolicy::ByStartTime => {
390            let raw = cfg
391                .start_time
392                .as_deref()
393                .ok_or_else(|| err("deliver.policy=by_start_time requires 'start.time'"))?;
394            let start_time =
395                time::OffsetDateTime::parse(raw, &time::format_description::well_known::Rfc3339)
396                    .map_err(|e| err(&format!("start.time '{raw}' is not valid RFC3339: {e}")))?;
397            Nats::ByStartTime { start_time }
398        }
399    })
400}
401
402fn map_ack_policy(p: AckPolicy) -> async_nats::jetstream::consumer::AckPolicy {
403    use async_nats::jetstream::consumer::AckPolicy as Nats;
404    match p {
405        AckPolicy::Explicit => Nats::Explicit,
406        AckPolicy::None => Nats::None,
407    }
408}
409
410/// 500ms, 1s, 2s, 4s, cap 5s.
411fn fetch_backoff_base(consecutive_errors: u32) -> Duration {
412    let exp = consecutive_errors.saturating_sub(1).min(4);
413    let ms = 500u64.saturating_mul(1u64 << exp);
414    Duration::from_millis(ms.min(5000))
415}
416
417/// `base ± 20%`. Tests pass a fixed `entropy` seed.
418fn with_jitter(base: Duration, entropy: u64) -> Duration {
419    let base_ms = u64::try_from(base.as_millis()).unwrap_or(u64::MAX);
420    let range = (base_ms / 5).max(1); // 20%
421    let window = range * 2 + 1;
422    let offset = entropy % window;
423    let jittered = base_ms.saturating_add(offset).saturating_sub(range);
424    Duration::from_millis(jittered)
425}
426
427fn fetch_backoff(consecutive_errors: u32, entropy: u64) -> Duration {
428    with_jitter(fetch_backoff_base(consecutive_errors), entropy)
429}
430
431/// Server 10148 / 10013 → consumer exists with a conflicting config;
432/// raise LDB-5070 with an operator fix-up.
433fn classify_create_consumer_error(
434    e: &async_nats::jetstream::stream::ConsumerError,
435    consumer_name: &str,
436) -> ConnectorError {
437    use async_nats::jetstream::stream::ConsumerErrorKind;
438    use async_nats::jetstream::ErrorCode;
439
440    let drift_code = match e.kind() {
441        ConsumerErrorKind::JetStream(server_err) => matches!(
442            server_err.error_code(),
443            ErrorCode::CONSUMER_ALREADY_EXISTS | ErrorCode::CONSUMER_NAME_EXIST
444        ),
445        _ => false,
446    };
447    if drift_code {
448        err(&format!(
449            "[LDB-5070] consumer '{consumer_name}' exists with incompatible config; \
450             rotate the durable name or delete the consumer out-of-band. \
451             Server said: {e}"
452        ))
453    } else {
454        err(&format!("create_consumer('{consumer_name}') failed: {e}"))
455    }
456}
457
458/// Wall-clock nanos for `with_jitter`. `Instant::now().elapsed()` is ~0
459/// and produces correlated jitter across tasks.
460#[allow(clippy::cast_possible_truncation)]
461fn entropy_now() -> u64 {
462    std::time::SystemTime::now()
463        .duration_since(std::time::UNIX_EPOCH)
464        .unwrap_or_default()
465        .as_nanos() as u64
466}
467
468struct JsReader {
469    consumer: jetstream::consumer::Consumer<pull::Config>,
470    tx: MAsyncTx<mpsc::Array<Incoming>>,
471    shutdown: Arc<Notify>,
472    consecutive_errors: Arc<AtomicU32>,
473    data_ready: Arc<Notify>,
474    metrics: NatsSourceMetrics,
475    batch_size: usize,
476    max_wait: Duration,
477    /// `Duration::ZERO` disables the poll.
478    lag_poll_interval: Duration,
479}
480
481impl JsReader {
482    async fn run(self) {
483        let Self {
484            mut consumer,
485            tx,
486            shutdown,
487            consecutive_errors,
488            data_ready,
489            metrics,
490            batch_size,
491            max_wait,
492            lag_poll_interval,
493        } = self;
494
495        let mut last_lag_poll = Instant::now();
496        let lag_poll_enabled = !lag_poll_interval.is_zero();
497
498        loop {
499            let fetch_result = tokio::select! {
500                biased;
501                () = shutdown.notified() => break,
502                r = consumer.fetch().max_messages(batch_size).expires(max_wait).messages() => r,
503            };
504
505            let mut stream = match fetch_result {
506                Ok(s) => s,
507                Err(e) => {
508                    let errs = consecutive_errors.fetch_add(1, Ordering::AcqRel) + 1;
509                    metrics.record_fetch_error();
510                    warn!(
511                        error = %e,
512                        consecutive_errors = errs,
513                        "nats fetch() errored; backing off",
514                    );
515                    let backoff = fetch_backoff(errs, entropy_now());
516                    tokio::select! {
517                        biased;
518                        () = shutdown.notified() => break,
519                        () = tokio::time::sleep(backoff) => {}
520                    }
521                    continue;
522                }
523            };
524
525            let mut forwarded = 0usize;
526            let mut stream_errors = 0usize;
527            loop {
528                let msg_result = tokio::select! {
529                    biased;
530                    () = shutdown.notified() => return,
531                    r = stream.next() => match r {
532                        Some(r) => r,
533                        None => break,
534                    },
535                };
536                let msg = match msg_result {
537                    Ok(m) => m,
538                    Err(e) => {
539                        metrics.record_fetch_error();
540                        stream_errors += 1;
541                        warn!(error = %e, "nats message error");
542                        continue;
543                    }
544                };
545                let incoming = Incoming {
546                    subject: msg.subject.to_string(),
547                    payload: msg.payload.clone(),
548                    stream_seq: msg.info().ok().map(|i| i.stream_sequence),
549                    ack: Some(msg),
550                };
551                if tx.send(incoming).await.is_err() {
552                    debug!("nats reader: downstream channel closed");
553                    return;
554                }
555                forwarded += 1;
556            }
557
558            // Reset on progress; an iteration with only errors counts
559            // as one failure; idle iterations don't bump.
560            if forwarded > 0 {
561                consecutive_errors.store(0, Ordering::Release);
562                data_ready.notify_one();
563            } else if stream_errors > 0 {
564                let errs = consecutive_errors.fetch_add(1, Ordering::AcqRel) + 1;
565                let backoff = fetch_backoff(errs, entropy_now());
566                tokio::select! {
567                    biased;
568                    () = shutdown.notified() => break,
569                    () = tokio::time::sleep(backoff) => {}
570                }
571            }
572
573            if lag_poll_enabled && last_lag_poll.elapsed() >= lag_poll_interval {
574                last_lag_poll = Instant::now();
575                match consumer.info().await {
576                    Ok(info) => metrics.set_consumer_lag(info.num_pending),
577                    Err(e) => warn!(error = %e, "consumer.info() failed; skipping lag update"),
578                }
579            }
580        }
581    }
582}
583
584struct CoreReader {
585    subscriber: async_nats::Subscriber,
586    tx: MAsyncTx<mpsc::Array<Incoming>>,
587    shutdown: Arc<Notify>,
588    data_ready: Arc<Notify>,
589}
590
591impl CoreReader {
592    async fn run(self) {
593        let Self {
594            mut subscriber,
595            tx,
596            shutdown,
597            data_ready,
598        } = self;
599
600        loop {
601            let msg = tokio::select! {
602                biased;
603                () = shutdown.notified() => break,
604                m = subscriber.next() => match m {
605                    Some(m) => m,
606                    None => break,
607                },
608            };
609            let incoming = Incoming {
610                subject: msg.subject.to_string(),
611                payload: msg.payload,
612                stream_seq: None,
613                ack: None,
614            };
615            if tx.send(incoming).await.is_err() {
616                return;
617            }
618            data_ready.notify_one();
619        }
620    }
621}
622
623#[cfg(test)]
624mod tests {
625    use super::*;
626    use arrow_schema::Schema;
627
628    #[test]
629    fn checkpoint_empty_pending_is_noop() {
630        // `jetstream::Message` can't be constructed without a server,
631        // so we only cover the empty path here; non-empty sealing is
632        // exercised in `tests/nats_integration.rs`.
633        let src = NatsSource::new(Arc::new(Schema::empty()), None);
634        let _ = src.checkpoint();
635        assert!(src.sealed.lock().is_empty());
636    }
637
638    #[test]
639    fn backoff_base_grows_then_caps_at_5s() {
640        assert_eq!(fetch_backoff_base(1), Duration::from_millis(500));
641        assert_eq!(fetch_backoff_base(2), Duration::from_millis(1000));
642        assert_eq!(fetch_backoff_base(3), Duration::from_millis(2000));
643        assert_eq!(fetch_backoff_base(4), Duration::from_millis(4000));
644        assert_eq!(fetch_backoff_base(5), Duration::from_millis(5000));
645        assert_eq!(fetch_backoff_base(100), Duration::from_millis(5000));
646    }
647
648    #[test]
649    fn jitter_stays_within_plus_minus_20_percent() {
650        let base = Duration::from_millis(1000);
651        for entropy in [0u64, 1, 99, 12345, u64::MAX] {
652            let j = with_jitter(base, entropy);
653            assert!(
654                j >= Duration::from_millis(800) && j <= Duration::from_millis(1200),
655                "entropy {entropy}: jittered = {j:?} outside ±20% of {base:?}"
656            );
657        }
658    }
659}