1use std::collections::VecDeque;
6use std::future::IntoFuture;
7use std::time::Duration;
8
9use arrow_array::{cast::AsArray, Array, RecordBatch, StringArray};
10use arrow_schema::SchemaRef;
11use async_nats::jetstream::{self, context::PublishAckFuture};
12use async_nats::{Client, HeaderMap};
13use async_trait::async_trait;
14use futures_util::stream::{FuturesUnordered, StreamExt};
15
16use super::config::{build_connect_options, Mode, NatsSinkConfig, SubjectSpec};
17use super::metrics::NatsSinkMetrics;
18use crate::config::ConnectorConfig;
19use crate::connector::{DeliveryGuarantee, SinkConnector, SinkConnectorCapabilities, WriteResult};
20use crate::error::ConnectorError;
21use crate::serde::{self, RecordSerializer};
22
23pub struct NatsSink {
25 schema: SchemaRef,
26 config: Option<NatsSinkConfig>,
27 serializer: Option<Box<dyn RecordSerializer>>,
28 runtime: Option<Runtime>,
29 metrics: NatsSinkMetrics,
30 pending_acks: VecDeque<PublishAckFuture>,
32}
33
34enum Runtime {
35 Core { client: Client },
36 JetStream { context: jetstream::Context },
37}
38
39impl NatsSink {
40 #[must_use]
42 pub fn new(schema: SchemaRef, registry: Option<&prometheus::Registry>) -> Self {
43 Self {
44 schema,
45 config: None,
46 serializer: None,
47 runtime: None,
48 metrics: NatsSinkMetrics::new(registry),
49 pending_acks: VecDeque::new(),
50 }
51 }
52
53 #[must_use]
55 pub fn config(&self) -> Option<&NatsSinkConfig> {
56 self.config.as_ref()
57 }
58
59 #[must_use]
61 pub fn metrics_handle(&self) -> &NatsSinkMetrics {
62 &self.metrics
63 }
64}
65
66#[async_trait]
67impl SinkConnector for NatsSink {
68 async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
69 let cfg = NatsSinkConfig::from_config(config)?;
70 self.serializer = Some(
71 serde::create_serializer(cfg.format)
72 .map_err(|e| err(&format!("serializer for format {:?}: {e}", cfg.format)))?,
73 );
74 let client = build_connect_options(&cfg.auth, &cfg.tls)?
75 .connect(&cfg.servers)
76 .await
77 .map_err(|e| err(&format!("nats connect({:?}): {e}", cfg.servers)))?;
78 self.runtime = Some(match cfg.mode {
79 Mode::Core => Runtime::Core { client },
80 Mode::JetStream => {
81 let context = jetstream::new(client);
82 if let Some(stream_name) = cfg.stream.as_deref() {
83 let stream = context
85 .get_stream(stream_name)
86 .await
87 .map_err(|e| err(&format!("get_stream('{stream_name}') failed: {e}")))?;
88 if cfg.delivery_guarantee == DeliveryGuarantee::ExactlyOnce {
89 let info = stream.cached_info();
90 let actual = info.config.duplicate_window;
91 if actual < cfg.min_duplicate_window {
92 return Err(err(&format!(
93 "[LDB-5056] stream '{stream_name}' has duplicate_window={actual:?}, \
94 below the configured minimum {:?}. Rollback redelivery could land \
95 outside the dedup horizon. Reconfigure the stream or lower \
96 'min.duplicate.window.ms'.",
97 cfg.min_duplicate_window,
98 )));
99 }
100 }
101 }
102 Runtime::JetStream { context }
103 }
104 });
105 self.config = Some(cfg);
106 Ok(())
107 }
108
109 async fn write_batch(&mut self, batch: &RecordBatch) -> Result<WriteResult, ConnectorError> {
110 let Self {
112 config,
113 serializer,
114 runtime,
115 pending_acks,
116 metrics,
117 schema: _,
118 } = self;
119 let cfg = config.as_ref().ok_or_else(|| err("sink: open() first"))?;
120 let ser = serializer
121 .as_ref()
122 .ok_or_else(|| err("sink: open() first"))?;
123 let rt = runtime.as_mut().ok_or_else(|| err("sink: open() first"))?;
124
125 let subject_col = match &cfg.subject {
126 SubjectSpec::Column(name) => Some(resolve_utf8(batch, name)?),
127 SubjectSpec::Literal(_) => None,
128 };
129 let header_cols: Vec<(&str, &StringArray)> = cfg
130 .header_columns
131 .iter()
132 .map(|n| resolve_utf8(batch, n).map(|arr| (n.as_str(), arr)))
133 .collect::<Result<_, _>>()?;
134 let expected_stream = cfg.expected_stream.as_deref();
135 let dedup_col = if cfg.delivery_guarantee == DeliveryGuarantee::ExactlyOnce {
138 cfg.dedup_id_column
139 .as_deref()
140 .map(|n| resolve_utf8(batch, n).map(|arr| (n, arr)))
141 .transpose()?
142 } else {
143 None
144 };
145
146 let records = ser
147 .serialize(batch)
148 .map_err(|e| err(&format!("serialize batch: {e}")))?;
149
150 let mut bytes_total: u64 = 0;
151 let mut rows_written: usize = 0;
152 for (row, payload) in records.into_iter().enumerate() {
153 let subject: &str = match (&cfg.subject, subject_col) {
154 (SubjectSpec::Literal(s), _) => s.as_str(),
155 (SubjectSpec::Column(name), Some(arr)) => {
156 non_null(arr, row, "subject.column", name)?
157 }
158 (SubjectSpec::Column(_), None) => unreachable!("resolved above"),
159 };
160 let msg_id = dedup_col
161 .map(|(name, arr)| non_null(arr, row, "dedup.id.column", name))
162 .transpose()?;
163 let headers = build_headers(expected_stream, msg_id, &header_cols, row);
164 let payload_len = payload.len() as u64;
165 let payload = bytes::Bytes::from(payload);
166
167 match rt {
168 Runtime::Core { client } => {
169 let result = if let Some(h) = headers {
170 client
171 .publish_with_headers(subject.to_string(), h, payload)
172 .await
173 } else {
174 client.publish(subject.to_string(), payload).await
175 };
176 if let Err(e) = result {
177 metrics.record_publish_error();
178 return Err(err(&format!("core publish: {e}")));
179 }
180 }
181 Runtime::JetStream { context } => {
182 if pending_acks.len() >= cfg.max_pending {
183 drain_acks(pending_acks, metrics, cfg.ack_timeout).await?;
184 }
185 let publish_result = if let Some(h) = headers {
186 context
187 .publish_with_headers(subject.to_string(), h, payload)
188 .await
189 } else {
190 context.publish(subject.to_string(), payload).await
191 };
192 match publish_result {
193 Ok(fut) => pending_acks.push_back(fut),
194 Err(e) => {
195 metrics.record_publish_error();
196 return Err(err(&format!("jetstream publish: {e}")));
197 }
198 }
199 }
200 }
201
202 metrics.record_published_row(payload_len);
204 rows_written += 1;
205 bytes_total += payload_len;
206 }
207
208 metrics.set_pending_futures(pending_acks.len());
209 Ok(WriteResult::new(rows_written, bytes_total))
210 }
211
212 fn schema(&self) -> SchemaRef {
213 self.schema.clone()
214 }
215
216 fn capabilities(&self) -> SinkConnectorCapabilities {
217 let mut caps = SinkConnectorCapabilities::new(Duration::from_secs(5))
218 .with_idempotent()
219 .with_partitioned();
220 if matches!(
221 self.config.as_ref().map(|c| c.delivery_guarantee),
222 Some(DeliveryGuarantee::ExactlyOnce)
223 ) {
224 caps = caps.with_exactly_once().with_two_phase_commit();
225 }
226 caps
227 }
228
229 async fn flush(&mut self) -> Result<(), ConnectorError> {
230 match self.runtime.as_ref() {
231 Some(Runtime::Core { client }) => client
232 .flush()
233 .await
234 .map_err(|e| err(&format!("core flush: {e}")))?,
235 Some(Runtime::JetStream { .. }) | None => {}
236 }
237 let timeout = self
238 .config
239 .as_ref()
240 .map_or(Duration::from_secs(30), |c| c.ack_timeout);
241 drain_acks(&mut self.pending_acks, &self.metrics, timeout).await
242 }
243
244 async fn pre_commit(&mut self, _epoch: u64) -> Result<(), ConnectorError> {
245 let timeout = self
246 .config
247 .as_ref()
248 .map_or(Duration::from_secs(30), |c| c.ack_timeout);
249 drain_acks(&mut self.pending_acks, &self.metrics, timeout).await
250 }
251
252 async fn rollback_epoch(&mut self, _epoch: u64) -> Result<(), ConnectorError> {
253 self.pending_acks.clear();
255 self.metrics.set_pending_futures(0);
256 self.metrics.record_rollback();
257 Ok(())
258 }
259
260 async fn close(&mut self) -> Result<(), ConnectorError> {
261 let timeout = self
262 .config
263 .as_ref()
264 .map_or(Duration::from_secs(5), |c| c.ack_timeout);
265 if let Some(Runtime::Core { client }) = self.runtime.as_ref() {
267 let _ = client.flush().await;
268 }
269 let _ = drain_acks(&mut self.pending_acks, &self.metrics, timeout).await;
270 self.runtime = None;
271 Ok(())
272 }
273}
274
275async fn drain_acks(
280 pending: &mut VecDeque<PublishAckFuture>,
281 metrics: &NatsSinkMetrics,
282 timeout: Duration,
283) -> Result<(), ConnectorError> {
284 if pending.is_empty() {
285 return Ok(());
286 }
287 let mut set: FuturesUnordered<_> = pending.drain(..).map(IntoFuture::into_future).collect();
288 let deadline = tokio::time::Instant::now() + timeout;
289 let mut first_err: Option<ConnectorError> = None;
290
291 loop {
292 if set.is_empty() {
293 break;
294 }
295 match tokio::time::timeout_at(deadline, set.next()).await {
296 Ok(Some(Ok(ack))) => {
297 if ack.duplicate {
298 metrics.record_dedup();
299 }
300 }
301 Ok(Some(Err(e))) => {
302 metrics.record_ack_error();
303 if first_err.is_none() {
304 first_err = Some(err(&format!("jetstream publish ack: {e}")));
305 }
306 }
307 Ok(None) => break,
308 Err(_) => {
309 let lost = set.len();
310 for _ in 0..lost {
311 metrics.record_ack_error();
312 }
313 metrics.set_pending_futures(pending.len());
314 return Err(err(&format!(
315 "jetstream publish ack: timed out with {lost} still in flight"
316 )));
317 }
318 }
319 }
320
321 metrics.set_pending_futures(pending.len());
322 first_err.map_or(Ok(()), Err)
323}
324
325fn resolve_utf8<'a>(batch: &'a RecordBatch, name: &str) -> Result<&'a StringArray, ConnectorError> {
326 let col = batch
327 .column_by_name(name)
328 .ok_or_else(|| err(&format!("column '{name}' not in batch schema")))?;
329 col.as_string_opt::<i32>()
330 .ok_or_else(|| err(&format!("column '{name}' must be Utf8")))
331}
332
333fn non_null<'a>(
334 arr: &'a StringArray,
335 row: usize,
336 kind: &str,
337 name: &str,
338) -> Result<&'a str, ConnectorError> {
339 if arr.is_null(row) {
340 Err(err(&format!("{kind} '{name}' is null at row {row}")))
341 } else {
342 Ok(arr.value(row))
343 }
344}
345
346fn build_headers(
347 expected_stream: Option<&str>,
348 msg_id: Option<&str>,
349 header_cols: &[(&str, &StringArray)],
350 row: usize,
351) -> Option<HeaderMap> {
352 if header_cols.is_empty() && expected_stream.is_none() && msg_id.is_none() {
353 return None;
354 }
355 let mut h = HeaderMap::new();
356 if let Some(s) = expected_stream {
357 h.insert("Nats-Expected-Stream", s);
358 }
359 if let Some(id) = msg_id {
360 h.insert("Nats-Msg-Id", id);
361 }
362 for (name, arr) in header_cols {
363 if !arr.is_null(row) {
364 h.insert(*name, arr.value(row));
365 }
366 }
367 Some(h)
368}
369
370fn err(msg: &str) -> ConnectorError {
371 ConnectorError::ConfigurationError(msg.to_string())
372}