1use std::collections::VecDeque;
6use std::future::IntoFuture;
7use std::time::Duration;
8
9use arrow_array::{cast::AsArray, Array, RecordBatch, StringArray};
10use arrow_schema::SchemaRef;
11use async_nats::jetstream::{self, context::PublishAckFuture};
12use async_nats::{Client, HeaderMap};
13use async_trait::async_trait;
14use futures_util::stream::{FuturesUnordered, StreamExt};
15
16use super::config::{build_connect_options, Mode, NatsSinkConfig, SubjectSpec};
17use super::metrics::NatsSinkMetrics;
18use crate::config::ConnectorConfig;
19use crate::connector::{DeliveryGuarantee, SinkConnector, SinkConnectorCapabilities, WriteResult};
20use crate::error::ConnectorError;
21use crate::health::HealthStatus;
22use crate::metrics::ConnectorMetrics;
23use crate::serde::{self, RecordSerializer};
24
25pub struct NatsSink {
27 schema: SchemaRef,
28 config: Option<NatsSinkConfig>,
29 serializer: Option<Box<dyn RecordSerializer>>,
30 runtime: Option<Runtime>,
31 metrics: NatsSinkMetrics,
32 pending_acks: VecDeque<PublishAckFuture>,
34}
35
36enum Runtime {
37 Core { client: Client },
38 JetStream { context: jetstream::Context },
39}
40
41impl NatsSink {
42 #[must_use]
44 pub fn new(schema: SchemaRef, registry: Option<&prometheus::Registry>) -> Self {
45 Self {
46 schema,
47 config: None,
48 serializer: None,
49 runtime: None,
50 metrics: NatsSinkMetrics::new(registry),
51 pending_acks: VecDeque::new(),
52 }
53 }
54
55 #[must_use]
57 pub fn config(&self) -> Option<&NatsSinkConfig> {
58 self.config.as_ref()
59 }
60}
61
62#[async_trait]
63impl SinkConnector for NatsSink {
64 async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
65 let cfg = NatsSinkConfig::from_config(config)?;
66 self.serializer = Some(
67 serde::create_serializer(cfg.format)
68 .map_err(|e| err(&format!("serializer for format {:?}: {e}", cfg.format)))?,
69 );
70 let client = build_connect_options(&cfg.auth, &cfg.tls)?
71 .connect(&cfg.servers)
72 .await
73 .map_err(|e| err(&format!("nats connect({:?}): {e}", cfg.servers)))?;
74 self.runtime = Some(match cfg.mode {
75 Mode::Core => Runtime::Core { client },
76 Mode::JetStream => {
77 let context = jetstream::new(client);
78 if let Some(stream_name) = cfg.stream.as_deref() {
79 let stream = context
81 .get_stream(stream_name)
82 .await
83 .map_err(|e| err(&format!("get_stream('{stream_name}') failed: {e}")))?;
84 if cfg.delivery_guarantee == DeliveryGuarantee::ExactlyOnce {
85 let info = stream.cached_info();
86 let actual = info.config.duplicate_window;
87 if actual < cfg.min_duplicate_window {
88 return Err(err(&format!(
89 "[LDB-5056] stream '{stream_name}' has duplicate_window={actual:?}, \
90 below the configured minimum {:?}. Rollback redelivery could land \
91 outside the dedup horizon. Reconfigure the stream or lower \
92 'min.duplicate.window.ms'.",
93 cfg.min_duplicate_window,
94 )));
95 }
96 }
97 }
98 Runtime::JetStream { context }
99 }
100 });
101 self.config = Some(cfg);
102 Ok(())
103 }
104
105 async fn write_batch(&mut self, batch: &RecordBatch) -> Result<WriteResult, ConnectorError> {
106 let Self {
108 config,
109 serializer,
110 runtime,
111 pending_acks,
112 metrics,
113 schema: _,
114 } = self;
115 let cfg = config.as_ref().ok_or_else(|| err("sink: open() first"))?;
116 let ser = serializer
117 .as_ref()
118 .ok_or_else(|| err("sink: open() first"))?;
119 let rt = runtime.as_mut().ok_or_else(|| err("sink: open() first"))?;
120
121 let subject_col = match &cfg.subject {
122 SubjectSpec::Column(name) => Some(resolve_utf8(batch, name)?),
123 SubjectSpec::Literal(_) => None,
124 };
125 let header_cols: Vec<(&str, &StringArray)> = cfg
126 .header_columns
127 .iter()
128 .map(|n| resolve_utf8(batch, n).map(|arr| (n.as_str(), arr)))
129 .collect::<Result<_, _>>()?;
130 let expected_stream = cfg.expected_stream.as_deref();
131 let dedup_col = if cfg.delivery_guarantee == DeliveryGuarantee::ExactlyOnce {
134 cfg.dedup_id_column
135 .as_deref()
136 .map(|n| resolve_utf8(batch, n).map(|arr| (n, arr)))
137 .transpose()?
138 } else {
139 None
140 };
141
142 let records = ser
143 .serialize(batch)
144 .map_err(|e| err(&format!("serialize batch: {e}")))?;
145
146 let mut bytes_total: u64 = 0;
147 let mut rows_written: usize = 0;
148 for (row, payload) in records.into_iter().enumerate() {
149 let subject: &str = match (&cfg.subject, subject_col) {
150 (SubjectSpec::Literal(s), _) => s.as_str(),
151 (SubjectSpec::Column(name), Some(arr)) => {
152 non_null(arr, row, "subject.column", name)?
153 }
154 (SubjectSpec::Column(_), None) => unreachable!("resolved above"),
155 };
156 let msg_id = dedup_col
157 .map(|(name, arr)| non_null(arr, row, "dedup.id.column", name))
158 .transpose()?;
159 let headers = build_headers(expected_stream, msg_id, &header_cols, row);
160 let payload_len = payload.len() as u64;
161 let payload = bytes::Bytes::from(payload);
162
163 match rt {
164 Runtime::Core { client } => {
165 let result = if let Some(h) = headers {
166 client
167 .publish_with_headers(subject.to_string(), h, payload)
168 .await
169 } else {
170 client.publish(subject.to_string(), payload).await
171 };
172 if let Err(e) = result {
173 metrics.record_publish_error();
174 return Err(err(&format!("core publish: {e}")));
175 }
176 }
177 Runtime::JetStream { context } => {
178 if pending_acks.len() >= cfg.max_pending {
179 drain_acks(pending_acks, metrics, cfg.ack_timeout).await?;
180 }
181 let publish_result = if let Some(h) = headers {
182 context
183 .publish_with_headers(subject.to_string(), h, payload)
184 .await
185 } else {
186 context.publish(subject.to_string(), payload).await
187 };
188 match publish_result {
189 Ok(fut) => pending_acks.push_back(fut),
190 Err(e) => {
191 metrics.record_publish_error();
192 return Err(err(&format!("jetstream publish: {e}")));
193 }
194 }
195 }
196 }
197
198 metrics.record_published_row(payload_len);
200 rows_written += 1;
201 bytes_total += payload_len;
202 }
203
204 metrics.set_pending_futures(pending_acks.len());
205 Ok(WriteResult::new(rows_written, bytes_total))
206 }
207
208 fn schema(&self) -> SchemaRef {
209 self.schema.clone()
210 }
211
212 fn capabilities(&self) -> SinkConnectorCapabilities {
213 let mut caps = SinkConnectorCapabilities::new(Duration::from_secs(5))
214 .with_idempotent()
215 .with_partitioned();
216 if matches!(
217 self.config.as_ref().map(|c| c.delivery_guarantee),
218 Some(DeliveryGuarantee::ExactlyOnce)
219 ) {
220 caps = caps.with_exactly_once().with_two_phase_commit();
221 }
222 caps
223 }
224
225 async fn flush(&mut self) -> Result<(), ConnectorError> {
226 match self.runtime.as_ref() {
227 Some(Runtime::Core { client }) => client
228 .flush()
229 .await
230 .map_err(|e| err(&format!("core flush: {e}")))?,
231 Some(Runtime::JetStream { .. }) | None => {}
232 }
233 let timeout = self
234 .config
235 .as_ref()
236 .map_or(Duration::from_secs(30), |c| c.ack_timeout);
237 drain_acks(&mut self.pending_acks, &self.metrics, timeout).await
238 }
239
240 async fn pre_commit(&mut self, _epoch: u64) -> Result<(), ConnectorError> {
241 let timeout = self
242 .config
243 .as_ref()
244 .map_or(Duration::from_secs(30), |c| c.ack_timeout);
245 drain_acks(&mut self.pending_acks, &self.metrics, timeout).await
246 }
247
248 async fn rollback_epoch(&mut self, _epoch: u64) -> Result<(), ConnectorError> {
249 self.pending_acks.clear();
251 self.metrics.set_pending_futures(0);
252 self.metrics.record_rollback();
253 Ok(())
254 }
255
256 async fn close(&mut self) -> Result<(), ConnectorError> {
257 let timeout = self
258 .config
259 .as_ref()
260 .map_or(Duration::from_secs(5), |c| c.ack_timeout);
261 if let Some(Runtime::Core { client }) = self.runtime.as_ref() {
263 let _ = client.flush().await;
264 }
265 let _ = drain_acks(&mut self.pending_acks, &self.metrics, timeout).await;
266 self.runtime = None;
267 Ok(())
268 }
269
270 fn health_check(&self) -> HealthStatus {
271 match self.config.as_ref() {
272 None => HealthStatus::Unknown,
273 Some(cfg) => {
274 let cap = cfg.max_pending.max(1) as u64;
275 #[allow(clippy::cast_sign_loss)]
276 let pending = self.metrics.pending_futures.get().max(0) as u64;
277 if pending * 2 >= cap {
278 HealthStatus::Degraded(format!(
279 "pending publish acks {pending}/{cap} — ack drain may stall pre_commit"
280 ))
281 } else {
282 HealthStatus::Healthy
283 }
284 }
285 }
286 }
287
288 fn metrics(&self) -> ConnectorMetrics {
289 self.metrics.to_connector_metrics()
290 }
291}
292
293async fn drain_acks(
298 pending: &mut VecDeque<PublishAckFuture>,
299 metrics: &NatsSinkMetrics,
300 timeout: Duration,
301) -> Result<(), ConnectorError> {
302 if pending.is_empty() {
303 return Ok(());
304 }
305 let mut set: FuturesUnordered<_> = pending.drain(..).map(IntoFuture::into_future).collect();
306 let deadline = tokio::time::Instant::now() + timeout;
307 let mut first_err: Option<ConnectorError> = None;
308
309 loop {
310 if set.is_empty() {
311 break;
312 }
313 match tokio::time::timeout_at(deadline, set.next()).await {
314 Ok(Some(Ok(ack))) => {
315 if ack.duplicate {
316 metrics.record_dedup();
317 }
318 }
319 Ok(Some(Err(e))) => {
320 metrics.record_ack_error();
321 if first_err.is_none() {
322 first_err = Some(err(&format!("jetstream publish ack: {e}")));
323 }
324 }
325 Ok(None) => break,
326 Err(_) => {
327 let lost = set.len();
328 for _ in 0..lost {
329 metrics.record_ack_error();
330 }
331 metrics.set_pending_futures(pending.len());
332 return Err(err(&format!(
333 "jetstream publish ack: timed out with {lost} still in flight"
334 )));
335 }
336 }
337 }
338
339 metrics.set_pending_futures(pending.len());
340 first_err.map_or(Ok(()), Err)
341}
342
343fn resolve_utf8<'a>(batch: &'a RecordBatch, name: &str) -> Result<&'a StringArray, ConnectorError> {
344 let col = batch
345 .column_by_name(name)
346 .ok_or_else(|| err(&format!("column '{name}' not in batch schema")))?;
347 col.as_string_opt::<i32>()
348 .ok_or_else(|| err(&format!("column '{name}' must be Utf8")))
349}
350
351fn non_null<'a>(
352 arr: &'a StringArray,
353 row: usize,
354 kind: &str,
355 name: &str,
356) -> Result<&'a str, ConnectorError> {
357 if arr.is_null(row) {
358 Err(err(&format!("{kind} '{name}' is null at row {row}")))
359 } else {
360 Ok(arr.value(row))
361 }
362}
363
364fn build_headers(
365 expected_stream: Option<&str>,
366 msg_id: Option<&str>,
367 header_cols: &[(&str, &StringArray)],
368 row: usize,
369) -> Option<HeaderMap> {
370 if header_cols.is_empty() && expected_stream.is_none() && msg_id.is_none() {
371 return None;
372 }
373 let mut h = HeaderMap::new();
374 if let Some(s) = expected_stream {
375 h.insert("Nats-Expected-Stream", s);
376 }
377 if let Some(id) = msg_id {
378 h.insert("Nats-Msg-Id", id);
379 }
380 for (name, arr) in header_cols {
381 if !arr.is_null(row) {
382 h.insert(*name, arr.value(row));
383 }
384 }
385 Some(h)
386}
387
388fn err(msg: &str) -> ConnectorError {
389 ConnectorError::ConfigurationError(msg.to_string())
390}