Skip to main content

laminar_connectors/nats/
sink.rs

1//! NATS sink. Core publishes are fire-and-forget; `JetStream` collects
2//! `PublishAckFuture`s and drains them in `flush` / `pre_commit`.
3//! Exactly-once uses server-side `Nats-Msg-Id` dedup (see `LDB-5056`).
4
5use std::collections::VecDeque;
6use std::future::IntoFuture;
7use std::time::Duration;
8
9use arrow_array::{cast::AsArray, Array, RecordBatch, StringArray};
10use arrow_schema::SchemaRef;
11use async_nats::jetstream::{self, context::PublishAckFuture};
12use async_nats::{Client, HeaderMap};
13use async_trait::async_trait;
14use futures_util::stream::{FuturesUnordered, StreamExt};
15
16use super::config::{build_connect_options, Mode, NatsSinkConfig, SubjectSpec};
17use super::metrics::NatsSinkMetrics;
18use crate::config::ConnectorConfig;
19use crate::connector::{DeliveryGuarantee, SinkConnector, SinkConnectorCapabilities, WriteResult};
20use crate::error::ConnectorError;
21use crate::serde::{self, RecordSerializer};
22
23/// NATS sink — core and `JetStream` modes.
24pub struct NatsSink {
25    schema: SchemaRef,
26    config: Option<NatsSinkConfig>,
27    serializer: Option<Box<dyn RecordSerializer>>,
28    runtime: Option<Runtime>,
29    metrics: NatsSinkMetrics,
30    /// Drained in `flush` / `pre_commit`.
31    pending_acks: VecDeque<PublishAckFuture>,
32}
33
34enum Runtime {
35    Core { client: Client },
36    JetStream { context: jetstream::Context },
37}
38
39impl NatsSink {
40    /// Metrics register on `registry` if provided.
41    #[must_use]
42    pub fn new(schema: SchemaRef, registry: Option<&prometheus::Registry>) -> Self {
43        Self {
44            schema,
45            config: None,
46            serializer: None,
47            runtime: None,
48            metrics: NatsSinkMetrics::new(registry),
49            pending_acks: VecDeque::new(),
50        }
51    }
52
53    /// Available after [`SinkConnector::open`].
54    #[must_use]
55    pub fn config(&self) -> Option<&NatsSinkConfig> {
56        self.config.as_ref()
57    }
58
59    /// Snapshot accessor for the prometheus-backed metrics struct.
60    #[must_use]
61    pub fn metrics_handle(&self) -> &NatsSinkMetrics {
62        &self.metrics
63    }
64}
65
66#[async_trait]
67impl SinkConnector for NatsSink {
68    async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
69        let cfg = NatsSinkConfig::from_config(config)?;
70        self.serializer = Some(
71            serde::create_serializer(cfg.format)
72                .map_err(|e| err(&format!("serializer for format {:?}: {e}", cfg.format)))?,
73        );
74        let client = build_connect_options(&cfg.auth, &cfg.tls)?
75            .connect(&cfg.servers)
76            .await
77            .map_err(|e| err(&format!("nats connect({:?}): {e}", cfg.servers)))?;
78        self.runtime = Some(match cfg.mode {
79            Mode::Core => Runtime::Core { client },
80            Mode::JetStream => {
81                let context = jetstream::new(client);
82                if let Some(stream_name) = cfg.stream.as_deref() {
83                    // Fail fast on a bad name.
84                    let stream = context
85                        .get_stream(stream_name)
86                        .await
87                        .map_err(|e| err(&format!("get_stream('{stream_name}') failed: {e}")))?;
88                    if cfg.delivery_guarantee == DeliveryGuarantee::ExactlyOnce {
89                        let info = stream.cached_info();
90                        let actual = info.config.duplicate_window;
91                        if actual < cfg.min_duplicate_window {
92                            return Err(err(&format!(
93                                "[LDB-5056] stream '{stream_name}' has duplicate_window={actual:?}, \
94                                 below the configured minimum {:?}. Rollback redelivery could land \
95                                 outside the dedup horizon. Reconfigure the stream or lower \
96                                 'min.duplicate.window.ms'.",
97                                cfg.min_duplicate_window,
98                            )));
99                        }
100                    }
101                }
102                Runtime::JetStream { context }
103            }
104        });
105        self.config = Some(cfg);
106        Ok(())
107    }
108
109    async fn write_batch(&mut self, batch: &RecordBatch) -> Result<WriteResult, ConnectorError> {
110        // Split-borrow: `runtime` &mut, config/serializer &.
111        let Self {
112            config,
113            serializer,
114            runtime,
115            pending_acks,
116            metrics,
117            schema: _,
118        } = self;
119        let cfg = config.as_ref().ok_or_else(|| err("sink: open() first"))?;
120        let ser = serializer
121            .as_ref()
122            .ok_or_else(|| err("sink: open() first"))?;
123        let rt = runtime.as_mut().ok_or_else(|| err("sink: open() first"))?;
124
125        let subject_col = match &cfg.subject {
126            SubjectSpec::Column(name) => Some(resolve_utf8(batch, name)?),
127            SubjectSpec::Literal(_) => None,
128        };
129        let header_cols: Vec<(&str, &StringArray)> = cfg
130            .header_columns
131            .iter()
132            .map(|n| resolve_utf8(batch, n).map(|arr| (n.as_str(), arr)))
133            .collect::<Result<_, _>>()?;
134        let expected_stream = cfg.expected_stream.as_deref();
135        // `Nats-Msg-Id` only under exactly-once — LDB-5056 validates the
136        // stream's `duplicate_window` at open, and isn't checked otherwise.
137        let dedup_col = if cfg.delivery_guarantee == DeliveryGuarantee::ExactlyOnce {
138            cfg.dedup_id_column
139                .as_deref()
140                .map(|n| resolve_utf8(batch, n).map(|arr| (n, arr)))
141                .transpose()?
142        } else {
143            None
144        };
145
146        let records = ser
147            .serialize(batch)
148            .map_err(|e| err(&format!("serialize batch: {e}")))?;
149
150        let mut bytes_total: u64 = 0;
151        let mut rows_written: usize = 0;
152        for (row, payload) in records.into_iter().enumerate() {
153            let subject: &str = match (&cfg.subject, subject_col) {
154                (SubjectSpec::Literal(s), _) => s.as_str(),
155                (SubjectSpec::Column(name), Some(arr)) => {
156                    non_null(arr, row, "subject.column", name)?
157                }
158                (SubjectSpec::Column(_), None) => unreachable!("resolved above"),
159            };
160            let msg_id = dedup_col
161                .map(|(name, arr)| non_null(arr, row, "dedup.id.column", name))
162                .transpose()?;
163            let headers = build_headers(expected_stream, msg_id, &header_cols, row);
164            let payload_len = payload.len() as u64;
165            let payload = bytes::Bytes::from(payload);
166
167            match rt {
168                Runtime::Core { client } => {
169                    let result = if let Some(h) = headers {
170                        client
171                            .publish_with_headers(subject.to_string(), h, payload)
172                            .await
173                    } else {
174                        client.publish(subject.to_string(), payload).await
175                    };
176                    if let Err(e) = result {
177                        metrics.record_publish_error();
178                        return Err(err(&format!("core publish: {e}")));
179                    }
180                }
181                Runtime::JetStream { context } => {
182                    if pending_acks.len() >= cfg.max_pending {
183                        drain_acks(pending_acks, metrics, cfg.ack_timeout).await?;
184                    }
185                    let publish_result = if let Some(h) = headers {
186                        context
187                            .publish_with_headers(subject.to_string(), h, payload)
188                            .await
189                    } else {
190                        context.publish(subject.to_string(), payload).await
191                    };
192                    match publish_result {
193                        Ok(fut) => pending_acks.push_back(fut),
194                        Err(e) => {
195                            metrics.record_publish_error();
196                            return Err(err(&format!("jetstream publish: {e}")));
197                        }
198                    }
199                }
200            }
201
202            // Per-row so partial failures still credit successes.
203            metrics.record_published_row(payload_len);
204            rows_written += 1;
205            bytes_total += payload_len;
206        }
207
208        metrics.set_pending_futures(pending_acks.len());
209        Ok(WriteResult::new(rows_written, bytes_total))
210    }
211
212    fn schema(&self) -> SchemaRef {
213        self.schema.clone()
214    }
215
216    fn capabilities(&self) -> SinkConnectorCapabilities {
217        let mut caps = SinkConnectorCapabilities::new(Duration::from_secs(5))
218            .with_idempotent()
219            .with_partitioned();
220        if matches!(
221            self.config.as_ref().map(|c| c.delivery_guarantee),
222            Some(DeliveryGuarantee::ExactlyOnce)
223        ) {
224            caps = caps.with_exactly_once().with_two_phase_commit();
225        }
226        caps
227    }
228
229    async fn flush(&mut self) -> Result<(), ConnectorError> {
230        match self.runtime.as_ref() {
231            Some(Runtime::Core { client }) => client
232                .flush()
233                .await
234                .map_err(|e| err(&format!("core flush: {e}")))?,
235            Some(Runtime::JetStream { .. }) | None => {}
236        }
237        let timeout = self
238            .config
239            .as_ref()
240            .map_or(Duration::from_secs(30), |c| c.ack_timeout);
241        drain_acks(&mut self.pending_acks, &self.metrics, timeout).await
242    }
243
244    async fn pre_commit(&mut self, _epoch: u64) -> Result<(), ConnectorError> {
245        let timeout = self
246            .config
247            .as_ref()
248            .map_or(Duration::from_secs(30), |c| c.ack_timeout);
249        drain_acks(&mut self.pending_acks, &self.metrics, timeout).await
250    }
251
252    async fn rollback_epoch(&mut self, _epoch: u64) -> Result<(), ConnectorError> {
253        // Retry safely: LDB-5056 ensures the dedup window covers this gap.
254        self.pending_acks.clear();
255        self.metrics.set_pending_futures(0);
256        self.metrics.record_rollback();
257        Ok(())
258    }
259
260    async fn close(&mut self) -> Result<(), ConnectorError> {
261        let timeout = self
262            .config
263            .as_ref()
264            .map_or(Duration::from_secs(5), |c| c.ack_timeout);
265        // async-nats buffers core publishes client-side; flush before drop.
266        if let Some(Runtime::Core { client }) = self.runtime.as_ref() {
267            let _ = client.flush().await;
268        }
269        let _ = drain_acks(&mut self.pending_acks, &self.metrics, timeout).await;
270        self.runtime = None;
271        Ok(())
272    }
273}
274
275/// Drain `pending` concurrently, bounded by `timeout`. On deadline,
276/// each still-unresolved ack bumps `record_ack_error` once; the publish
277/// may have landed server-side, so exactly-once depends on dedup to
278/// swallow the retry.
279async fn drain_acks(
280    pending: &mut VecDeque<PublishAckFuture>,
281    metrics: &NatsSinkMetrics,
282    timeout: Duration,
283) -> Result<(), ConnectorError> {
284    if pending.is_empty() {
285        return Ok(());
286    }
287    let mut set: FuturesUnordered<_> = pending.drain(..).map(IntoFuture::into_future).collect();
288    let deadline = tokio::time::Instant::now() + timeout;
289    let mut first_err: Option<ConnectorError> = None;
290
291    loop {
292        if set.is_empty() {
293            break;
294        }
295        match tokio::time::timeout_at(deadline, set.next()).await {
296            Ok(Some(Ok(ack))) => {
297                if ack.duplicate {
298                    metrics.record_dedup();
299                }
300            }
301            Ok(Some(Err(e))) => {
302                metrics.record_ack_error();
303                if first_err.is_none() {
304                    first_err = Some(err(&format!("jetstream publish ack: {e}")));
305                }
306            }
307            Ok(None) => break,
308            Err(_) => {
309                let lost = set.len();
310                for _ in 0..lost {
311                    metrics.record_ack_error();
312                }
313                metrics.set_pending_futures(pending.len());
314                return Err(err(&format!(
315                    "jetstream publish ack: timed out with {lost} still in flight"
316                )));
317            }
318        }
319    }
320
321    metrics.set_pending_futures(pending.len());
322    first_err.map_or(Ok(()), Err)
323}
324
325fn resolve_utf8<'a>(batch: &'a RecordBatch, name: &str) -> Result<&'a StringArray, ConnectorError> {
326    let col = batch
327        .column_by_name(name)
328        .ok_or_else(|| err(&format!("column '{name}' not in batch schema")))?;
329    col.as_string_opt::<i32>()
330        .ok_or_else(|| err(&format!("column '{name}' must be Utf8")))
331}
332
333fn non_null<'a>(
334    arr: &'a StringArray,
335    row: usize,
336    kind: &str,
337    name: &str,
338) -> Result<&'a str, ConnectorError> {
339    if arr.is_null(row) {
340        Err(err(&format!("{kind} '{name}' is null at row {row}")))
341    } else {
342        Ok(arr.value(row))
343    }
344}
345
346fn build_headers(
347    expected_stream: Option<&str>,
348    msg_id: Option<&str>,
349    header_cols: &[(&str, &StringArray)],
350    row: usize,
351) -> Option<HeaderMap> {
352    if header_cols.is_empty() && expected_stream.is_none() && msg_id.is_none() {
353        return None;
354    }
355    let mut h = HeaderMap::new();
356    if let Some(s) = expected_stream {
357        h.insert("Nats-Expected-Stream", s);
358    }
359    if let Some(id) = msg_id {
360        h.insert("Nats-Msg-Id", id);
361    }
362    for (name, arr) in header_cols {
363        if !arr.is_null(row) {
364            h.insert(*name, arr.value(row));
365        }
366    }
367    Some(h)
368}
369
370fn err(msg: &str) -> ConnectorError {
371    ConnectorError::ConfigurationError(msg.to_string())
372}