Skip to main content

laminar_connectors/changelog/
collapse.rs

1//! Collapse a changelog epoch batch into a cardinality-safe, key-unique upsert
2//! batch. See the module docs for why this is necessary.
3
4use std::sync::Arc;
5
6use arrow_array::{Array, ArrayRef, Int64Array, RecordBatch, StringArray, UInt32Array};
7use arrow_row::{RowConverter, Rows, SortField};
8use arrow_schema::{DataType, Field, Schema};
9
10use laminar_core::changelog::WEIGHT_COLUMN;
11
12use crate::error::ConnectorError;
13
14/// Collapse a concatenated changelog epoch `batch` into a key-unique batch
15/// carrying a `_op` column of `U` (upsert) or `D` (delete), one row per
16/// `merge_key`.
17///
18/// Two input encodings are detected automatically:
19///
20/// - **Z-set** (the `__weight` column is present): identical full rows are
21///   consolidated by summing their weights, net-zero rows are dropped, then the
22///   survivors are grouped by `merge_key`. A key with a net-positive (live) row
23///   becomes a `U` carrying that value; a key with only net-negative rows
24///   becomes a `D`. The `__weight` column is stripped from the output.
25/// - **CDC** (no `__weight`): the last-arriving row per `merge_key` wins (row
26///   order is arrival order). Its op is normalized to `D` for deletes
27///   (`_op ∈ {D, U-}`) and `U` for everything else. A batch with neither column
28///   is treated as all-upsert.
29///
30/// The output reuses the existing key-by-key MERGE (`_op ∈ {U, D}`) unchanged,
31/// and contains at most one row per merge key, so the writer never sees a
32/// cardinality violation.
33///
34/// # Errors
35///
36/// - [`ConnectorError::ConfigurationError`] if `merge_key` is empty, names a
37///   column absent from the batch, or is not unique over the collapsed output
38///   (more than one live row for a single key — a misdeclared merge key).
39/// - [`ConnectorError::Internal`] if an Arrow row-conversion or take fails, or
40///   the `__weight` column is not Int64.
41pub fn collapse_changelog(
42    batch: &RecordBatch,
43    merge_key: &[String],
44) -> Result<RecordBatch, ConnectorError> {
45    if merge_key.is_empty() {
46        return Err(ConnectorError::ConfigurationError(
47            "changelog collapse requires at least one merge key column".into(),
48        ));
49    }
50    let schema = batch.schema();
51    for k in merge_key {
52        if is_metadata_column(k) {
53            return Err(ConnectorError::ConfigurationError(format!(
54                "merge key column '{k}' is reserved changelog metadata and cannot be a merge key"
55            )));
56        }
57        if schema.index_of(k).is_err() {
58            return Err(ConnectorError::ConfigurationError(format!(
59                "merge key column '{k}' is not present in the changelog output schema"
60            )));
61        }
62    }
63
64    if let Ok(weight_idx) = schema.index_of(WEIGHT_COLUMN) {
65        collapse_zset(batch, merge_key, weight_idx)
66    } else {
67        collapse_cdc(batch, merge_key)
68    }
69}
70
71/// Build comparable [`Rows`] over the given column indices of `batch`.
72fn rows_over(batch: &RecordBatch, indices: &[usize]) -> Result<Rows, ConnectorError> {
73    let schema = batch.schema();
74    let fields: Vec<SortField> = indices
75        .iter()
76        .map(|&i| SortField::new(schema.field(i).data_type().clone()))
77        .collect();
78    let arrays: Vec<ArrayRef> = indices.iter().map(|&i| batch.column(i).clone()).collect();
79    let converter = RowConverter::new(fields)
80        .map_err(|e| ConnectorError::Internal(format!("row converter: {e}")))?;
81    converter
82        .convert_columns(&arrays)
83        .map_err(|e| ConnectorError::Internal(format!("convert columns to rows: {e}")))
84}
85
86/// Column indices for the named `columns` (which the caller has validated exist).
87fn index_of_all(batch: &RecordBatch, columns: &[String]) -> Vec<usize> {
88    let schema = batch.schema();
89    columns
90        .iter()
91        .map(|name| schema.index_of(name).expect("merge key columns validated"))
92        .collect()
93}
94
95/// Changelog metadata columns, excluded from the collapsed output's user
96/// columns (`_op` is re-emitted normalized; `__weight`/`_ts_ms` are dropped).
97fn is_metadata_column(name: &str) -> bool {
98    name == "_op" || name == "_ts_ms" || name == WEIGHT_COLUMN
99}
100
101/// Z-set collapse: consolidate by full row, then pick one row per merge key.
102fn collapse_zset(
103    batch: &RecordBatch,
104    merge_key: &[String],
105    weight_idx: usize,
106) -> Result<RecordBatch, ConnectorError> {
107    let num_rows = batch.num_rows();
108    let weights = batch
109        .column(weight_idx)
110        .as_any()
111        .downcast_ref::<Int64Array>()
112        .ok_or_else(|| ConnectorError::Internal(format!("{WEIGHT_COLUMN} column is not Int64")))?;
113
114    // User columns = every column except changelog metadata.
115    let schema = batch.schema();
116    let user_indices: Vec<usize> = (0..batch.num_columns())
117        .filter(|&i| !is_metadata_column(schema.field(i).name()))
118        .collect();
119
120    // 1. Consolidate identical full rows, summing weights; keep net-nonzero.
121    //    `survivors` holds (representative row index, net weight).
122    let full_rows = rows_over(batch, &user_indices)?;
123    let mut order: Vec<usize> = (0..num_rows).collect();
124    order.sort_unstable_by(|&a, &b| full_rows.row(a).cmp(&full_rows.row(b)));
125
126    let mut survivors: Vec<(usize, i64)> = Vec::new();
127    let mut i = 0;
128    while i < order.len() {
129        let rep = order[i];
130        let rep_row = full_rows.row(rep);
131        let mut sum = 0i64;
132        let mut j = i;
133        while j < order.len() && full_rows.row(order[j]) == rep_row {
134            sum += weights.value(order[j]);
135            j += 1;
136        }
137        if sum != 0 {
138            survivors.push((rep, sum));
139        }
140        i = j;
141    }
142
143    if survivors.is_empty() {
144        return build_output(batch, &user_indices, &[], &[]);
145    }
146
147    // 2. Group survivors by merge key; one output row per key.
148    let key_rows = rows_over(batch, &index_of_all(batch, merge_key))?;
149    survivors
150        .sort_unstable_by(|&(a, _), &(b, _)| key_rows.row(a).cmp(&key_rows.row(b)).then(a.cmp(&b)));
151
152    let mut selected: Vec<usize> = Vec::new();
153    let mut ops: Vec<&str> = Vec::new();
154    let mut g = 0;
155    while g < survivors.len() {
156        let key = key_rows.row(survivors[g].0);
157        let mut live: Option<usize> = None;
158        let mut live_count = 0usize;
159        let mut first_negative: Option<usize> = None;
160        let mut h = g;
161        while h < survivors.len() && key_rows.row(survivors[h].0) == key {
162            let (idx, weight) = survivors[h];
163            if weight > 0 {
164                live_count += 1;
165                if live.is_none() {
166                    live = Some(idx);
167                }
168            } else if first_negative.is_none() {
169                first_negative = Some(idx);
170            }
171            h += 1;
172        }
173        if live_count > 1 {
174            return Err(ConnectorError::ConfigurationError(format!(
175                "changelog collapse: merge.key.columns {merge_key:?} is not unique — {live_count} \
176                 distinct live rows share one key in a single epoch; declare a merge key that is \
177                 unique over the materialized-view output"
178            )));
179        }
180        if let Some(idx) = live {
181            selected.push(idx);
182            ops.push("U");
183        } else if let Some(idx) = first_negative {
184            selected.push(idx);
185            ops.push("D");
186        }
187        g = h;
188    }
189
190    build_output(batch, &user_indices, &selected, &ops)
191}
192
193/// CDC collapse: last-arriving row per merge key wins; op normalized to U/D.
194fn collapse_cdc(batch: &RecordBatch, merge_key: &[String]) -> Result<RecordBatch, ConnectorError> {
195    let num_rows = batch.num_rows();
196    let schema = batch.schema();
197    let op_values = match schema.index_of("_op") {
198        Ok(idx) => Some(
199            batch
200                .column(idx)
201                .as_any()
202                .downcast_ref::<StringArray>()
203                .ok_or_else(|| ConnectorError::Internal("_op column is not Utf8".into()))?,
204        ),
205        Err(_) => None,
206    };
207
208    // Keep the highest (last-arriving) row index per merge key.
209    let key_rows = rows_over(batch, &index_of_all(batch, merge_key))?;
210    let mut order: Vec<usize> = (0..num_rows).collect();
211    order.sort_unstable_by(|&a, &b| key_rows.row(a).cmp(&key_rows.row(b)).then(a.cmp(&b)));
212
213    let mut selected: Vec<usize> = Vec::new();
214    let mut i = 0;
215    while i < order.len() {
216        let key = key_rows.row(order[i]);
217        let mut last = order[i];
218        let mut j = i;
219        while j < order.len() && key_rows.row(order[j]) == key {
220            last = order[j]; // order is index-ascending within a key group
221            j += 1;
222        }
223        selected.push(last);
224        i = j;
225    }
226    // Deterministic output order (correctness is unaffected — keys are unique).
227    selected.sort_unstable();
228
229    // Normalize to {U, D}: a delete iff the surviving op is D or U- (a before
230    // image); otherwise the row is the current image for its key.
231    let ops: Vec<&str> = selected
232        .iter()
233        .map(|&idx| match op_values {
234            Some(values) if !values.is_null(idx) => {
235                if matches!(values.value(idx), "D" | "U-") {
236                    "D"
237                } else {
238                    "U"
239                }
240            }
241            _ => "U",
242        })
243        .collect();
244
245    // User columns = every column except changelog metadata.
246    let user_indices: Vec<usize> = (0..batch.num_columns())
247        .filter(|&i| !is_metadata_column(schema.field(i).name()))
248        .collect();
249
250    build_output(batch, &user_indices, &selected, &ops)
251}
252
253/// Take `user_indices` columns at the `selected` rows of `batch` and append a
254/// fresh `_op` column built from `ops`. `selected` and `ops` must be parallel.
255fn build_output(
256    batch: &RecordBatch,
257    user_indices: &[usize],
258    selected: &[usize],
259    ops: &[&str],
260) -> Result<RecordBatch, ConnectorError> {
261    debug_assert_eq!(selected.len(), ops.len());
262    let schema = batch.schema();
263    // Row counts are bounded by the epoch buffer cap, well under u32::MAX.
264    #[allow(clippy::cast_possible_truncation)]
265    let take_idx = UInt32Array::from(selected.iter().map(|&i| i as u32).collect::<Vec<_>>());
266
267    let mut fields: Vec<Field> = Vec::with_capacity(user_indices.len() + 1);
268    let mut columns: Vec<ArrayRef> = Vec::with_capacity(user_indices.len() + 1);
269    for &idx in user_indices {
270        let taken = arrow_select::take::take(batch.column(idx), &take_idx, None)
271            .map_err(|e| ConnectorError::Internal(format!("take column: {e}")))?;
272        fields.push(schema.field(idx).as_ref().clone());
273        columns.push(taken);
274    }
275    fields.push(Field::new("_op", DataType::Utf8, false));
276    columns.push(Arc::new(StringArray::from(ops.to_vec())));
277
278    RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
279        .map_err(|e| ConnectorError::Internal(format!("build collapsed batch: {e}")))
280}
281
282#[cfg(test)]
283#[allow(clippy::too_many_lines)]
284mod tests {
285    use super::*;
286    use arrow_array::{Float64Array, Int64Array, StringArray};
287
288    fn keys(cols: &[&str]) -> Vec<String> {
289        cols.iter().map(|s| (*s).to_string()).collect()
290    }
291
292    /// Build a Z-set changelog batch: schema [region: Utf8, total: Int64, __weight: Int64].
293    fn zset_batch(rows: &[(&str, i64, i64)]) -> RecordBatch {
294        let schema = Arc::new(Schema::new(vec![
295            Field::new("region", DataType::Utf8, false),
296            Field::new("total", DataType::Int64, false),
297            Field::new(WEIGHT_COLUMN, DataType::Int64, false),
298        ]));
299        RecordBatch::try_new(
300            schema,
301            vec![
302                Arc::new(StringArray::from(
303                    rows.iter().map(|r| r.0).collect::<Vec<_>>(),
304                )),
305                Arc::new(Int64Array::from(
306                    rows.iter().map(|r| r.1).collect::<Vec<_>>(),
307                )),
308                Arc::new(Int64Array::from(
309                    rows.iter().map(|r| r.2).collect::<Vec<_>>(),
310                )),
311            ],
312        )
313        .unwrap()
314    }
315
316    /// Build a CDC changelog batch: schema [id: Int64, value: Float64, _op: Utf8].
317    fn cdc_batch(rows: &[(i64, f64, &str)]) -> RecordBatch {
318        let schema = Arc::new(Schema::new(vec![
319            Field::new("id", DataType::Int64, false),
320            Field::new("value", DataType::Float64, false),
321            Field::new("_op", DataType::Utf8, false),
322        ]));
323        RecordBatch::try_new(
324            schema,
325            vec![
326                Arc::new(Int64Array::from(
327                    rows.iter().map(|r| r.0).collect::<Vec<_>>(),
328                )),
329                Arc::new(Float64Array::from(
330                    rows.iter().map(|r| r.1).collect::<Vec<_>>(),
331                )),
332                Arc::new(StringArray::from(
333                    rows.iter().map(|r| r.2).collect::<Vec<_>>(),
334                )),
335            ],
336        )
337        .unwrap()
338    }
339
340    fn col_str(batch: &RecordBatch, name: &str) -> Vec<String> {
341        let idx = batch.schema().index_of(name).unwrap();
342        let arr = batch
343            .column(idx)
344            .as_any()
345            .downcast_ref::<StringArray>()
346            .unwrap();
347        (0..arr.len()).map(|i| arr.value(i).to_string()).collect()
348    }
349
350    fn col_i64(batch: &RecordBatch, name: &str) -> Vec<i64> {
351        let idx = batch.schema().index_of(name).unwrap();
352        let arr = batch
353            .column(idx)
354            .as_any()
355            .downcast_ref::<Int64Array>()
356            .unwrap();
357        (0..arr.len()).map(|i| arr.value(i)).collect()
358    }
359
360    /// Sort output rows by (region/id key, op) so assertions are order-stable.
361    fn sorted_pairs(
362        regions: &[String],
363        values: &[i64],
364        ops: &[String],
365    ) -> Vec<(String, i64, String)> {
366        let mut v: Vec<_> = regions
367            .iter()
368            .zip(values)
369            .zip(ops)
370            .map(|((r, t), o)| (r.clone(), *t, o.clone()))
371            .collect();
372        v.sort();
373        v
374    }
375
376    #[test]
377    fn zset_multi_update_per_key_keeps_final_value() {
378        // Two emit cycles concatenated: 10→20→35. The intermediate 20 cancels.
379        let out = collapse_changelog(
380            &zset_batch(&[
381                ("east", 10, -1),
382                ("east", 20, 1),
383                ("east", 20, -1),
384                ("east", 35, 1),
385            ]),
386            &keys(&["region"]),
387        )
388        .unwrap();
389        assert_eq!(out.num_rows(), 1);
390        assert_eq!(col_str(&out, "_op"), vec!["U"]);
391        assert_eq!(col_i64(&out, "total"), vec![35]);
392    }
393
394    #[test]
395    fn zset_insert_then_delete_within_epoch_is_noop() {
396        // +1 then -1 on the same full row nets zero → dropped entirely.
397        let out = collapse_changelog(
398            &zset_batch(&[("east", 10, 1), ("east", 10, -1)]),
399            &keys(&["region"]),
400        )
401        .unwrap();
402        assert_eq!(out.num_rows(), 0);
403        assert!(out.schema().index_of(WEIGHT_COLUMN).is_err());
404        assert!(out.schema().index_of("_op").is_ok());
405    }
406
407    #[test]
408    fn zset_multiple_keys_mixed_ops() {
409        // east updated, west dropped, north newly inserted — in one epoch.
410        let out = collapse_changelog(
411            &zset_batch(&[
412                ("east", 10, -1),
413                ("east", 30, 1),
414                ("west", 5, -1),
415                ("north", 99, 1),
416            ]),
417            &keys(&["region"]),
418        )
419        .unwrap();
420        assert_eq!(out.num_rows(), 3);
421        let got = sorted_pairs(
422            &col_str(&out, "region"),
423            &col_i64(&out, "total"),
424            &col_str(&out, "_op"),
425        );
426        assert_eq!(
427            got,
428            vec![
429                ("east".into(), 30, "U".into()),
430                ("north".into(), 99, "U".into()),
431                ("west".into(), 5, "D".into()),
432            ]
433        );
434    }
435
436    #[test]
437    fn zset_higher_multiplicity_is_single_live_row() {
438        // Cascaded aggregation can emit weight > 1; still one live row per key.
439        let out = collapse_changelog(&zset_batch(&[("east", 10, 3)]), &keys(&["region"])).unwrap();
440        assert_eq!(out.num_rows(), 1);
441        assert_eq!(col_str(&out, "_op"), vec!["U"]);
442    }
443
444    #[test]
445    fn zset_non_unique_merge_key_errors() {
446        // Two distinct live rows for the same key → misdeclared merge key.
447        let err = collapse_changelog(
448            &zset_batch(&[("east", 10, 1), ("east", 20, 1)]),
449            &keys(&["region"]),
450        )
451        .unwrap_err();
452        assert!(
453            matches!(err, ConnectorError::ConfigurationError(_)),
454            "expected ConfigurationError, got {err:?}"
455        );
456        assert!(format!("{err}").contains("not unique"));
457    }
458
459    #[test]
460    fn zset_composite_merge_key() {
461        // Merge key over both columns: distinct (region,total) live rows are
462        // distinct keys, not a uniqueness violation.
463        let out = collapse_changelog(
464            &zset_batch(&[("east", 10, 1), ("east", 20, 1)]),
465            &keys(&["region", "total"]),
466        )
467        .unwrap();
468        assert_eq!(out.num_rows(), 2);
469        assert_eq!(col_str(&out, "_op"), vec!["U", "U"]);
470    }
471
472    #[test]
473    fn cdc_dedup_keeps_last_arrival() {
474        // id=1 inserted then updated within an epoch → one U with the last value.
475        let out = collapse_changelog(
476            &cdc_batch(&[(1, 10.0, "I"), (1, 15.0, "U"), (2, 20.0, "I")]),
477            &keys(&["id"]),
478        )
479        .unwrap();
480        assert_eq!(out.num_rows(), 2);
481        let idx = out.schema().index_of("id").unwrap();
482        let ids = out
483            .column(idx)
484            .as_any()
485            .downcast_ref::<Int64Array>()
486            .unwrap();
487        assert_eq!(
488            (0..ids.len()).map(|i| ids.value(i)).collect::<Vec<_>>(),
489            vec![1, 2]
490        );
491        assert_eq!(col_str(&out, "_op"), vec!["U", "U"]);
492        let vidx = out.schema().index_of("value").unwrap();
493        let vals = out
494            .column(vidx)
495            .as_any()
496            .downcast_ref::<Float64Array>()
497            .unwrap();
498        assert!(
499            (vals.value(0) - 15.0).abs() < f64::EPSILON,
500            "last value wins"
501        );
502    }
503
504    #[test]
505    fn cdc_delete_is_preserved() {
506        let out = collapse_changelog(
507            &cdc_batch(&[(1, 10.0, "I"), (1, 10.0, "D")]),
508            &keys(&["id"]),
509        )
510        .unwrap();
511        assert_eq!(out.num_rows(), 1);
512        assert_eq!(col_str(&out, "_op"), vec!["D"]);
513    }
514
515    #[test]
516    fn cdc_update_before_normalizes_to_delete_and_after_to_upsert() {
517        // U- alone → D; the U+ after-image → U.
518        let out_before =
519            collapse_changelog(&cdc_batch(&[(1, 10.0, "U-")]), &keys(&["id"])).unwrap();
520        assert_eq!(col_str(&out_before, "_op"), vec!["D"]);
521        let out_after = collapse_changelog(&cdc_batch(&[(1, 10.0, "U+")]), &keys(&["id"])).unwrap();
522        assert_eq!(col_str(&out_after, "_op"), vec!["U"]);
523    }
524
525    #[test]
526    fn cdc_no_op_column_treated_as_upsert() {
527        // Plain MV (no _op, no __weight) → all upserts, deduped by key.
528        let schema = Arc::new(Schema::new(vec![
529            Field::new("id", DataType::Int64, false),
530            Field::new("value", DataType::Float64, false),
531        ]));
532        let batch = RecordBatch::try_new(
533            schema,
534            vec![
535                Arc::new(Int64Array::from(vec![1, 1, 2])),
536                Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0])),
537            ],
538        )
539        .unwrap();
540        let out = collapse_changelog(&batch, &keys(&["id"])).unwrap();
541        assert_eq!(out.num_rows(), 2);
542        assert!(out.schema().index_of("_op").is_ok());
543        assert_eq!(col_str(&out, "_op"), vec!["U", "U"]);
544    }
545
546    #[test]
547    fn cdc_strips_ts_ms_and_emits_single_op() {
548        let schema = Arc::new(Schema::new(vec![
549            Field::new("id", DataType::Int64, false),
550            Field::new("_ts_ms", DataType::Int64, false),
551            Field::new("_op", DataType::Utf8, false),
552        ]));
553        let batch = RecordBatch::try_new(
554            schema,
555            vec![
556                Arc::new(Int64Array::from(vec![1, 2])),
557                Arc::new(Int64Array::from(vec![100, 200])),
558                Arc::new(StringArray::from(vec!["I", "U"])),
559            ],
560        )
561        .unwrap();
562        let out = collapse_changelog(&batch, &keys(&["id"])).unwrap();
563        assert!(out.schema().index_of("_ts_ms").is_err(), "_ts_ms stripped");
564        assert_eq!(
565            out.schema()
566                .fields()
567                .iter()
568                .filter(|f| f.name() == "_op")
569                .count(),
570            1,
571            "exactly one _op column"
572        );
573    }
574
575    #[test]
576    fn empty_batch_yields_empty_with_op_column() {
577        let out = collapse_changelog(&zset_batch(&[]), &keys(&["region"])).unwrap();
578        assert_eq!(out.num_rows(), 0);
579        assert!(out.schema().index_of(WEIGHT_COLUMN).is_err());
580        assert!(out.schema().index_of("_op").is_ok());
581    }
582
583    #[test]
584    fn empty_merge_key_errors() {
585        let err = collapse_changelog(&zset_batch(&[("east", 10, 1)]), &[]).unwrap_err();
586        assert!(matches!(err, ConnectorError::ConfigurationError(_)));
587    }
588
589    #[test]
590    fn missing_merge_key_column_errors() {
591        let err =
592            collapse_changelog(&zset_batch(&[("east", 10, 1)]), &keys(&["nope"])).unwrap_err();
593        assert!(matches!(err, ConnectorError::ConfigurationError(_)));
594        assert!(format!("{err}").contains("not present"));
595    }
596}