Skip to main content

laminar_connectors/nats/
sink.rs

1//! NATS sink. Core publishes are fire-and-forget; `JetStream` collects
2//! `PublishAckFuture`s and drains them in `flush` / `pre_commit`.
3//! Exactly-once uses server-side `Nats-Msg-Id` dedup (see `LDB-5056`).
4
5use std::collections::VecDeque;
6use std::future::IntoFuture;
7use std::time::Duration;
8
9use arrow_array::{cast::AsArray, Array, RecordBatch, StringArray};
10use arrow_schema::SchemaRef;
11use async_nats::jetstream::{self, context::PublishAckFuture};
12use async_nats::{Client, HeaderMap};
13use async_trait::async_trait;
14use futures_util::stream::{FuturesUnordered, StreamExt};
15
16use super::config::{build_connect_options, Mode, NatsSinkConfig, SubjectSpec};
17use super::metrics::NatsSinkMetrics;
18use crate::config::ConnectorConfig;
19use crate::connector::{DeliveryGuarantee, SinkConnector, SinkConnectorCapabilities, WriteResult};
20use crate::error::ConnectorError;
21use crate::health::HealthStatus;
22use crate::metrics::ConnectorMetrics;
23use crate::serde::{self, RecordSerializer};
24
25/// NATS sink — core and `JetStream` modes.
26pub struct NatsSink {
27    schema: SchemaRef,
28    config: Option<NatsSinkConfig>,
29    serializer: Option<Box<dyn RecordSerializer>>,
30    runtime: Option<Runtime>,
31    metrics: NatsSinkMetrics,
32    /// Drained in `flush` / `pre_commit`.
33    pending_acks: VecDeque<PublishAckFuture>,
34}
35
36enum Runtime {
37    Core { client: Client },
38    JetStream { context: jetstream::Context },
39}
40
41impl NatsSink {
42    /// Metrics register on `registry` if provided.
43    #[must_use]
44    pub fn new(schema: SchemaRef, registry: Option<&prometheus::Registry>) -> Self {
45        Self {
46            schema,
47            config: None,
48            serializer: None,
49            runtime: None,
50            metrics: NatsSinkMetrics::new(registry),
51            pending_acks: VecDeque::new(),
52        }
53    }
54
55    /// Available after [`SinkConnector::open`].
56    #[must_use]
57    pub fn config(&self) -> Option<&NatsSinkConfig> {
58        self.config.as_ref()
59    }
60}
61
62#[async_trait]
63impl SinkConnector for NatsSink {
64    async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
65        let cfg = NatsSinkConfig::from_config(config)?;
66        self.serializer = Some(
67            serde::create_serializer(cfg.format)
68                .map_err(|e| err(&format!("serializer for format {:?}: {e}", cfg.format)))?,
69        );
70        let client = build_connect_options(&cfg.auth, &cfg.tls)?
71            .connect(&cfg.servers)
72            .await
73            .map_err(|e| err(&format!("nats connect({:?}): {e}", cfg.servers)))?;
74        self.runtime = Some(match cfg.mode {
75            Mode::Core => Runtime::Core { client },
76            Mode::JetStream => {
77                let context = jetstream::new(client);
78                if let Some(stream_name) = cfg.stream.as_deref() {
79                    // Fail fast on a bad name.
80                    let stream = context
81                        .get_stream(stream_name)
82                        .await
83                        .map_err(|e| err(&format!("get_stream('{stream_name}') failed: {e}")))?;
84                    if cfg.delivery_guarantee == DeliveryGuarantee::ExactlyOnce {
85                        let info = stream.cached_info();
86                        let actual = info.config.duplicate_window;
87                        if actual < cfg.min_duplicate_window {
88                            return Err(err(&format!(
89                                "[LDB-5056] stream '{stream_name}' has duplicate_window={actual:?}, \
90                                 below the configured minimum {:?}. Rollback redelivery could land \
91                                 outside the dedup horizon. Reconfigure the stream or lower \
92                                 'min.duplicate.window.ms'.",
93                                cfg.min_duplicate_window,
94                            )));
95                        }
96                    }
97                }
98                Runtime::JetStream { context }
99            }
100        });
101        self.config = Some(cfg);
102        Ok(())
103    }
104
105    async fn write_batch(&mut self, batch: &RecordBatch) -> Result<WriteResult, ConnectorError> {
106        // Split-borrow: `runtime` &mut, config/serializer &.
107        let Self {
108            config,
109            serializer,
110            runtime,
111            pending_acks,
112            metrics,
113            schema: _,
114        } = self;
115        let cfg = config.as_ref().ok_or_else(|| err("sink: open() first"))?;
116        let ser = serializer
117            .as_ref()
118            .ok_or_else(|| err("sink: open() first"))?;
119        let rt = runtime.as_mut().ok_or_else(|| err("sink: open() first"))?;
120
121        let subject_col = match &cfg.subject {
122            SubjectSpec::Column(name) => Some(resolve_utf8(batch, name)?),
123            SubjectSpec::Literal(_) => None,
124        };
125        let header_cols: Vec<(&str, &StringArray)> = cfg
126            .header_columns
127            .iter()
128            .map(|n| resolve_utf8(batch, n).map(|arr| (n.as_str(), arr)))
129            .collect::<Result<_, _>>()?;
130        let expected_stream = cfg.expected_stream.as_deref();
131        // `Nats-Msg-Id` only under exactly-once — LDB-5056 validates the
132        // stream's `duplicate_window` at open, and isn't checked otherwise.
133        let dedup_col = if cfg.delivery_guarantee == DeliveryGuarantee::ExactlyOnce {
134            cfg.dedup_id_column
135                .as_deref()
136                .map(|n| resolve_utf8(batch, n).map(|arr| (n, arr)))
137                .transpose()?
138        } else {
139            None
140        };
141
142        let records = ser
143            .serialize(batch)
144            .map_err(|e| err(&format!("serialize batch: {e}")))?;
145
146        let mut bytes_total: u64 = 0;
147        let mut rows_written: usize = 0;
148        for (row, payload) in records.into_iter().enumerate() {
149            let subject: &str = match (&cfg.subject, subject_col) {
150                (SubjectSpec::Literal(s), _) => s.as_str(),
151                (SubjectSpec::Column(name), Some(arr)) => {
152                    non_null(arr, row, "subject.column", name)?
153                }
154                (SubjectSpec::Column(_), None) => unreachable!("resolved above"),
155            };
156            let msg_id = dedup_col
157                .map(|(name, arr)| non_null(arr, row, "dedup.id.column", name))
158                .transpose()?;
159            let headers = build_headers(expected_stream, msg_id, &header_cols, row);
160            let payload_len = payload.len() as u64;
161            let payload = bytes::Bytes::from(payload);
162
163            match rt {
164                Runtime::Core { client } => {
165                    let result = if let Some(h) = headers {
166                        client
167                            .publish_with_headers(subject.to_string(), h, payload)
168                            .await
169                    } else {
170                        client.publish(subject.to_string(), payload).await
171                    };
172                    if let Err(e) = result {
173                        metrics.record_publish_error();
174                        return Err(err(&format!("core publish: {e}")));
175                    }
176                }
177                Runtime::JetStream { context } => {
178                    if pending_acks.len() >= cfg.max_pending {
179                        drain_acks(pending_acks, metrics, cfg.ack_timeout).await?;
180                    }
181                    let publish_result = if let Some(h) = headers {
182                        context
183                            .publish_with_headers(subject.to_string(), h, payload)
184                            .await
185                    } else {
186                        context.publish(subject.to_string(), payload).await
187                    };
188                    match publish_result {
189                        Ok(fut) => pending_acks.push_back(fut),
190                        Err(e) => {
191                            metrics.record_publish_error();
192                            return Err(err(&format!("jetstream publish: {e}")));
193                        }
194                    }
195                }
196            }
197
198            // Per-row so partial failures still credit successes.
199            metrics.record_published_row(payload_len);
200            rows_written += 1;
201            bytes_total += payload_len;
202        }
203
204        metrics.set_pending_futures(pending_acks.len());
205        Ok(WriteResult::new(rows_written, bytes_total))
206    }
207
208    fn schema(&self) -> SchemaRef {
209        self.schema.clone()
210    }
211
212    fn capabilities(&self) -> SinkConnectorCapabilities {
213        let mut caps = SinkConnectorCapabilities::new(Duration::from_secs(5))
214            .with_idempotent()
215            .with_partitioned();
216        if matches!(
217            self.config.as_ref().map(|c| c.delivery_guarantee),
218            Some(DeliveryGuarantee::ExactlyOnce)
219        ) {
220            caps = caps.with_exactly_once().with_two_phase_commit();
221        }
222        caps
223    }
224
225    async fn flush(&mut self) -> Result<(), ConnectorError> {
226        match self.runtime.as_ref() {
227            Some(Runtime::Core { client }) => client
228                .flush()
229                .await
230                .map_err(|e| err(&format!("core flush: {e}")))?,
231            Some(Runtime::JetStream { .. }) | None => {}
232        }
233        let timeout = self
234            .config
235            .as_ref()
236            .map_or(Duration::from_secs(30), |c| c.ack_timeout);
237        drain_acks(&mut self.pending_acks, &self.metrics, timeout).await
238    }
239
240    async fn pre_commit(&mut self, _epoch: u64) -> Result<(), ConnectorError> {
241        let timeout = self
242            .config
243            .as_ref()
244            .map_or(Duration::from_secs(30), |c| c.ack_timeout);
245        drain_acks(&mut self.pending_acks, &self.metrics, timeout).await
246    }
247
248    async fn rollback_epoch(&mut self, _epoch: u64) -> Result<(), ConnectorError> {
249        // Retry safely: LDB-5056 ensures the dedup window covers this gap.
250        self.pending_acks.clear();
251        self.metrics.set_pending_futures(0);
252        self.metrics.record_rollback();
253        Ok(())
254    }
255
256    async fn close(&mut self) -> Result<(), ConnectorError> {
257        let timeout = self
258            .config
259            .as_ref()
260            .map_or(Duration::from_secs(5), |c| c.ack_timeout);
261        // async-nats buffers core publishes client-side; flush before drop.
262        if let Some(Runtime::Core { client }) = self.runtime.as_ref() {
263            let _ = client.flush().await;
264        }
265        let _ = drain_acks(&mut self.pending_acks, &self.metrics, timeout).await;
266        self.runtime = None;
267        Ok(())
268    }
269
270    fn health_check(&self) -> HealthStatus {
271        match self.config.as_ref() {
272            None => HealthStatus::Unknown,
273            Some(cfg) => {
274                let cap = cfg.max_pending.max(1) as u64;
275                #[allow(clippy::cast_sign_loss)]
276                let pending = self.metrics.pending_futures.get().max(0) as u64;
277                if pending * 2 >= cap {
278                    HealthStatus::Degraded(format!(
279                        "pending publish acks {pending}/{cap} — ack drain may stall pre_commit"
280                    ))
281                } else {
282                    HealthStatus::Healthy
283                }
284            }
285        }
286    }
287
288    fn metrics(&self) -> ConnectorMetrics {
289        self.metrics.to_connector_metrics()
290    }
291}
292
293/// Drain `pending` concurrently, bounded by `timeout`. On deadline,
294/// each still-unresolved ack bumps `record_ack_error` once; the publish
295/// may have landed server-side, so exactly-once depends on dedup to
296/// swallow the retry.
297async fn drain_acks(
298    pending: &mut VecDeque<PublishAckFuture>,
299    metrics: &NatsSinkMetrics,
300    timeout: Duration,
301) -> Result<(), ConnectorError> {
302    if pending.is_empty() {
303        return Ok(());
304    }
305    let mut set: FuturesUnordered<_> = pending.drain(..).map(IntoFuture::into_future).collect();
306    let deadline = tokio::time::Instant::now() + timeout;
307    let mut first_err: Option<ConnectorError> = None;
308
309    loop {
310        if set.is_empty() {
311            break;
312        }
313        match tokio::time::timeout_at(deadline, set.next()).await {
314            Ok(Some(Ok(ack))) => {
315                if ack.duplicate {
316                    metrics.record_dedup();
317                }
318            }
319            Ok(Some(Err(e))) => {
320                metrics.record_ack_error();
321                if first_err.is_none() {
322                    first_err = Some(err(&format!("jetstream publish ack: {e}")));
323                }
324            }
325            Ok(None) => break,
326            Err(_) => {
327                let lost = set.len();
328                for _ in 0..lost {
329                    metrics.record_ack_error();
330                }
331                metrics.set_pending_futures(pending.len());
332                return Err(err(&format!(
333                    "jetstream publish ack: timed out with {lost} still in flight"
334                )));
335            }
336        }
337    }
338
339    metrics.set_pending_futures(pending.len());
340    first_err.map_or(Ok(()), Err)
341}
342
343fn resolve_utf8<'a>(batch: &'a RecordBatch, name: &str) -> Result<&'a StringArray, ConnectorError> {
344    let col = batch
345        .column_by_name(name)
346        .ok_or_else(|| err(&format!("column '{name}' not in batch schema")))?;
347    col.as_string_opt::<i32>()
348        .ok_or_else(|| err(&format!("column '{name}' must be Utf8")))
349}
350
351fn non_null<'a>(
352    arr: &'a StringArray,
353    row: usize,
354    kind: &str,
355    name: &str,
356) -> Result<&'a str, ConnectorError> {
357    if arr.is_null(row) {
358        Err(err(&format!("{kind} '{name}' is null at row {row}")))
359    } else {
360        Ok(arr.value(row))
361    }
362}
363
364fn build_headers(
365    expected_stream: Option<&str>,
366    msg_id: Option<&str>,
367    header_cols: &[(&str, &StringArray)],
368    row: usize,
369) -> Option<HeaderMap> {
370    if header_cols.is_empty() && expected_stream.is_none() && msg_id.is_none() {
371        return None;
372    }
373    let mut h = HeaderMap::new();
374    if let Some(s) = expected_stream {
375        h.insert("Nats-Expected-Stream", s);
376    }
377    if let Some(id) = msg_id {
378        h.insert("Nats-Msg-Id", id);
379    }
380    for (name, arr) in header_cols {
381        if !arr.is_null(row) {
382            h.insert(*name, arr.value(row));
383        }
384    }
385    Some(h)
386}
387
388fn err(msg: &str) -> ConnectorError {
389    ConnectorError::ConfigurationError(msg.to_string())
390}