1use std::sync::Arc;
5
6use arrow_array::{Array, ArrayRef, Int64Array, RecordBatch, StringArray, UInt32Array};
7use arrow_row::{RowConverter, Rows, SortField};
8use arrow_schema::{DataType, Field, Schema};
9
10use laminar_core::changelog::WEIGHT_COLUMN;
11
12use crate::error::ConnectorError;
13
14pub fn collapse_changelog(
42 batch: &RecordBatch,
43 merge_key: &[String],
44) -> Result<RecordBatch, ConnectorError> {
45 if merge_key.is_empty() {
46 return Err(ConnectorError::ConfigurationError(
47 "changelog collapse requires at least one merge key column".into(),
48 ));
49 }
50 let schema = batch.schema();
51 for k in merge_key {
52 if is_metadata_column(k) {
53 return Err(ConnectorError::ConfigurationError(format!(
54 "merge key column '{k}' is reserved changelog metadata and cannot be a merge key"
55 )));
56 }
57 if schema.index_of(k).is_err() {
58 return Err(ConnectorError::ConfigurationError(format!(
59 "merge key column '{k}' is not present in the changelog output schema"
60 )));
61 }
62 }
63
64 if let Ok(weight_idx) = schema.index_of(WEIGHT_COLUMN) {
65 collapse_zset(batch, merge_key, weight_idx)
66 } else {
67 collapse_cdc(batch, merge_key)
68 }
69}
70
71fn rows_over(batch: &RecordBatch, indices: &[usize]) -> Result<Rows, ConnectorError> {
73 let schema = batch.schema();
74 let fields: Vec<SortField> = indices
75 .iter()
76 .map(|&i| SortField::new(schema.field(i).data_type().clone()))
77 .collect();
78 let arrays: Vec<ArrayRef> = indices.iter().map(|&i| batch.column(i).clone()).collect();
79 let converter = RowConverter::new(fields)
80 .map_err(|e| ConnectorError::Internal(format!("row converter: {e}")))?;
81 converter
82 .convert_columns(&arrays)
83 .map_err(|e| ConnectorError::Internal(format!("convert columns to rows: {e}")))
84}
85
86fn index_of_all(batch: &RecordBatch, columns: &[String]) -> Vec<usize> {
88 let schema = batch.schema();
89 columns
90 .iter()
91 .map(|name| schema.index_of(name).expect("merge key columns validated"))
92 .collect()
93}
94
95fn is_metadata_column(name: &str) -> bool {
98 name == "_op" || name == "_ts_ms" || name == WEIGHT_COLUMN
99}
100
101fn collapse_zset(
103 batch: &RecordBatch,
104 merge_key: &[String],
105 weight_idx: usize,
106) -> Result<RecordBatch, ConnectorError> {
107 let num_rows = batch.num_rows();
108 let weights = batch
109 .column(weight_idx)
110 .as_any()
111 .downcast_ref::<Int64Array>()
112 .ok_or_else(|| ConnectorError::Internal(format!("{WEIGHT_COLUMN} column is not Int64")))?;
113
114 let schema = batch.schema();
116 let user_indices: Vec<usize> = (0..batch.num_columns())
117 .filter(|&i| !is_metadata_column(schema.field(i).name()))
118 .collect();
119
120 let full_rows = rows_over(batch, &user_indices)?;
123 let mut order: Vec<usize> = (0..num_rows).collect();
124 order.sort_unstable_by(|&a, &b| full_rows.row(a).cmp(&full_rows.row(b)));
125
126 let mut survivors: Vec<(usize, i64)> = Vec::new();
127 let mut i = 0;
128 while i < order.len() {
129 let rep = order[i];
130 let rep_row = full_rows.row(rep);
131 let mut sum = 0i64;
132 let mut j = i;
133 while j < order.len() && full_rows.row(order[j]) == rep_row {
134 sum += weights.value(order[j]);
135 j += 1;
136 }
137 if sum != 0 {
138 survivors.push((rep, sum));
139 }
140 i = j;
141 }
142
143 if survivors.is_empty() {
144 return build_output(batch, &user_indices, &[], &[]);
145 }
146
147 let key_rows = rows_over(batch, &index_of_all(batch, merge_key))?;
149 survivors
150 .sort_unstable_by(|&(a, _), &(b, _)| key_rows.row(a).cmp(&key_rows.row(b)).then(a.cmp(&b)));
151
152 let mut selected: Vec<usize> = Vec::new();
153 let mut ops: Vec<&str> = Vec::new();
154 let mut g = 0;
155 while g < survivors.len() {
156 let key = key_rows.row(survivors[g].0);
157 let mut live: Option<usize> = None;
158 let mut live_count = 0usize;
159 let mut first_negative: Option<usize> = None;
160 let mut h = g;
161 while h < survivors.len() && key_rows.row(survivors[h].0) == key {
162 let (idx, weight) = survivors[h];
163 if weight > 0 {
164 live_count += 1;
165 if live.is_none() {
166 live = Some(idx);
167 }
168 } else if first_negative.is_none() {
169 first_negative = Some(idx);
170 }
171 h += 1;
172 }
173 if live_count > 1 {
174 return Err(ConnectorError::ConfigurationError(format!(
175 "changelog collapse: merge.key.columns {merge_key:?} is not unique — {live_count} \
176 distinct live rows share one key in a single epoch; declare a merge key that is \
177 unique over the materialized-view output"
178 )));
179 }
180 if let Some(idx) = live {
181 selected.push(idx);
182 ops.push("U");
183 } else if let Some(idx) = first_negative {
184 selected.push(idx);
185 ops.push("D");
186 }
187 g = h;
188 }
189
190 build_output(batch, &user_indices, &selected, &ops)
191}
192
193fn collapse_cdc(batch: &RecordBatch, merge_key: &[String]) -> Result<RecordBatch, ConnectorError> {
195 let num_rows = batch.num_rows();
196 let schema = batch.schema();
197 let op_values = match schema.index_of("_op") {
198 Ok(idx) => Some(
199 batch
200 .column(idx)
201 .as_any()
202 .downcast_ref::<StringArray>()
203 .ok_or_else(|| ConnectorError::Internal("_op column is not Utf8".into()))?,
204 ),
205 Err(_) => None,
206 };
207
208 let key_rows = rows_over(batch, &index_of_all(batch, merge_key))?;
210 let mut order: Vec<usize> = (0..num_rows).collect();
211 order.sort_unstable_by(|&a, &b| key_rows.row(a).cmp(&key_rows.row(b)).then(a.cmp(&b)));
212
213 let mut selected: Vec<usize> = Vec::new();
214 let mut i = 0;
215 while i < order.len() {
216 let key = key_rows.row(order[i]);
217 let mut last = order[i];
218 let mut j = i;
219 while j < order.len() && key_rows.row(order[j]) == key {
220 last = order[j]; j += 1;
222 }
223 selected.push(last);
224 i = j;
225 }
226 selected.sort_unstable();
228
229 let ops: Vec<&str> = selected
232 .iter()
233 .map(|&idx| match op_values {
234 Some(values) if !values.is_null(idx) => {
235 if matches!(values.value(idx), "D" | "U-") {
236 "D"
237 } else {
238 "U"
239 }
240 }
241 _ => "U",
242 })
243 .collect();
244
245 let user_indices: Vec<usize> = (0..batch.num_columns())
247 .filter(|&i| !is_metadata_column(schema.field(i).name()))
248 .collect();
249
250 build_output(batch, &user_indices, &selected, &ops)
251}
252
253fn build_output(
256 batch: &RecordBatch,
257 user_indices: &[usize],
258 selected: &[usize],
259 ops: &[&str],
260) -> Result<RecordBatch, ConnectorError> {
261 debug_assert_eq!(selected.len(), ops.len());
262 let schema = batch.schema();
263 #[allow(clippy::cast_possible_truncation)]
265 let take_idx = UInt32Array::from(selected.iter().map(|&i| i as u32).collect::<Vec<_>>());
266
267 let mut fields: Vec<Field> = Vec::with_capacity(user_indices.len() + 1);
268 let mut columns: Vec<ArrayRef> = Vec::with_capacity(user_indices.len() + 1);
269 for &idx in user_indices {
270 let taken = arrow_select::take::take(batch.column(idx), &take_idx, None)
271 .map_err(|e| ConnectorError::Internal(format!("take column: {e}")))?;
272 fields.push(schema.field(idx).as_ref().clone());
273 columns.push(taken);
274 }
275 fields.push(Field::new("_op", DataType::Utf8, false));
276 columns.push(Arc::new(StringArray::from(ops.to_vec())));
277
278 RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
279 .map_err(|e| ConnectorError::Internal(format!("build collapsed batch: {e}")))
280}
281
282#[cfg(test)]
283#[allow(clippy::too_many_lines)]
284mod tests {
285 use super::*;
286 use arrow_array::{Float64Array, Int64Array, StringArray};
287
288 fn keys(cols: &[&str]) -> Vec<String> {
289 cols.iter().map(|s| (*s).to_string()).collect()
290 }
291
292 fn zset_batch(rows: &[(&str, i64, i64)]) -> RecordBatch {
294 let schema = Arc::new(Schema::new(vec![
295 Field::new("region", DataType::Utf8, false),
296 Field::new("total", DataType::Int64, false),
297 Field::new(WEIGHT_COLUMN, DataType::Int64, false),
298 ]));
299 RecordBatch::try_new(
300 schema,
301 vec![
302 Arc::new(StringArray::from(
303 rows.iter().map(|r| r.0).collect::<Vec<_>>(),
304 )),
305 Arc::new(Int64Array::from(
306 rows.iter().map(|r| r.1).collect::<Vec<_>>(),
307 )),
308 Arc::new(Int64Array::from(
309 rows.iter().map(|r| r.2).collect::<Vec<_>>(),
310 )),
311 ],
312 )
313 .unwrap()
314 }
315
316 fn cdc_batch(rows: &[(i64, f64, &str)]) -> RecordBatch {
318 let schema = Arc::new(Schema::new(vec![
319 Field::new("id", DataType::Int64, false),
320 Field::new("value", DataType::Float64, false),
321 Field::new("_op", DataType::Utf8, false),
322 ]));
323 RecordBatch::try_new(
324 schema,
325 vec![
326 Arc::new(Int64Array::from(
327 rows.iter().map(|r| r.0).collect::<Vec<_>>(),
328 )),
329 Arc::new(Float64Array::from(
330 rows.iter().map(|r| r.1).collect::<Vec<_>>(),
331 )),
332 Arc::new(StringArray::from(
333 rows.iter().map(|r| r.2).collect::<Vec<_>>(),
334 )),
335 ],
336 )
337 .unwrap()
338 }
339
340 fn col_str(batch: &RecordBatch, name: &str) -> Vec<String> {
341 let idx = batch.schema().index_of(name).unwrap();
342 let arr = batch
343 .column(idx)
344 .as_any()
345 .downcast_ref::<StringArray>()
346 .unwrap();
347 (0..arr.len()).map(|i| arr.value(i).to_string()).collect()
348 }
349
350 fn col_i64(batch: &RecordBatch, name: &str) -> Vec<i64> {
351 let idx = batch.schema().index_of(name).unwrap();
352 let arr = batch
353 .column(idx)
354 .as_any()
355 .downcast_ref::<Int64Array>()
356 .unwrap();
357 (0..arr.len()).map(|i| arr.value(i)).collect()
358 }
359
360 fn sorted_pairs(
362 regions: &[String],
363 values: &[i64],
364 ops: &[String],
365 ) -> Vec<(String, i64, String)> {
366 let mut v: Vec<_> = regions
367 .iter()
368 .zip(values)
369 .zip(ops)
370 .map(|((r, t), o)| (r.clone(), *t, o.clone()))
371 .collect();
372 v.sort();
373 v
374 }
375
376 #[test]
377 fn zset_multi_update_per_key_keeps_final_value() {
378 let out = collapse_changelog(
380 &zset_batch(&[
381 ("east", 10, -1),
382 ("east", 20, 1),
383 ("east", 20, -1),
384 ("east", 35, 1),
385 ]),
386 &keys(&["region"]),
387 )
388 .unwrap();
389 assert_eq!(out.num_rows(), 1);
390 assert_eq!(col_str(&out, "_op"), vec!["U"]);
391 assert_eq!(col_i64(&out, "total"), vec![35]);
392 }
393
394 #[test]
395 fn zset_insert_then_delete_within_epoch_is_noop() {
396 let out = collapse_changelog(
398 &zset_batch(&[("east", 10, 1), ("east", 10, -1)]),
399 &keys(&["region"]),
400 )
401 .unwrap();
402 assert_eq!(out.num_rows(), 0);
403 assert!(out.schema().index_of(WEIGHT_COLUMN).is_err());
404 assert!(out.schema().index_of("_op").is_ok());
405 }
406
407 #[test]
408 fn zset_multiple_keys_mixed_ops() {
409 let out = collapse_changelog(
411 &zset_batch(&[
412 ("east", 10, -1),
413 ("east", 30, 1),
414 ("west", 5, -1),
415 ("north", 99, 1),
416 ]),
417 &keys(&["region"]),
418 )
419 .unwrap();
420 assert_eq!(out.num_rows(), 3);
421 let got = sorted_pairs(
422 &col_str(&out, "region"),
423 &col_i64(&out, "total"),
424 &col_str(&out, "_op"),
425 );
426 assert_eq!(
427 got,
428 vec![
429 ("east".into(), 30, "U".into()),
430 ("north".into(), 99, "U".into()),
431 ("west".into(), 5, "D".into()),
432 ]
433 );
434 }
435
436 #[test]
437 fn zset_higher_multiplicity_is_single_live_row() {
438 let out = collapse_changelog(&zset_batch(&[("east", 10, 3)]), &keys(&["region"])).unwrap();
440 assert_eq!(out.num_rows(), 1);
441 assert_eq!(col_str(&out, "_op"), vec!["U"]);
442 }
443
444 #[test]
445 fn zset_non_unique_merge_key_errors() {
446 let err = collapse_changelog(
448 &zset_batch(&[("east", 10, 1), ("east", 20, 1)]),
449 &keys(&["region"]),
450 )
451 .unwrap_err();
452 assert!(
453 matches!(err, ConnectorError::ConfigurationError(_)),
454 "expected ConfigurationError, got {err:?}"
455 );
456 assert!(format!("{err}").contains("not unique"));
457 }
458
459 #[test]
460 fn zset_composite_merge_key() {
461 let out = collapse_changelog(
464 &zset_batch(&[("east", 10, 1), ("east", 20, 1)]),
465 &keys(&["region", "total"]),
466 )
467 .unwrap();
468 assert_eq!(out.num_rows(), 2);
469 assert_eq!(col_str(&out, "_op"), vec!["U", "U"]);
470 }
471
472 #[test]
473 fn cdc_dedup_keeps_last_arrival() {
474 let out = collapse_changelog(
476 &cdc_batch(&[(1, 10.0, "I"), (1, 15.0, "U"), (2, 20.0, "I")]),
477 &keys(&["id"]),
478 )
479 .unwrap();
480 assert_eq!(out.num_rows(), 2);
481 let idx = out.schema().index_of("id").unwrap();
482 let ids = out
483 .column(idx)
484 .as_any()
485 .downcast_ref::<Int64Array>()
486 .unwrap();
487 assert_eq!(
488 (0..ids.len()).map(|i| ids.value(i)).collect::<Vec<_>>(),
489 vec![1, 2]
490 );
491 assert_eq!(col_str(&out, "_op"), vec!["U", "U"]);
492 let vidx = out.schema().index_of("value").unwrap();
493 let vals = out
494 .column(vidx)
495 .as_any()
496 .downcast_ref::<Float64Array>()
497 .unwrap();
498 assert!(
499 (vals.value(0) - 15.0).abs() < f64::EPSILON,
500 "last value wins"
501 );
502 }
503
504 #[test]
505 fn cdc_delete_is_preserved() {
506 let out = collapse_changelog(
507 &cdc_batch(&[(1, 10.0, "I"), (1, 10.0, "D")]),
508 &keys(&["id"]),
509 )
510 .unwrap();
511 assert_eq!(out.num_rows(), 1);
512 assert_eq!(col_str(&out, "_op"), vec!["D"]);
513 }
514
515 #[test]
516 fn cdc_update_before_normalizes_to_delete_and_after_to_upsert() {
517 let out_before =
519 collapse_changelog(&cdc_batch(&[(1, 10.0, "U-")]), &keys(&["id"])).unwrap();
520 assert_eq!(col_str(&out_before, "_op"), vec!["D"]);
521 let out_after = collapse_changelog(&cdc_batch(&[(1, 10.0, "U+")]), &keys(&["id"])).unwrap();
522 assert_eq!(col_str(&out_after, "_op"), vec!["U"]);
523 }
524
525 #[test]
526 fn cdc_no_op_column_treated_as_upsert() {
527 let schema = Arc::new(Schema::new(vec![
529 Field::new("id", DataType::Int64, false),
530 Field::new("value", DataType::Float64, false),
531 ]));
532 let batch = RecordBatch::try_new(
533 schema,
534 vec![
535 Arc::new(Int64Array::from(vec![1, 1, 2])),
536 Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0])),
537 ],
538 )
539 .unwrap();
540 let out = collapse_changelog(&batch, &keys(&["id"])).unwrap();
541 assert_eq!(out.num_rows(), 2);
542 assert!(out.schema().index_of("_op").is_ok());
543 assert_eq!(col_str(&out, "_op"), vec!["U", "U"]);
544 }
545
546 #[test]
547 fn cdc_strips_ts_ms_and_emits_single_op() {
548 let schema = Arc::new(Schema::new(vec![
549 Field::new("id", DataType::Int64, false),
550 Field::new("_ts_ms", DataType::Int64, false),
551 Field::new("_op", DataType::Utf8, false),
552 ]));
553 let batch = RecordBatch::try_new(
554 schema,
555 vec![
556 Arc::new(Int64Array::from(vec![1, 2])),
557 Arc::new(Int64Array::from(vec![100, 200])),
558 Arc::new(StringArray::from(vec!["I", "U"])),
559 ],
560 )
561 .unwrap();
562 let out = collapse_changelog(&batch, &keys(&["id"])).unwrap();
563 assert!(out.schema().index_of("_ts_ms").is_err(), "_ts_ms stripped");
564 assert_eq!(
565 out.schema()
566 .fields()
567 .iter()
568 .filter(|f| f.name() == "_op")
569 .count(),
570 1,
571 "exactly one _op column"
572 );
573 }
574
575 #[test]
576 fn empty_batch_yields_empty_with_op_column() {
577 let out = collapse_changelog(&zset_batch(&[]), &keys(&["region"])).unwrap();
578 assert_eq!(out.num_rows(), 0);
579 assert!(out.schema().index_of(WEIGHT_COLUMN).is_err());
580 assert!(out.schema().index_of("_op").is_ok());
581 }
582
583 #[test]
584 fn empty_merge_key_errors() {
585 let err = collapse_changelog(&zset_batch(&[("east", 10, 1)]), &[]).unwrap_err();
586 assert!(matches!(err, ConnectorError::ConfigurationError(_)));
587 }
588
589 #[test]
590 fn missing_merge_key_column_errors() {
591 let err =
592 collapse_changelog(&zset_batch(&[("east", 10, 1)]), &keys(&["nope"])).unwrap_err();
593 assert!(matches!(err, ConnectorError::ConfigurationError(_)));
594 assert!(format!("{err}").contains("not present"));
595 }
596}