Skip to main content

laminar_connectors/mongodb/
sink.rs

1//! `MongoDB` sink connector implementation.
2//!
3//! Implements [`SinkConnector`] for writing Arrow `RecordBatch` data to
4//! `MongoDB` collections. Supports insert, upsert, replace, and CDC replay
5//! write modes, with optional time series collection support.
6//!
7//! # Architecture
8//!
9//! - **Ring 0**: No sink code — data arrives via SPSC channel (~5ns push)
10//! - **Ring 1**: Batch buffering, write dispatch, flush management
11//! - **Ring 2**: Connection pool, collection creation, write concern config
12//!
13//! # Batching
14//!
15//! Writes are buffered up to `batch_size` records and flushed when:
16//! - The batch is full
17//! - `flush_interval` has elapsed
18//! - A shutdown signal or epoch boundary is reached
19//!
20//! Insert mode uses `insert_many` for batch efficiency. Upsert, replace,
21//! and CDC replay modes issue individual operations per document.
22
23use std::sync::Arc;
24use std::time::{Duration, Instant};
25
26use arrow_array::RecordBatch;
27use arrow_schema::{DataType, SchemaRef};
28use async_trait::async_trait;
29use tracing::{debug, info};
30
31use crate::config::{ConnectorConfig, ConnectorState};
32use crate::connector::{SinkConnector, SinkConnectorCapabilities, WriteResult};
33use crate::error::ConnectorError;
34
35use super::config::MongoDbSinkConfig;
36use super::metrics::MongoDbSinkMetrics;
37use super::timeseries::CollectionKind;
38use super::write_model::WriteMode;
39
40/// `MongoDB` sink connector.
41///
42/// Writes Arrow `RecordBatch` records to a `MongoDB` collection using
43/// configurable write modes. Supports standard and time series collections.
44pub struct MongoDbSink {
45    /// Sink configuration.
46    config: MongoDbSinkConfig,
47
48    /// Arrow schema for input batches.
49    schema: SchemaRef,
50
51    /// Connector lifecycle state.
52    state: ConnectorState,
53
54    /// Buffered records awaiting flush.
55    buffer: Vec<RecordBatch>,
56
57    /// Total rows in buffer.
58    buffered_rows: usize,
59
60    /// Last flush time.
61    last_flush: Instant,
62
63    /// Sink metrics.
64    metrics: MongoDbSinkMetrics,
65
66    /// `MongoDB` client (feature-gated).
67    #[cfg(feature = "mongodb-cdc")]
68    client: Option<mongodb::Client>,
69
70    /// Target collection handle (feature-gated).
71    #[cfg(feature = "mongodb-cdc")]
72    collection: Option<mongodb::Collection<mongodb::bson::Document>>,
73}
74
75impl MongoDbSink {
76    /// Creates a new `MongoDB` sink connector.
77    #[must_use]
78    pub fn new(
79        schema: SchemaRef,
80        config: MongoDbSinkConfig,
81        registry: Option<&prometheus::Registry>,
82    ) -> Self {
83        let buf_capacity = (config.batch_size / 128).max(4);
84        Self {
85            config,
86            schema,
87            state: ConnectorState::Created,
88            buffer: Vec::with_capacity(buf_capacity),
89            buffered_rows: 0,
90            last_flush: Instant::now(),
91            metrics: MongoDbSinkMetrics::new(registry),
92            #[cfg(feature = "mongodb-cdc")]
93            client: None,
94            #[cfg(feature = "mongodb-cdc")]
95            collection: None,
96        }
97    }
98
99    /// Creates a new sink from a generic [`ConnectorConfig`].
100    ///
101    /// # Errors
102    ///
103    /// Returns `ConnectorError` if the configuration is invalid.
104    pub fn from_config(
105        schema: SchemaRef,
106        config: &ConnectorConfig,
107    ) -> Result<Self, ConnectorError> {
108        let mongo_config = MongoDbSinkConfig::from_config(config)?;
109        Ok(Self::new(schema, mongo_config, None))
110    }
111
112    /// Returns a reference to the sink configuration.
113    #[must_use]
114    pub fn config(&self) -> &MongoDbSinkConfig {
115        &self.config
116    }
117
118    /// Returns the number of buffered rows.
119    #[must_use]
120    pub fn buffered_rows(&self) -> usize {
121        self.buffered_rows
122    }
123
124    /// Returns whether a flush is needed based on batch size or interval.
125    #[must_use]
126    fn should_flush(&self) -> bool {
127        self.buffered_rows >= self.config.batch_size
128            || self.last_flush.elapsed() >= self.config.flush_interval()
129    }
130
131    /// Converts buffered Arrow batches to JSON documents for writing.
132    ///
133    /// Returns `(docs, byte_estimate)`. The byte estimate is accumulated
134    /// during conversion to avoid a redundant serialization pass later.
135    fn batches_to_json_docs(&self) -> (Vec<serde_json::Value>, u64) {
136        let mut docs = Vec::with_capacity(self.buffered_rows);
137        let mut byte_estimate: u64 = 0;
138
139        for batch in &self.buffer {
140            for row_idx in 0..batch.num_rows() {
141                let mut doc = serde_json::Map::new();
142
143                for (col_idx, field) in batch.schema().fields().iter().enumerate() {
144                    let col = batch.column(col_idx);
145                    let value = arrow_value_to_json(col, row_idx);
146                    doc.insert(field.name().clone(), value);
147                }
148
149                let val = serde_json::Value::Object(doc);
150                byte_estimate += serde_json::to_string(&val).map_or(0, |s| s.len() as u64);
151                docs.push(val);
152            }
153        }
154
155        (docs, byte_estimate)
156    }
157
158    /// Clears the internal buffer after a flush.
159    fn clear_buffer(&mut self) {
160        self.buffer.clear();
161        self.buffered_rows = 0;
162        self.last_flush = Instant::now();
163    }
164
165    /// Internal flush that returns a [`WriteResult`] with actual counts.
166    ///
167    /// Both `write_batch` (on auto-flush) and `flush` delegate here.
168    async fn flush_inner(&mut self) -> Result<WriteResult, ConnectorError> {
169        if self.buffer.is_empty() {
170            return Ok(WriteResult::new(0, 0));
171        }
172
173        let (docs, byte_estimate) = self.batches_to_json_docs();
174        let doc_count = docs.len();
175
176        #[cfg(feature = "mongodb-cdc")]
177        {
178            self.write_docs(&docs).await?;
179        }
180
181        #[cfg(not(feature = "mongodb-cdc"))]
182        {
183            debug!(
184                count = doc_count,
185                "flush (no-op without mongodb-cdc feature)"
186            );
187        }
188
189        self.metrics.record_flush(doc_count as u64, byte_estimate);
190        self.clear_buffer();
191
192        Ok(WriteResult::new(doc_count, byte_estimate))
193    }
194}
195
196/// Extracts a JSON value from an Arrow array at the given row index.
197fn arrow_value_to_json(col: &dyn arrow_array::Array, row: usize) -> serde_json::Value {
198    use arrow_array::{BooleanArray, LargeStringArray, StringArray};
199
200    if col.is_null(row) {
201        return serde_json::Value::Null;
202    }
203
204    match col.data_type() {
205        DataType::Boolean => {
206            let arr = col.as_any().downcast_ref::<BooleanArray>().unwrap();
207            serde_json::Value::Bool(arr.value(row))
208        }
209        DataType::Int8
210        | DataType::Int16
211        | DataType::Int32
212        | DataType::Int64
213        | DataType::UInt8
214        | DataType::UInt16
215        | DataType::UInt32
216        | DataType::UInt64 => json_from_primitive(col, row),
217        DataType::Float32 | DataType::Float64 => json_from_float(col, row),
218        DataType::Utf8 => {
219            let arr = col.as_any().downcast_ref::<StringArray>().unwrap();
220            serde_json::Value::String(arr.value(row).to_string())
221        }
222        DataType::LargeUtf8 => {
223            let arr = col.as_any().downcast_ref::<LargeStringArray>().unwrap();
224            serde_json::Value::String(arr.value(row).to_string())
225        }
226        _ => {
227            // Fallback: use Arrow's display format.
228            let formatted = arrow_cast::display::ArrayFormatter::try_new(
229                col,
230                &arrow_cast::display::FormatOptions::default(),
231            );
232            match formatted {
233                Ok(fmt) => serde_json::Value::String(fmt.value(row).to_string()),
234                Err(_) => serde_json::Value::Null,
235            }
236        }
237    }
238}
239
240/// Helper to extract a JSON number from a primitive Arrow array.
241fn json_from_primitive(col: &dyn arrow_array::Array, row: usize) -> serde_json::Value {
242    let formatted = arrow_cast::display::ArrayFormatter::try_new(
243        col,
244        &arrow_cast::display::FormatOptions::default(),
245    );
246    match formatted {
247        Ok(fmt) => {
248            let s = fmt.value(row).to_string();
249            if let Ok(n) = s.parse::<i64>() {
250                serde_json::Value::Number(n.into())
251            } else {
252                serde_json::Value::String(s)
253            }
254        }
255        Err(_) => serde_json::Value::Null,
256    }
257}
258
259/// Helper to extract a JSON number from a float Arrow array.
260fn json_from_float(col: &dyn arrow_array::Array, row: usize) -> serde_json::Value {
261    let formatted = arrow_cast::display::ArrayFormatter::try_new(
262        col,
263        &arrow_cast::display::FormatOptions::default(),
264    );
265    match formatted {
266        Ok(fmt) => {
267            let s = fmt.value(row).to_string();
268            if let Ok(n) = s.parse::<f64>() {
269                serde_json::json!(n)
270            } else {
271                serde_json::Value::String(s)
272            }
273        }
274        Err(_) => serde_json::Value::Null,
275    }
276}
277
278#[async_trait]
279impl SinkConnector for MongoDbSink {
280    async fn open(&mut self, _config: &ConnectorConfig) -> Result<(), ConnectorError> {
281        self.config.validate()?;
282
283        #[cfg(feature = "mongodb-cdc")]
284        {
285            self.connect().await?;
286        }
287
288        self.state = ConnectorState::Running;
289        info!(
290            database = %self.config.database,
291            collection = %self.config.collection,
292            write_mode = ?self.config.write_mode,
293            ordered = self.config.ordered,
294            "MongoDB sink opened"
295        );
296
297        Ok(())
298    }
299
300    async fn write_batch(&mut self, batch: &RecordBatch) -> Result<WriteResult, ConnectorError> {
301        let rows = batch.num_rows();
302        if rows == 0 {
303            return Ok(WriteResult::new(0, 0));
304        }
305
306        self.buffer.push(batch.clone());
307        self.buffered_rows += rows;
308
309        if self.should_flush() {
310            return self.flush_inner().await;
311        }
312
313        // Just buffered, nothing written yet.
314        Ok(WriteResult::new(0, 0))
315    }
316
317    fn schema(&self) -> SchemaRef {
318        Arc::clone(&self.schema)
319    }
320
321    fn capabilities(&self) -> SinkConnectorCapabilities {
322        let mut caps = SinkConnectorCapabilities::new(Duration::from_secs(30)).with_idempotent();
323
324        if matches!(self.config.write_mode, WriteMode::Upsert { .. }) {
325            caps = caps.with_upsert();
326        }
327        if matches!(self.config.write_mode, WriteMode::CdcReplay) {
328            caps = caps.with_changelog();
329        }
330
331        caps
332    }
333
334    async fn flush(&mut self) -> Result<(), ConnectorError> {
335        self.flush_inner().await.map(|_| ())
336    }
337
338    async fn close(&mut self) -> Result<(), ConnectorError> {
339        // Flush remaining buffered data.
340        if !self.buffer.is_empty() {
341            self.flush().await?;
342        }
343
344        self.state = ConnectorState::Closed;
345        info!("MongoDB sink closed");
346        Ok(())
347    }
348}
349
350// ── Feature-gated I/O (real MongoDB driver) ──
351
352#[cfg(feature = "mongodb-cdc")]
353impl MongoDbSink {
354    /// Connects to `MongoDB` and sets up the target collection with write concern.
355    async fn connect(&mut self) -> Result<(), ConnectorError> {
356        use mongodb::options::{ClientOptions, CollectionOptions};
357
358        let client_options = ClientOptions::parse(&self.config.connection_uri)
359            .await
360            .map_err(|e| ConnectorError::ConnectionFailed(format!("parse URI: {e}")))?;
361
362        let client = mongodb::Client::with_options(client_options)
363            .map_err(|e| ConnectorError::ConnectionFailed(format!("create client: {e}")))?;
364
365        let db = client.database(&self.config.database);
366
367        // Ensure time series collection exists if configured.
368        if let CollectionKind::TimeSeries(ref ts_config) = self.config.collection_kind {
369            self.ensure_timeseries_collection(&db, ts_config).await?;
370        }
371
372        // Apply write concern from configuration.
373        let wc = {
374            use super::config::WriteConcernLevel;
375            let mut wc = mongodb::options::WriteConcern::default();
376            wc.w = Some(match &self.config.write_concern.w {
377                WriteConcernLevel::Majority => mongodb::options::Acknowledgment::Majority,
378                WriteConcernLevel::Nodes(n) => mongodb::options::Acknowledgment::Nodes(*n),
379            });
380            wc.journal = Some(self.config.write_concern.journal);
381            wc.w_timeout = self
382                .config
383                .write_concern
384                .timeout_ms
385                .map(std::time::Duration::from_millis);
386            wc
387        };
388
389        let coll_opts = CollectionOptions::builder().write_concern(wc).build();
390
391        let collection = db
392            .collection_with_options::<mongodb::bson::Document>(&self.config.collection, coll_opts);
393
394        self.client = Some(client);
395        self.collection = Some(collection);
396
397        Ok(())
398    }
399
400    /// Ensures a time series collection exists with the correct configuration.
401    async fn ensure_timeseries_collection(
402        &self,
403        db: &mongodb::Database,
404        ts_config: &super::timeseries::TimeSeriesConfig,
405    ) -> Result<(), ConnectorError> {
406        use mongodb::bson::doc;
407
408        let mut ts_opts = doc! {
409            "timeField": &ts_config.time_field,
410        };
411
412        if let Some(ref meta) = ts_config.meta_field {
413            ts_opts.insert("metaField", meta);
414        }
415
416        match ts_config.granularity {
417            super::timeseries::TimeSeriesGranularity::Seconds => {
418                ts_opts.insert("granularity", "seconds");
419            }
420            super::timeseries::TimeSeriesGranularity::Minutes => {
421                ts_opts.insert("granularity", "minutes");
422            }
423            super::timeseries::TimeSeriesGranularity::Hours => {
424                ts_opts.insert("granularity", "hours");
425            }
426            super::timeseries::TimeSeriesGranularity::Custom {
427                bucket_max_span_seconds,
428                bucket_rounding_seconds,
429            } => {
430                ts_opts.insert("bucketMaxSpanSeconds", i64::from(bucket_max_span_seconds));
431                ts_opts.insert("bucketRoundingSeconds", i64::from(bucket_rounding_seconds));
432            }
433        }
434
435        let mut create_opts = doc! {
436            "create": &self.config.collection,
437            "timeseries": ts_opts,
438        };
439
440        if let Some(ttl) = ts_config.expire_after_seconds {
441            #[allow(clippy::cast_possible_wrap)]
442            create_opts.insert("expireAfterSeconds", ttl as i64);
443        }
444
445        // Try creating; ignore "already exists" errors.
446        match db.run_command(create_opts).await {
447            Ok(_) => {
448                info!(
449                    collection = %self.config.collection,
450                    time_field = %ts_config.time_field,
451                    granularity = %ts_config.granularity,
452                    "created time series collection"
453                );
454            }
455            Err(e) => {
456                let msg = e.to_string();
457                if !msg.contains("already exists") && !msg.contains("NamespaceExists") {
458                    return Err(ConnectorError::ConnectionFailed(format!(
459                        "create time series collection: {e}"
460                    )));
461                }
462                debug!(
463                    collection = %self.config.collection,
464                    "time series collection already exists"
465                );
466            }
467        }
468
469        Ok(())
470    }
471
472    /// Extracts a CDC envelope field that may be a JSON string (from Utf8 Arrow
473    /// columns) or already a JSON object. Parses strings into objects for BSON
474    /// conversion.
475    fn parse_cdc_field<'a>(
476        val: &'a serde_json::Value,
477        field: &str,
478    ) -> Result<std::borrow::Cow<'a, serde_json::Value>, ConnectorError> {
479        let v = val.get(field).ok_or_else(|| {
480            ConnectorError::WriteError(format!("CDC event missing {field} field"))
481        })?;
482        match v {
483            serde_json::Value::Object(_) => Ok(std::borrow::Cow::Borrowed(v)),
484            serde_json::Value::String(s) => {
485                let parsed: serde_json::Value = serde_json::from_str(s)
486                    .map_err(|e| ConnectorError::WriteError(format!("parse {field} JSON: {e}")))?;
487                Ok(std::borrow::Cow::Owned(parsed))
488            }
489            _ => Err(ConnectorError::WriteError(format!(
490                "{field} must be a JSON object or JSON string, got {v}"
491            ))),
492        }
493    }
494
495    /// Writes JSON value documents to `MongoDB` using the configured write mode.
496    ///
497    /// Accepts `serde_json::Value` directly (no intermediate string round-trip).
498    #[allow(clippy::too_many_lines)]
499    async fn write_docs(&self, docs: &[serde_json::Value]) -> Result<(), ConnectorError> {
500        use mongodb::bson::{doc, Document};
501
502        let collection = self
503            .collection
504            .as_ref()
505            .ok_or_else(|| ConnectorError::Internal("collection not initialized".to_string()))?;
506
507        match &self.config.write_mode {
508            WriteMode::Insert => {
509                let bson_docs: Vec<Document> = docs
510                    .iter()
511                    .map(|v| {
512                        mongodb::bson::to_document(v)
513                            .map_err(|e| ConnectorError::WriteError(format!("to BSON: {e}")))
514                    })
515                    .collect::<Result<Vec<_>, _>>()?;
516
517                let opts = mongodb::options::InsertManyOptions::builder()
518                    .ordered(Some(self.config.ordered))
519                    .build();
520
521                collection
522                    .insert_many(bson_docs)
523                    .with_options(opts)
524                    .await
525                    .map_err(|e| {
526                        self.metrics.record_error();
527                        ConnectorError::WriteError(format!("insert_many: {e}"))
528                    })?;
529
530                self.metrics.record_inserts(docs.len() as u64);
531            }
532
533            WriteMode::Upsert { ref key_fields } => {
534                for val in docs {
535                    let bson_doc = mongodb::bson::to_document(val)
536                        .map_err(|e| ConnectorError::WriteError(format!("to BSON: {e}")))?;
537
538                    let mut filter = Document::new();
539                    for key in key_fields {
540                        if let Some(v) = bson_doc.get(key) {
541                            filter.insert(key, v.clone());
542                        }
543                    }
544                    if filter.is_empty() {
545                        return Err(ConnectorError::WriteError(format!(
546                            "upsert filter is empty: none of the key_fields {key_fields:?} \
547                             exist in the document"
548                        )));
549                    }
550
551                    let opts = mongodb::options::ReplaceOptions::builder()
552                        .upsert(Some(true))
553                        .build();
554
555                    collection
556                        .replace_one(filter, bson_doc)
557                        .with_options(opts)
558                        .await
559                        .map_err(|e| {
560                            self.metrics.record_error();
561                            ConnectorError::WriteError(format!("upsert: {e}"))
562                        })?;
563                }
564
565                self.metrics.record_upserts(docs.len() as u64);
566            }
567
568            WriteMode::Replace { upsert_on_missing } => {
569                for val in docs {
570                    let bson_doc = mongodb::bson::to_document(val)
571                        .map_err(|e| ConnectorError::WriteError(format!("to BSON: {e}")))?;
572
573                    // Use _id as the filter for replacement.
574                    let filter = match bson_doc.get("_id") {
575                        Some(id) if *id != mongodb::bson::Bson::Null => {
576                            doc! { "_id": id.clone() }
577                        }
578                        _ => {
579                            return Err(ConnectorError::WriteError(
580                                "Replace mode requires a non-null _id field in document"
581                                    .to_string(),
582                            ));
583                        }
584                    };
585
586                    let opts = mongodb::options::ReplaceOptions::builder()
587                        .upsert(Some(*upsert_on_missing))
588                        .build();
589
590                    collection
591                        .replace_one(filter, bson_doc)
592                        .with_options(opts)
593                        .await
594                        .map_err(|e| {
595                            self.metrics.record_error();
596                            ConnectorError::WriteError(format!("replace: {e}"))
597                        })?;
598                }
599            }
600
601            WriteMode::CdcReplay => {
602                // CDC replay processes each document based on its _op field.
603                for val in docs {
604                    let op = val.get("_op").and_then(|v| v.as_str()).unwrap_or("I");
605
606                    match op {
607                        "I" => {
608                            let full_doc = Self::parse_cdc_field(val, "_full_document")?;
609                            let bson_doc = mongodb::bson::to_document(full_doc.as_ref())
610                                .map_err(|e| ConnectorError::WriteError(format!("to BSON: {e}")))?;
611                            collection.insert_one(bson_doc).await.map_err(|e| {
612                                ConnectorError::WriteError(format!("cdc insert: {e}"))
613                            })?;
614                            self.metrics.record_inserts(1);
615                        }
616                        "U" => {
617                            let dk = Self::parse_cdc_field(val, "_document_key")?;
618                            let ud = Self::parse_cdc_field(val, "_update_desc")?;
619                            let filter = mongodb::bson::to_document(dk.as_ref()).map_err(|e| {
620                                ConnectorError::WriteError(format!("filter BSON: {e}"))
621                            })?;
622
623                            // Transform updateDescription into update operators.
624                            // Raw format: { "updatedFields": {...}, "removedFields": [...] }
625                            // Required:   { "$set": {...}, "$unset": {...} }
626                            let mut update = mongodb::bson::Document::new();
627                            if let Some(updated) = ud.get("updatedFields") {
628                                let bson = mongodb::bson::to_bson(updated).map_err(|e| {
629                                    ConnectorError::WriteError(format!("updatedFields BSON: {e}"))
630                                })?;
631                                update.insert("$set", bson);
632                            }
633                            if let Some(removed) =
634                                ud.get("removedFields").and_then(|v| v.as_array())
635                            {
636                                if !removed.is_empty() {
637                                    let unset_doc: mongodb::bson::Document = removed
638                                        .iter()
639                                        .filter_map(|f| f.as_str())
640                                        .map(|f| {
641                                            (
642                                                f.to_string(),
643                                                mongodb::bson::Bson::String(String::new()),
644                                            )
645                                        })
646                                        .collect();
647                                    update.insert("$unset", unset_doc);
648                                }
649                            }
650                            if update.is_empty() {
651                                continue;
652                            }
653
654                            collection.update_one(filter, update).await.map_err(|e| {
655                                ConnectorError::WriteError(format!("cdc update: {e}"))
656                            })?;
657                            self.metrics.record_upserts(1);
658                        }
659                        "R" => {
660                            let dk = Self::parse_cdc_field(val, "_document_key")?;
661                            let full_doc = Self::parse_cdc_field(val, "_full_document")?;
662                            let filter = mongodb::bson::to_document(dk.as_ref()).map_err(|e| {
663                                ConnectorError::WriteError(format!("filter BSON: {e}"))
664                            })?;
665                            let replacement = mongodb::bson::to_document(full_doc.as_ref())
666                                .map_err(|e| {
667                                    ConnectorError::WriteError(format!("replace BSON: {e}"))
668                                })?;
669                            let opts = mongodb::options::ReplaceOptions::builder()
670                                .upsert(Some(true))
671                                .build();
672                            collection
673                                .replace_one(filter, replacement)
674                                .with_options(opts)
675                                .await
676                                .map_err(|e| {
677                                    ConnectorError::WriteError(format!("cdc replace: {e}"))
678                                })?;
679                            self.metrics.record_upserts(1);
680                        }
681                        "D" => {
682                            let dk = Self::parse_cdc_field(val, "_document_key")?;
683                            let filter = mongodb::bson::to_document(dk.as_ref()).map_err(|e| {
684                                ConnectorError::WriteError(format!("filter BSON: {e}"))
685                            })?;
686                            collection.delete_one(filter).await.map_err(|e| {
687                                ConnectorError::WriteError(format!("cdc delete: {e}"))
688                            })?;
689                            self.metrics.record_deletes(1);
690                        }
691                        _ => {
692                            debug!(op = op, "lifecycle event — no write issued");
693                        }
694                    }
695                }
696            }
697        }
698
699        self.metrics.record_bulk_write();
700        Ok(())
701    }
702}
703
704#[cfg(test)]
705mod tests {
706    use super::*;
707    use arrow_array::{Int64Array, StringArray};
708    use arrow_schema::{Field, Schema};
709
710    fn test_schema() -> SchemaRef {
711        Arc::new(Schema::new(vec![
712            Field::new("id", DataType::Int64, false),
713            Field::new("name", DataType::Utf8, false),
714        ]))
715    }
716
717    fn test_batch(n: usize) -> RecordBatch {
718        #[allow(clippy::cast_possible_wrap)]
719        let ids: Vec<i64> = (0..n as i64).collect();
720        let names: Vec<String> = (0..n).map(|i| format!("user_{i}")).collect();
721
722        RecordBatch::try_new(
723            test_schema(),
724            vec![
725                Arc::new(Int64Array::from(ids)),
726                Arc::new(StringArray::from(names)),
727            ],
728        )
729        .unwrap()
730    }
731
732    #[test]
733    fn test_new_sink() {
734        let config = MongoDbSinkConfig::new("mongodb://localhost:27017", "db", "coll");
735        let sink = MongoDbSink::new(test_schema(), config, None);
736        assert_eq!(sink.buffered_rows(), 0);
737    }
738
739    #[test]
740    fn test_sink_capabilities_insert() {
741        let config = MongoDbSinkConfig::default();
742        let sink = MongoDbSink::new(test_schema(), config, None);
743        let caps = sink.capabilities();
744        assert!(caps.idempotent);
745        assert!(!caps.upsert);
746        assert!(!caps.changelog);
747    }
748
749    #[test]
750    fn test_sink_capabilities_upsert() {
751        let mut config = MongoDbSinkConfig::default();
752        config.write_mode = WriteMode::Upsert {
753            key_fields: vec!["id".to_string()],
754        };
755        let sink = MongoDbSink::new(test_schema(), config, None);
756        let caps = sink.capabilities();
757        assert!(caps.upsert);
758    }
759
760    #[test]
761    fn test_sink_capabilities_cdc_replay() {
762        let mut config = MongoDbSinkConfig::default();
763        config.write_mode = WriteMode::CdcReplay;
764        let sink = MongoDbSink::new(test_schema(), config, None);
765        let caps = sink.capabilities();
766        assert!(caps.changelog);
767    }
768
769    #[test]
770    fn test_batches_to_json() {
771        let config = MongoDbSinkConfig::default();
772        let mut sink = MongoDbSink::new(test_schema(), config, None);
773        sink.buffer.push(test_batch(3));
774        sink.buffered_rows = 3;
775
776        let (docs, byte_estimate) = sink.batches_to_json_docs();
777        assert_eq!(docs.len(), 3);
778        assert!(byte_estimate > 0);
779
780        assert_eq!(docs[0]["id"], 0);
781        assert_eq!(docs[0]["name"], "user_0");
782    }
783
784    #[test]
785    fn test_should_flush_batch_size() {
786        let mut config = MongoDbSinkConfig::default();
787        config.batch_size = 100;
788        let mut sink = MongoDbSink::new(test_schema(), config, None);
789
790        assert!(!sink.should_flush());
791        sink.buffered_rows = 100;
792        assert!(sink.should_flush());
793    }
794
795    #[test]
796    fn test_clear_buffer() {
797        let config = MongoDbSinkConfig::default();
798        let mut sink = MongoDbSink::new(test_schema(), config, None);
799
800        sink.buffer.push(test_batch(5));
801        sink.buffered_rows = 5;
802
803        sink.clear_buffer();
804        assert_eq!(sink.buffered_rows, 0);
805        assert!(sink.buffer.is_empty());
806    }
807
808    #[test]
809    fn test_arrow_value_to_json_types() {
810        use arrow_array::*;
811
812        // Int64
813        let arr = Int64Array::from(vec![42]);
814        let val = arrow_value_to_json(&arr, 0);
815        assert_eq!(val, serde_json::json!(42));
816
817        // String
818        let arr = StringArray::from(vec!["hello"]);
819        let val = arrow_value_to_json(&arr, 0);
820        assert_eq!(val, serde_json::json!("hello"));
821
822        // Boolean
823        let arr = BooleanArray::from(vec![true]);
824        let val = arrow_value_to_json(&arr, 0);
825        assert_eq!(val, serde_json::json!(true));
826
827        // Null
828        let arr = Int64Array::from(vec![None::<i64>]);
829        let val = arrow_value_to_json(&arr, 0);
830        assert!(val.is_null());
831    }
832}