1use std::io::Cursor;
34
35use rustc_hash::FxHashMap;
36
37use arrow::ipc;
38use arrow::record_batch::RecordBatch;
39use bytes::Bytes;
40use serde::{Deserialize, Serialize};
41
42use super::CrossPartitionAggregateStore;
43
44#[derive(Debug, thiserror::Error)]
48pub enum TwoPhaseError {
49 #[error("arrow error: {0}")]
51 Arrow(#[from] arrow::error::ArrowError),
52
53 #[error("serialization error: {0}")]
55 Serde(String),
56
57 #[error("function count mismatch: expected {expected}, got {actual}")]
59 FunctionCountMismatch {
60 expected: usize,
62 actual: usize,
64 },
65
66 #[error("unsupported two-phase function: {0}")]
68 UnsupportedFunction(String),
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
75pub enum TwoPhaseKind {
76 Count,
78 Sum,
80 Avg,
82 Min,
84 Max,
86 ApproxDistinct,
88}
89
90impl TwoPhaseKind {
91 #[must_use]
103 pub fn from_name(name: &str) -> Option<Self> {
104 match name.to_ascii_uppercase().as_str() {
105 "COUNT" => Some(Self::Count),
106 "SUM" => Some(Self::Sum),
107 "AVG" => Some(Self::Avg),
108 "MIN" => Some(Self::Min),
109 "MAX" => Some(Self::Max),
110 "APPROX_COUNT_DISTINCT" | "APPROX_DISTINCT" => Some(Self::ApproxDistinct),
111 _ => None,
112 }
113 }
114}
115
116#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
123pub enum PartialState {
124 Count(i64),
126 Sum(f64),
128 Avg {
130 sum: f64,
132 count: i64,
134 },
135 Min(Option<f64>),
137 Max(Option<f64>),
139 ApproxDistinct(Vec<u8>),
141}
142
143impl PartialState {
144 #[must_use]
146 pub fn empty(kind: TwoPhaseKind) -> Self {
147 match kind {
148 TwoPhaseKind::Count => Self::Count(0),
149 TwoPhaseKind::Sum => Self::Sum(0.0),
150 TwoPhaseKind::Avg => Self::Avg { sum: 0.0, count: 0 },
151 TwoPhaseKind::Min => Self::Min(None),
152 TwoPhaseKind::Max => Self::Max(None),
153 TwoPhaseKind::ApproxDistinct => Self::ApproxDistinct(HllSketch::new().to_bytes()),
154 }
155 }
156
157 pub fn merge(&mut self, other: &Self) {
161 match (self, other) {
162 (Self::Count(a), Self::Count(b)) => *a += b,
163 (Self::Sum(a), Self::Sum(b)) => *a += b,
164 (Self::Avg { sum: s1, count: c1 }, Self::Avg { sum: s2, count: c2 }) => {
165 *s1 += s2;
166 *c1 += c2;
167 }
168 (Self::Min(a), Self::Min(b)) => {
169 *a = match (*a, *b) {
170 (None, v) | (v, None) => v,
171 (Some(x), Some(y)) => Some(x.min(y)),
172 };
173 }
174 (Self::Max(a), Self::Max(b)) => {
175 *a = match (*a, *b) {
176 (None, v) | (v, None) => v,
177 (Some(x), Some(y)) => Some(x.max(y)),
178 };
179 }
180 (Self::ApproxDistinct(a), Self::ApproxDistinct(b)) => {
181 if let (Ok(mut hll_a), Ok(hll_b)) =
182 (HllSketch::from_bytes(a), HllSketch::from_bytes(b))
183 {
184 hll_a.merge(&hll_b);
185 *a = hll_a.to_bytes();
186 }
187 }
188 _ => {} }
190 }
191
192 #[must_use]
194 #[allow(clippy::cast_precision_loss)]
195 pub fn finalize(&self) -> f64 {
196 match self {
197 Self::Count(n) => *n as f64,
198 Self::Sum(s) => *s,
199 Self::Avg { sum, count } => {
200 if *count > 0 {
201 sum / (*count as f64)
202 } else {
203 0.0
204 }
205 }
206 Self::Min(v) | Self::Max(v) => v.unwrap_or(f64::NAN),
207 Self::ApproxDistinct(bytes) => {
208 HllSketch::from_bytes(bytes).map_or(0.0, |h| h.estimate())
209 }
210 }
211 }
212}
213
214#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
221pub struct PartialAggregate {
222 pub group_key: Vec<u8>,
224 pub partition_id: u32,
226 pub states: Vec<PartialState>,
228 pub watermark_ms: i64,
230 pub epoch: u64,
232}
233
234pub struct MergeAggregator {
242 kinds: Vec<TwoPhaseKind>,
243}
244
245impl MergeAggregator {
246 #[must_use]
248 pub fn new(kinds: Vec<TwoPhaseKind>) -> Self {
249 Self { kinds }
250 }
251
252 pub fn merge_group(
259 &self,
260 partials: &[&PartialAggregate],
261 ) -> Result<Vec<PartialState>, TwoPhaseError> {
262 let mut merged: Vec<PartialState> =
263 self.kinds.iter().map(|k| PartialState::empty(*k)).collect();
264
265 for partial in partials {
266 if partial.states.len() != self.kinds.len() {
267 return Err(TwoPhaseError::FunctionCountMismatch {
268 expected: self.kinds.len(),
269 actual: partial.states.len(),
270 });
271 }
272 for (target, source) in merged.iter_mut().zip(&partial.states) {
273 target.merge(source);
274 }
275 }
276
277 Ok(merged)
278 }
279
280 pub fn merge_all(
290 &self,
291 partials: &[PartialAggregate],
292 ) -> Result<FxHashMap<Vec<u8>, Vec<PartialState>>, TwoPhaseError> {
293 let mut by_group: FxHashMap<&[u8], Vec<&PartialAggregate>> = FxHashMap::default();
294 for partial in partials {
295 by_group
296 .entry(&partial.group_key)
297 .or_default()
298 .push(partial);
299 }
300
301 let mut result =
302 FxHashMap::with_capacity_and_hasher(by_group.len(), rustc_hash::FxBuildHasher);
303 for (key, group_partials) in by_group {
304 let merged = self.merge_group(&group_partials)?;
305 result.insert(key.to_vec(), merged);
306 }
307
308 Ok(result)
309 }
310
311 #[must_use]
313 pub fn finalize(states: &[PartialState]) -> Vec<f64> {
314 states.iter().map(PartialState::finalize).collect()
315 }
316
317 #[must_use]
319 pub fn num_functions(&self) -> usize {
320 self.kinds.len()
321 }
322
323 #[must_use]
325 pub fn kinds(&self) -> &[TwoPhaseKind] {
326 &self.kinds
327 }
328}
329
330pub fn publish_partials(
341 store: &CrossPartitionAggregateStore,
342 partials: &[PartialAggregate],
343) -> Result<(), TwoPhaseError> {
344 for partial in partials {
345 let serialized =
346 serde_json::to_vec(partial).map_err(|e| TwoPhaseError::Serde(e.to_string()))?;
347 store.publish(
348 Bytes::copy_from_slice(&partial.group_key),
349 partial.partition_id,
350 Bytes::from(serialized),
351 );
352 }
353 Ok(())
354}
355
356#[must_use]
359pub fn collect_partials(
360 store: &CrossPartitionAggregateStore,
361 group_key: &[u8],
362) -> Vec<PartialAggregate> {
363 store
364 .collect_partials(group_key)
365 .into_iter()
366 .filter_map(|(_pid, bytes)| serde_json::from_slice(&bytes).ok())
367 .collect()
368}
369
370pub fn encode_batch_to_ipc(batch: &RecordBatch) -> Result<Vec<u8>, TwoPhaseError> {
378 let mut buf = Vec::new();
379 {
380 let mut writer = ipc::writer::FileWriter::try_new(&mut buf, &batch.schema())?;
381 writer.write(batch)?;
382 writer.finish()?;
383 }
384 Ok(buf)
385}
386
387pub fn decode_batch_from_ipc(bytes: &[u8]) -> Result<RecordBatch, TwoPhaseError> {
394 let cursor = Cursor::new(bytes);
395 let mut reader = ipc::reader::FileReader::try_new(cursor, None)?;
396 reader
397 .next()
398 .ok_or_else(|| TwoPhaseError::Serde("empty IPC stream".into()))?
399 .map_err(TwoPhaseError::Arrow)
400}
401
402pub fn serialize_partials(partials: &[PartialAggregate]) -> Result<Vec<u8>, TwoPhaseError> {
408 serde_json::to_vec(partials).map_err(|e| TwoPhaseError::Serde(e.to_string()))
409}
410
411pub fn deserialize_partials(bytes: &[u8]) -> Result<Vec<PartialAggregate>, TwoPhaseError> {
417 serde_json::from_slice(bytes).map_err(|e| TwoPhaseError::Serde(e.to_string()))
418}
419
420#[must_use]
437pub fn can_use_two_phase(function_names: &[&str]) -> bool {
438 !function_names.is_empty()
439 && function_names
440 .iter()
441 .all(|n| TwoPhaseKind::from_name(n).is_some())
442}
443
444#[derive(Debug, Clone)]
455pub struct HllSketch {
456 registers: Vec<u8>,
457 precision: u8,
458}
459
460const DEFAULT_HLL_PRECISION: u8 = 8;
461
462impl HllSketch {
463 #[must_use]
465 pub fn new() -> Self {
466 Self::with_precision(DEFAULT_HLL_PRECISION)
467 }
468
469 #[must_use]
477 pub fn with_precision(precision: u8) -> Self {
478 assert!(
479 (4..=18).contains(&precision),
480 "HLL precision must be 4..=18, got {precision}"
481 );
482 let num_registers = 1usize << precision;
483 Self {
484 registers: vec![0; num_registers],
485 precision,
486 }
487 }
488
489 #[allow(clippy::cast_possible_truncation)]
491 pub fn add_hash(&mut self, hash: u64) {
492 let p = u32::from(self.precision);
493 let idx = (hash >> (64 - p)) as usize;
494 let w = hash << p;
495 let rho = if w == 0 {
496 64 - p + 1
497 } else {
498 w.leading_zeros() + 1
499 } as u8;
500 if rho > self.registers[idx] {
501 self.registers[idx] = rho;
502 }
503 }
504
505 pub fn merge(&mut self, other: &Self) {
511 assert_eq!(
512 self.precision, other.precision,
513 "HLL precision mismatch: {} vs {}",
514 self.precision, other.precision
515 );
516 for (a, &b) in self.registers.iter_mut().zip(&other.registers) {
517 *a = (*a).max(b);
518 }
519 }
520
521 #[must_use]
523 #[allow(clippy::cast_precision_loss)]
524 pub fn estimate(&self) -> f64 {
525 let m = self.registers.len() as f64;
526 let alpha = match self.precision {
527 4 => 0.673,
528 5 => 0.697,
529 6 => 0.709,
530 _ => 0.7213 / (1.0 + 1.079 / m),
531 };
532
533 let harmonic_sum: f64 = self
534 .registers
535 .iter()
536 .map(|&r| 2.0_f64.powi(-i32::from(r)))
537 .sum();
538
539 let raw_estimate = alpha * m * m / harmonic_sum;
540
541 if raw_estimate <= 2.5 * m {
543 #[allow(clippy::naive_bytecount)] let zeros = self.registers.iter().filter(|&&r| r == 0).count() as f64;
545 if zeros > 0.0 {
546 m * (m / zeros).ln()
547 } else {
548 raw_estimate
549 }
550 } else {
551 raw_estimate
552 }
553 }
554
555 #[must_use]
557 pub fn to_bytes(&self) -> Vec<u8> {
558 let mut buf = Vec::with_capacity(1 + self.registers.len());
559 buf.push(self.precision);
560 buf.extend_from_slice(&self.registers);
561 buf
562 }
563
564 pub fn from_bytes(bytes: &[u8]) -> Result<Self, TwoPhaseError> {
570 if bytes.is_empty() {
571 return Err(TwoPhaseError::Serde("empty HLL bytes".into()));
572 }
573 let precision = bytes[0];
574 if !(4..=18).contains(&precision) {
575 return Err(TwoPhaseError::Serde(format!(
576 "HLL precision {precision} out of range 4..=18"
577 )));
578 }
579 let expected_len = 1 + (1usize << precision);
580 if bytes.len() != expected_len {
581 return Err(TwoPhaseError::Serde(format!(
582 "HLL bytes length mismatch: expected {expected_len}, got {}",
583 bytes.len()
584 )));
585 }
586 Ok(Self {
587 precision,
588 registers: bytes[1..].to_vec(),
589 })
590 }
591
592 #[must_use]
594 pub fn num_registers(&self) -> usize {
595 self.registers.len()
596 }
597
598 #[must_use]
600 pub fn precision(&self) -> u8 {
601 self.precision
602 }
603}
604
605impl Default for HllSketch {
606 fn default() -> Self {
607 Self::new()
608 }
609}
610
611#[cfg(feature = "delta")]
614mod delta_bridge {
615 use super::PartialState;
616 use crate::aggregation::gossip_aggregates::AggregateState;
617
618 impl PartialState {
619 #[must_use]
621 pub fn to_aggregate_state(&self) -> AggregateState {
622 match self {
623 Self::Count(n) => AggregateState::Count(*n),
624 Self::Sum(s) => AggregateState::Sum(*s),
625 Self::Avg { sum, count } => AggregateState::Avg {
626 sum: *sum,
627 count: *count,
628 },
629 Self::Min(v) => AggregateState::Min(v.unwrap_or(f64::NAN)),
630 Self::Max(v) => AggregateState::Max(v.unwrap_or(f64::NAN)),
631 Self::ApproxDistinct(bytes) => AggregateState::Custom(bytes.clone()),
632 }
633 }
634
635 #[must_use]
637 pub fn from_aggregate_state(state: &AggregateState) -> Self {
638 match state {
639 AggregateState::Count(n) => Self::Count(*n),
640 AggregateState::Sum(s) => Self::Sum(*s),
641 AggregateState::Avg { sum, count } => Self::Avg {
642 sum: *sum,
643 count: *count,
644 },
645 AggregateState::Min(v) => Self::Min(if v.is_nan() { None } else { Some(*v) }),
646 AggregateState::Max(v) => Self::Max(if v.is_nan() { None } else { Some(*v) }),
647 AggregateState::Custom(bytes) => Self::ApproxDistinct(bytes.clone()),
648 }
649 }
650 }
651}
652
653#[cfg(test)]
654mod tests {
655 use super::*;
656 use arrow::array::{Float64Array, StringArray};
657 use arrow::datatypes::{DataType, Field, Schema};
658 use std::sync::Arc;
659
660 #[test]
663 fn test_kind_from_name() {
664 assert_eq!(TwoPhaseKind::from_name("COUNT"), Some(TwoPhaseKind::Count));
665 assert_eq!(TwoPhaseKind::from_name("sum"), Some(TwoPhaseKind::Sum));
666 assert_eq!(TwoPhaseKind::from_name("Avg"), Some(TwoPhaseKind::Avg));
667 assert_eq!(TwoPhaseKind::from_name("MIN"), Some(TwoPhaseKind::Min));
668 assert_eq!(TwoPhaseKind::from_name("max"), Some(TwoPhaseKind::Max));
669 assert_eq!(
670 TwoPhaseKind::from_name("APPROX_COUNT_DISTINCT"),
671 Some(TwoPhaseKind::ApproxDistinct)
672 );
673 assert_eq!(
674 TwoPhaseKind::from_name("APPROX_DISTINCT"),
675 Some(TwoPhaseKind::ApproxDistinct)
676 );
677 assert_eq!(TwoPhaseKind::from_name("MEDIAN"), None);
678 assert_eq!(TwoPhaseKind::from_name(""), None);
679 }
680
681 #[test]
682 fn test_can_use_two_phase() {
683 assert!(can_use_two_phase(&["COUNT", "SUM"]));
684 assert!(can_use_two_phase(&["AVG"]));
685 assert!(can_use_two_phase(&["MIN", "MAX", "COUNT"]));
686 assert!(!can_use_two_phase(&["COUNT", "MEDIAN"]));
687 assert!(!can_use_two_phase(&[]));
688 }
689
690 #[test]
693 fn test_merge_count() {
694 let mut a = PartialState::Count(10);
695 a.merge(&PartialState::Count(5));
696 assert_eq!(a, PartialState::Count(15));
697 }
698
699 #[test]
700 fn test_merge_sum() {
701 let mut a = PartialState::Sum(1.5);
702 a.merge(&PartialState::Sum(2.5));
703 assert_eq!(a, PartialState::Sum(4.0));
704 }
705
706 #[test]
707 fn test_merge_avg() {
708 let mut a = PartialState::Avg {
709 sum: 10.0,
710 count: 2,
711 };
712 a.merge(&PartialState::Avg {
713 sum: 20.0,
714 count: 3,
715 });
716 match a {
717 PartialState::Avg { sum, count } => {
718 assert!((sum - 30.0).abs() < f64::EPSILON);
719 assert_eq!(count, 5);
720 }
721 _ => panic!("expected Avg"),
722 }
723 }
724
725 #[test]
726 fn test_merge_min() {
727 let mut a = PartialState::Min(Some(10.0));
728 a.merge(&PartialState::Min(Some(5.0)));
729 assert_eq!(a, PartialState::Min(Some(5.0)));
730
731 let mut b = PartialState::Min(None);
733 b.merge(&PartialState::Min(Some(3.0)));
734 assert_eq!(b, PartialState::Min(Some(3.0)));
735
736 let mut c = PartialState::Min(Some(7.0));
738 c.merge(&PartialState::Min(None));
739 assert_eq!(c, PartialState::Min(Some(7.0)));
740
741 let mut d = PartialState::Min(None);
743 d.merge(&PartialState::Min(None));
744 assert_eq!(d, PartialState::Min(None));
745 }
746
747 #[test]
748 fn test_merge_max() {
749 let mut a = PartialState::Max(Some(5.0));
750 a.merge(&PartialState::Max(Some(10.0)));
751 assert_eq!(a, PartialState::Max(Some(10.0)));
752
753 let mut b = PartialState::Max(None);
754 b.merge(&PartialState::Max(Some(3.0)));
755 assert_eq!(b, PartialState::Max(Some(3.0)));
756 }
757
758 #[test]
759 fn test_merge_type_mismatch_noop() {
760 let mut a = PartialState::Count(10);
761 a.merge(&PartialState::Sum(5.0));
762 assert_eq!(a, PartialState::Count(10));
763 }
764
765 #[test]
768 fn test_finalize_count() {
769 assert!((PartialState::Count(42).finalize() - 42.0).abs() < f64::EPSILON);
770 }
771
772 #[test]
773 fn test_finalize_avg() {
774 let avg = PartialState::Avg {
775 sum: 10.0,
776 count: 4,
777 };
778 assert!((avg.finalize() - 2.5).abs() < f64::EPSILON);
779 }
780
781 #[test]
782 fn test_finalize_avg_zero_count() {
783 let avg = PartialState::Avg { sum: 0.0, count: 0 };
784 assert!((avg.finalize()).abs() < f64::EPSILON);
785 }
786
787 #[test]
788 fn test_finalize_min_none() {
789 assert!(PartialState::Min(None).finalize().is_nan());
790 }
791
792 #[test]
793 fn test_finalize_min_some() {
794 assert!((PartialState::Min(Some(3.25)).finalize() - 3.25).abs() < f64::EPSILON);
795 }
796
797 #[test]
800 fn test_merge_single_group_three_partitions() {
801 let aggregator = MergeAggregator::new(vec![TwoPhaseKind::Count, TwoPhaseKind::Sum]);
802
803 let partials = vec![
804 PartialAggregate {
805 group_key: b"AAPL".to_vec(),
806 partition_id: 0,
807 states: vec![PartialState::Count(500), PartialState::Sum(75000.0)],
808 watermark_ms: 1000,
809 epoch: 1,
810 },
811 PartialAggregate {
812 group_key: b"AAPL".to_vec(),
813 partition_id: 1,
814 states: vec![PartialState::Count(300), PartialState::Sum(45000.0)],
815 watermark_ms: 1000,
816 epoch: 1,
817 },
818 PartialAggregate {
819 group_key: b"AAPL".to_vec(),
820 partition_id: 2,
821 states: vec![PartialState::Count(200), PartialState::Sum(30000.0)],
822 watermark_ms: 1000,
823 epoch: 1,
824 },
825 ];
826
827 let result = aggregator.merge_all(&partials).unwrap();
828 assert_eq!(result.len(), 1);
829
830 let merged = &result[b"AAPL".as_ref()];
831 assert_eq!(merged[0], PartialState::Count(1000));
832 assert_eq!(merged[1], PartialState::Sum(150_000.0));
833
834 let finals = MergeAggregator::finalize(merged);
835 assert!((finals[0] - 1000.0).abs() < f64::EPSILON);
836 assert!((finals[1] - 150_000.0).abs() < f64::EPSILON);
837 }
838
839 #[test]
840 fn test_merge_multi_group() {
841 let aggregator = MergeAggregator::new(vec![TwoPhaseKind::Count]);
842
843 let partials = vec![
844 PartialAggregate {
845 group_key: b"AAPL".to_vec(),
846 partition_id: 0,
847 states: vec![PartialState::Count(10)],
848 watermark_ms: 1000,
849 epoch: 1,
850 },
851 PartialAggregate {
852 group_key: b"GOOG".to_vec(),
853 partition_id: 0,
854 states: vec![PartialState::Count(20)],
855 watermark_ms: 1000,
856 epoch: 1,
857 },
858 PartialAggregate {
859 group_key: b"AAPL".to_vec(),
860 partition_id: 1,
861 states: vec![PartialState::Count(30)],
862 watermark_ms: 1000,
863 epoch: 1,
864 },
865 ];
866
867 let result = aggregator.merge_all(&partials).unwrap();
868 assert_eq!(result.len(), 2);
869 assert_eq!(result[b"AAPL".as_ref()][0], PartialState::Count(40));
870 assert_eq!(result[b"GOOG".as_ref()][0], PartialState::Count(20));
871 }
872
873 #[test]
874 fn test_merge_avg_weighted() {
875 let aggregator = MergeAggregator::new(vec![TwoPhaseKind::Avg]);
876
877 let partials = vec![
881 PartialAggregate {
882 group_key: b"g1".to_vec(),
883 partition_id: 0,
884 states: vec![PartialState::Avg {
885 sum: 60.0,
886 count: 3,
887 }],
888 watermark_ms: 1000,
889 epoch: 1,
890 },
891 PartialAggregate {
892 group_key: b"g1".to_vec(),
893 partition_id: 1,
894 states: vec![PartialState::Avg {
895 sum: 90.0,
896 count: 2,
897 }],
898 watermark_ms: 1000,
899 epoch: 1,
900 },
901 ];
902
903 let result = aggregator.merge_all(&partials).unwrap();
904 let finals = MergeAggregator::finalize(&result[b"g1".as_ref()]);
905 assert!((finals[0] - 30.0).abs() < f64::EPSILON);
906 }
907
908 #[test]
909 fn test_merge_min_max_global() {
910 let aggregator = MergeAggregator::new(vec![TwoPhaseKind::Min, TwoPhaseKind::Max]);
911
912 let partials = vec![
913 PartialAggregate {
914 group_key: b"g".to_vec(),
915 partition_id: 0,
916 states: vec![PartialState::Min(Some(10.0)), PartialState::Max(Some(90.0))],
917 watermark_ms: 1000,
918 epoch: 1,
919 },
920 PartialAggregate {
921 group_key: b"g".to_vec(),
922 partition_id: 1,
923 states: vec![PartialState::Min(Some(5.0)), PartialState::Max(Some(100.0))],
924 watermark_ms: 1000,
925 epoch: 1,
926 },
927 PartialAggregate {
928 group_key: b"g".to_vec(),
929 partition_id: 2,
930 states: vec![PartialState::Min(Some(15.0)), PartialState::Max(Some(80.0))],
931 watermark_ms: 1000,
932 epoch: 1,
933 },
934 ];
935
936 let result = aggregator.merge_all(&partials).unwrap();
937 let merged = &result[b"g".as_ref()];
938 assert_eq!(merged[0], PartialState::Min(Some(5.0)));
939 assert_eq!(merged[1], PartialState::Max(Some(100.0)));
940 }
941
942 #[test]
943 fn test_merge_empty_partials() {
944 let aggregator = MergeAggregator::new(vec![TwoPhaseKind::Count]);
945 let result = aggregator.merge_all(&[]).unwrap();
946 assert!(result.is_empty());
947 }
948
949 #[test]
950 fn test_merge_function_count_mismatch() {
951 let aggregator = MergeAggregator::new(vec![TwoPhaseKind::Count, TwoPhaseKind::Sum]);
952
953 let bad = PartialAggregate {
954 group_key: b"g".to_vec(),
955 partition_id: 0,
956 states: vec![PartialState::Count(1)], watermark_ms: 0,
958 epoch: 0,
959 };
960
961 let refs: Vec<&PartialAggregate> = vec![&bad];
962 let err = aggregator.merge_group(&refs).unwrap_err();
963 match err {
964 TwoPhaseError::FunctionCountMismatch {
965 expected: 2,
966 actual: 1,
967 } => {}
968 other => panic!("expected FunctionCountMismatch, got {other:?}"),
969 }
970 }
971
972 #[test]
975 fn test_arrow_ipc_roundtrip() {
976 let schema = Arc::new(Schema::new(vec![
977 Field::new("symbol", DataType::Utf8, false),
978 Field::new("count", DataType::Float64, false),
979 ]));
980 let batch = RecordBatch::try_new(
981 schema,
982 vec![
983 Arc::new(StringArray::from(vec!["AAPL", "GOOG"])),
984 Arc::new(Float64Array::from(vec![1000.0, 500.0])),
985 ],
986 )
987 .unwrap();
988
989 let ipc_bytes = encode_batch_to_ipc(&batch).unwrap();
990 assert!(!ipc_bytes.is_empty());
991
992 let decoded = decode_batch_from_ipc(&ipc_bytes).unwrap();
993 assert_eq!(decoded.num_rows(), 2);
994 assert_eq!(decoded.num_columns(), 2);
995 assert_eq!(decoded.schema(), batch.schema());
996 }
997
998 #[test]
999 fn test_ipc_decode_invalid() {
1000 let result = decode_batch_from_ipc(b"not valid ipc");
1001 assert!(result.is_err());
1002 }
1003
1004 #[test]
1007 fn test_serialize_deserialize_partials() {
1008 let partials = vec![
1009 PartialAggregate {
1010 group_key: b"AAPL".to_vec(),
1011 partition_id: 0,
1012 states: vec![PartialState::Count(42), PartialState::Sum(100.5)],
1013 watermark_ms: 1000,
1014 epoch: 5,
1015 },
1016 PartialAggregate {
1017 group_key: b"GOOG".to_vec(),
1018 partition_id: 1,
1019 states: vec![PartialState::Count(10), PartialState::Sum(50.0)],
1020 watermark_ms: 1000,
1021 epoch: 5,
1022 },
1023 ];
1024
1025 let bytes = serialize_partials(&partials).unwrap();
1026 let decoded = deserialize_partials(&bytes).unwrap();
1027 assert_eq!(decoded, partials);
1028 }
1029
1030 #[test]
1031 fn test_deserialize_invalid() {
1032 let result = deserialize_partials(b"not json");
1033 assert!(result.is_err());
1034 }
1035
1036 #[test]
1039 fn test_store_publish_collect_roundtrip() {
1040 let store = CrossPartitionAggregateStore::new(3);
1041
1042 let partials = vec![
1043 PartialAggregate {
1044 group_key: b"AAPL".to_vec(),
1045 partition_id: 0,
1046 states: vec![PartialState::Count(100)],
1047 watermark_ms: 1000,
1048 epoch: 1,
1049 },
1050 PartialAggregate {
1051 group_key: b"AAPL".to_vec(),
1052 partition_id: 1,
1053 states: vec![PartialState::Count(200)],
1054 watermark_ms: 1000,
1055 epoch: 1,
1056 },
1057 ];
1058
1059 publish_partials(&store, &partials).unwrap();
1060
1061 let collected = collect_partials(&store, b"AAPL");
1062 assert_eq!(collected.len(), 2);
1063
1064 let total: i64 = collected
1066 .iter()
1067 .map(|p| match &p.states[0] {
1068 PartialState::Count(n) => *n,
1069 _ => panic!("expected Count"),
1070 })
1071 .sum();
1072 assert_eq!(total, 300);
1073 }
1074
1075 #[test]
1078 fn test_three_partition_full_pipeline() {
1079 let store = CrossPartitionAggregateStore::new(3);
1080 let aggregator = MergeAggregator::new(vec![
1081 TwoPhaseKind::Count,
1082 TwoPhaseKind::Sum,
1083 TwoPhaseKind::Avg,
1084 ]);
1085
1086 for pid in 0..3u32 {
1088 let count = (i64::from(pid) + 1) * 100; let sum = (f64::from(pid) + 1.0) * 1000.0; let partial = PartialAggregate {
1091 group_key: b"AAPL".to_vec(),
1092 partition_id: pid,
1093 states: vec![
1094 PartialState::Count(count),
1095 PartialState::Sum(sum),
1096 PartialState::Avg { sum, count },
1097 ],
1098 watermark_ms: 2000,
1099 epoch: 5,
1100 };
1101 publish_partials(&store, &[partial]).unwrap();
1102 }
1103
1104 let collected = collect_partials(&store, b"AAPL");
1106 assert_eq!(collected.len(), 3);
1107
1108 let result = aggregator.merge_all(&collected).unwrap();
1109 let merged = &result[b"AAPL".as_ref()];
1110
1111 assert_eq!(merged[0], PartialState::Count(600));
1113
1114 assert_eq!(merged[1], PartialState::Sum(6000.0));
1116
1117 let finals = MergeAggregator::finalize(merged);
1119 assert!((finals[2] - 10.0).abs() < f64::EPSILON);
1120 }
1121
1122 fn test_hash(x: u64) -> u64 {
1126 let mut h = x.wrapping_mul(0x517c_c1b7_2722_0a95);
1127 h ^= h >> 33;
1128 h = h.wrapping_mul(0xff51_afd7_ed55_8ccd);
1129 h ^= h >> 33;
1130 h
1131 }
1132
1133 #[test]
1134 fn test_hll_basic() {
1135 let mut hll = HllSketch::new();
1136 assert_eq!(hll.num_registers(), 256);
1137 assert_eq!(hll.precision(), 8);
1138
1139 assert!(hll.estimate() < 1.0);
1141
1142 for i in 0..1000u64 {
1144 hll.add_hash(test_hash(i));
1145 }
1146
1147 let est = hll.estimate();
1148 assert!(est > 700.0, "estimate {est} too low");
1151 assert!(est < 1400.0, "estimate {est} too high");
1152 }
1153
1154 #[test]
1155 fn test_hll_merge() {
1156 let mut hll_a = HllSketch::new();
1157 let mut hll_b = HllSketch::new();
1158
1159 for i in 0..500u64 {
1161 hll_a.add_hash(test_hash(i));
1162 }
1163 for i in 500..1000u64 {
1164 hll_b.add_hash(test_hash(i));
1165 }
1166
1167 let est_a = hll_a.estimate();
1168 let est_b = hll_b.estimate();
1169
1170 hll_a.merge(&hll_b);
1171 let est_merged = hll_a.estimate();
1172
1173 assert!(est_merged > est_a, "merged should be >= individual");
1175 assert!(est_merged > est_b, "merged should be >= individual");
1176 assert!(est_merged > 700.0, "merged {est_merged} too low");
1178 assert!(est_merged < 1400.0, "merged {est_merged} too high");
1179 }
1180
1181 #[test]
1182 fn test_hll_serialization_roundtrip() {
1183 let mut hll = HllSketch::new();
1184 for i in 0..100u64 {
1185 hll.add_hash(i * 12345);
1186 }
1187
1188 let bytes = hll.to_bytes();
1189 assert_eq!(bytes.len(), 1 + 256); let restored = HllSketch::from_bytes(&bytes).unwrap();
1192 assert_eq!(restored.precision(), hll.precision());
1193 assert!(
1194 (restored.estimate() - hll.estimate()).abs() < f64::EPSILON,
1195 "estimate should be identical after roundtrip"
1196 );
1197 }
1198
1199 #[test]
1200 fn test_hll_from_bytes_invalid() {
1201 assert!(HllSketch::from_bytes(&[]).is_err());
1202 assert!(HllSketch::from_bytes(&[3]).is_err()); assert!(HllSketch::from_bytes(&[8, 1, 2]).is_err()); }
1205
1206 #[test]
1207 fn test_hll_merge_partial_state() {
1208 let mut hll_a = HllSketch::new();
1209 let mut hll_b = HllSketch::new();
1210
1211 for i in 0..500u64 {
1212 hll_a.add_hash(test_hash(i));
1213 }
1214 for i in 500..1000u64 {
1215 hll_b.add_hash(test_hash(i));
1216 }
1217
1218 let mut state_a = PartialState::ApproxDistinct(hll_a.to_bytes());
1219 let state_b = PartialState::ApproxDistinct(hll_b.to_bytes());
1220
1221 state_a.merge(&state_b);
1222
1223 let est = state_a.finalize();
1224 assert!(est > 700.0, "HLL partial merge estimate {est} too low");
1225 assert!(est < 1400.0, "HLL partial merge estimate {est} too high");
1226 }
1227
1228 #[test]
1231 fn test_partial_state_empty() {
1232 assert_eq!(
1233 PartialState::empty(TwoPhaseKind::Count),
1234 PartialState::Count(0)
1235 );
1236 assert_eq!(
1237 PartialState::empty(TwoPhaseKind::Sum),
1238 PartialState::Sum(0.0)
1239 );
1240 assert_eq!(
1241 PartialState::empty(TwoPhaseKind::Avg),
1242 PartialState::Avg { sum: 0.0, count: 0 }
1243 );
1244 assert_eq!(
1245 PartialState::empty(TwoPhaseKind::Min),
1246 PartialState::Min(None)
1247 );
1248 assert_eq!(
1249 PartialState::empty(TwoPhaseKind::Max),
1250 PartialState::Max(None)
1251 );
1252
1253 let empty_hll = PartialState::empty(TwoPhaseKind::ApproxDistinct);
1255 match &empty_hll {
1256 PartialState::ApproxDistinct(bytes) => {
1257 let sketch = HllSketch::from_bytes(bytes).unwrap();
1258 assert!(sketch.estimate() < 1.0);
1259 }
1260 _ => panic!("expected ApproxDistinct"),
1261 }
1262 }
1263
1264 #[test]
1267 fn test_merge_aggregator_accessors() {
1268 let aggregator = MergeAggregator::new(vec![TwoPhaseKind::Count, TwoPhaseKind::Avg]);
1269 assert_eq!(aggregator.num_functions(), 2);
1270 assert_eq!(
1271 aggregator.kinds(),
1272 &[TwoPhaseKind::Count, TwoPhaseKind::Avg]
1273 );
1274 }
1275
1276 #[test]
1279 fn test_partial_aggregate_json_roundtrip() {
1280 let pa = PartialAggregate {
1281 group_key: b"test".to_vec(),
1282 partition_id: 42,
1283 states: vec![
1284 PartialState::Count(100),
1285 PartialState::Avg {
1286 sum: 500.0,
1287 count: 10,
1288 },
1289 PartialState::Min(Some(1.5)),
1290 PartialState::Max(None),
1291 ],
1292 watermark_ms: 5000,
1293 epoch: 10,
1294 };
1295
1296 let json = serde_json::to_string(&pa).unwrap();
1297 let back: PartialAggregate = serde_json::from_str(&json).unwrap();
1298 assert_eq!(back, pa);
1299 }
1300}