Skip to main content

laminar_core/aggregation/
two_phase.rs

1//! Two-phase cross-partition aggregation.
2//!
3//! # Architecture
4//!
5//! ```text
6//! Partition 0: COUNT(*) GROUP BY symbol → partial(AAPL: 500)
7//! Partition 1: COUNT(*) GROUP BY symbol → partial(AAPL: 300)  → Ring 2 merge → final(AAPL: 800)
8//! Partition 2: COUNT(*) GROUP BY symbol → partial(AAPL: 200)
9//! ```
10//!
11//! Phase 1 (Partial): Each partition computes local aggregates and produces
12//! `PartialAggregate` entries — one per group key per partition.
13//!
14//! Phase 2 (Merge): A coordinator collects partials from all partitions,
15//! merges them via `MergeAggregator`, and produces final results.
16//!
17//! ## Supported Functions
18//!
19//! | Function | Partial State | Merge Logic |
20//! |----------|--------------|-------------|
21//! | COUNT | `count: i64` | sum of counts |
22//! | SUM | `sum: f64` | sum of sums |
23//! | AVG | `sum: f64, count: i64` | total\_sum / total\_count |
24//! | MIN | `min: Option<f64>` | min of mins |
25//! | MAX | `max: Option<f64>` | max of maxes |
26//! | APPROX\_DISTINCT | HLL sketch bytes | HLL union |
27//!
28//! ## Arrow IPC
29//!
30//! Partial results can be shipped between nodes as Arrow IPC-encoded
31//! `RecordBatch`es via `encode_batch_to_ipc` / `decode_batch_from_ipc`.
32
33use std::io::Cursor;
34
35use rustc_hash::FxHashMap;
36
37use arrow::ipc;
38use arrow::record_batch::RecordBatch;
39use bytes::Bytes;
40use serde::{Deserialize, Serialize};
41
42use super::CrossPartitionAggregateStore;
43
44// ── Errors ──────────────────────────────────────────────────────────
45
46/// Errors from two-phase aggregation.
47#[derive(Debug, thiserror::Error)]
48pub enum TwoPhaseError {
49    /// Arrow error during IPC encoding or decoding.
50    #[error("arrow error: {0}")]
51    Arrow(#[from] arrow::error::ArrowError),
52
53    /// Serialization or deserialization error.
54    #[error("serialization error: {0}")]
55    Serde(String),
56
57    /// Mismatched function count during merge.
58    #[error("function count mismatch: expected {expected}, got {actual}")]
59    FunctionCountMismatch {
60        /// Expected number of functions.
61        expected: usize,
62        /// Actual number of functions found.
63        actual: usize,
64    },
65
66    /// Unknown aggregate function name.
67    #[error("unsupported two-phase function: {0}")]
68    UnsupportedFunction(String),
69}
70
71// ── TwoPhaseKind ────────────────────────────────────────────────────
72
73/// Aggregate function kind that supports two-phase execution.
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
75pub enum TwoPhaseKind {
76    /// COUNT aggregate.
77    Count,
78    /// SUM aggregate.
79    Sum,
80    /// AVG aggregate (carried as sum + count).
81    Avg,
82    /// MIN aggregate.
83    Min,
84    /// MAX aggregate.
85    Max,
86    /// Approximate distinct count via `HyperLogLog`.
87    ApproxDistinct,
88}
89
90impl TwoPhaseKind {
91    /// Resolve a function kind from its SQL name.
92    ///
93    /// # Examples
94    ///
95    /// ```
96    /// use laminar_core::aggregation::two_phase::TwoPhaseKind;
97    ///
98    /// assert_eq!(TwoPhaseKind::from_name("COUNT"), Some(TwoPhaseKind::Count));
99    /// assert_eq!(TwoPhaseKind::from_name("avg"), Some(TwoPhaseKind::Avg));
100    /// assert_eq!(TwoPhaseKind::from_name("MEDIAN"), None);
101    /// ```
102    #[must_use]
103    pub fn from_name(name: &str) -> Option<Self> {
104        match name.to_ascii_uppercase().as_str() {
105            "COUNT" => Some(Self::Count),
106            "SUM" => Some(Self::Sum),
107            "AVG" => Some(Self::Avg),
108            "MIN" => Some(Self::Min),
109            "MAX" => Some(Self::Max),
110            "APPROX_COUNT_DISTINCT" | "APPROX_DISTINCT" => Some(Self::ApproxDistinct),
111            _ => None,
112        }
113    }
114}
115
116// ── PartialState ────────────────────────────────────────────────────
117
118/// Intermediate partial state for a single aggregate function.
119///
120/// Each variant carries the minimum state needed for the merge step
121/// to produce a correct final result.
122#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
123pub enum PartialState {
124    /// Partial count.
125    Count(i64),
126    /// Partial sum.
127    Sum(f64),
128    /// Partial average as (running sum, running count).
129    Avg {
130        /// Running sum.
131        sum: f64,
132        /// Running count.
133        count: i64,
134    },
135    /// Partial minimum.
136    Min(Option<f64>),
137    /// Partial maximum.
138    Max(Option<f64>),
139    /// `HyperLogLog` sketch bytes for approximate distinct count.
140    ApproxDistinct(Vec<u8>),
141}
142
143impl PartialState {
144    /// Create an empty state for the given function kind.
145    #[must_use]
146    pub fn empty(kind: TwoPhaseKind) -> Self {
147        match kind {
148            TwoPhaseKind::Count => Self::Count(0),
149            TwoPhaseKind::Sum => Self::Sum(0.0),
150            TwoPhaseKind::Avg => Self::Avg { sum: 0.0, count: 0 },
151            TwoPhaseKind::Min => Self::Min(None),
152            TwoPhaseKind::Max => Self::Max(None),
153            TwoPhaseKind::ApproxDistinct => Self::ApproxDistinct(HllSketch::new().to_bytes()),
154        }
155    }
156
157    /// Merge another partial state into this one (in place).
158    ///
159    /// Type-mismatched merges are silently ignored.
160    pub fn merge(&mut self, other: &Self) {
161        match (self, other) {
162            (Self::Count(a), Self::Count(b)) => *a += b,
163            (Self::Sum(a), Self::Sum(b)) => *a += b,
164            (Self::Avg { sum: s1, count: c1 }, Self::Avg { sum: s2, count: c2 }) => {
165                *s1 += s2;
166                *c1 += c2;
167            }
168            (Self::Min(a), Self::Min(b)) => {
169                *a = match (*a, *b) {
170                    (None, v) | (v, None) => v,
171                    (Some(x), Some(y)) => Some(x.min(y)),
172                };
173            }
174            (Self::Max(a), Self::Max(b)) => {
175                *a = match (*a, *b) {
176                    (None, v) | (v, None) => v,
177                    (Some(x), Some(y)) => Some(x.max(y)),
178                };
179            }
180            (Self::ApproxDistinct(a), Self::ApproxDistinct(b)) => {
181                if let (Ok(mut hll_a), Ok(hll_b)) =
182                    (HllSketch::from_bytes(a), HllSketch::from_bytes(b))
183                {
184                    hll_a.merge(&hll_b);
185                    *a = hll_a.to_bytes();
186                }
187            }
188            _ => {} // type mismatch — no-op
189        }
190    }
191
192    /// Finalize the partial state to an `f64` result.
193    #[must_use]
194    #[allow(clippy::cast_precision_loss)]
195    pub fn finalize(&self) -> f64 {
196        match self {
197            Self::Count(n) => *n as f64,
198            Self::Sum(s) => *s,
199            Self::Avg { sum, count } => {
200                if *count > 0 {
201                    sum / (*count as f64)
202                } else {
203                    0.0
204                }
205            }
206            Self::Min(v) | Self::Max(v) => v.unwrap_or(f64::NAN),
207            Self::ApproxDistinct(bytes) => {
208                HllSketch::from_bytes(bytes).map_or(0.0, |h| h.estimate())
209            }
210        }
211    }
212}
213
214// ── PartialAggregate ────────────────────────────────────────────────
215
216/// A partial aggregate entry for one group from one partition.
217///
218/// Contains the serialized group key, source partition ID,
219/// one [`PartialState`] per aggregate function, and metadata.
220#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
221pub struct PartialAggregate {
222    /// Serialized group key.
223    pub group_key: Vec<u8>,
224    /// Source partition.
225    pub partition_id: u32,
226    /// One partial state per aggregate function.
227    pub states: Vec<PartialState>,
228    /// Watermark at computation time (millis since epoch).
229    pub watermark_ms: i64,
230    /// Epoch at computation time.
231    pub epoch: u64,
232}
233
234// ── MergeAggregator ─────────────────────────────────────────────────
235
236/// Combines partial aggregates from multiple partitions into final results.
237///
238/// The merge step runs on Ring 2 (control plane). It groups partials by
239/// group key and merges each group's partial states into a single combined
240/// state per function.
241pub struct MergeAggregator {
242    kinds: Vec<TwoPhaseKind>,
243}
244
245impl MergeAggregator {
246    /// Create a new merge aggregator for the given function kinds.
247    #[must_use]
248    pub fn new(kinds: Vec<TwoPhaseKind>) -> Self {
249        Self { kinds }
250    }
251
252    /// Merge partials for a single group key.
253    ///
254    /// # Errors
255    ///
256    /// Returns [`TwoPhaseError::FunctionCountMismatch`] if any partial
257    /// has a different number of states than expected.
258    pub fn merge_group(
259        &self,
260        partials: &[&PartialAggregate],
261    ) -> Result<Vec<PartialState>, TwoPhaseError> {
262        let mut merged: Vec<PartialState> =
263            self.kinds.iter().map(|k| PartialState::empty(*k)).collect();
264
265        for partial in partials {
266            if partial.states.len() != self.kinds.len() {
267                return Err(TwoPhaseError::FunctionCountMismatch {
268                    expected: self.kinds.len(),
269                    actual: partial.states.len(),
270                });
271            }
272            for (target, source) in merged.iter_mut().zip(&partial.states) {
273                target.merge(source);
274            }
275        }
276
277        Ok(merged)
278    }
279
280    /// Merge all partials into per-group final states.
281    ///
282    /// Groups the partials by `group_key`, merges each group, and returns
283    /// a map of `group_key → merged_states`.
284    ///
285    /// # Errors
286    ///
287    /// Returns [`TwoPhaseError::FunctionCountMismatch`] if any partial
288    /// has a mismatched number of states.
289    pub fn merge_all(
290        &self,
291        partials: &[PartialAggregate],
292    ) -> Result<FxHashMap<Vec<u8>, Vec<PartialState>>, TwoPhaseError> {
293        let mut by_group: FxHashMap<&[u8], Vec<&PartialAggregate>> = FxHashMap::default();
294        for partial in partials {
295            by_group
296                .entry(&partial.group_key)
297                .or_default()
298                .push(partial);
299        }
300
301        let mut result =
302            FxHashMap::with_capacity_and_hasher(by_group.len(), rustc_hash::FxBuildHasher);
303        for (key, group_partials) in by_group {
304            let merged = self.merge_group(&group_partials)?;
305            result.insert(key.to_vec(), merged);
306        }
307
308        Ok(result)
309    }
310
311    /// Finalize a vector of merged partial states to `f64` results.
312    #[must_use]
313    pub fn finalize(states: &[PartialState]) -> Vec<f64> {
314        states.iter().map(PartialState::finalize).collect()
315    }
316
317    /// Number of aggregate functions.
318    #[must_use]
319    pub fn num_functions(&self) -> usize {
320        self.kinds.len()
321    }
322
323    /// The function kinds in order.
324    #[must_use]
325    pub fn kinds(&self) -> &[TwoPhaseKind] {
326        &self.kinds
327    }
328}
329
330// ── Store Integration ───────────────────────────────────────────────
331
332/// Publish partial aggregates to a [`CrossPartitionAggregateStore`].
333///
334/// Each partial is serialized to JSON and published under its
335/// `(group_key, partition_id)`.
336///
337/// # Errors
338///
339/// Returns [`TwoPhaseError::Serde`] if serialization fails.
340pub fn publish_partials(
341    store: &CrossPartitionAggregateStore,
342    partials: &[PartialAggregate],
343) -> Result<(), TwoPhaseError> {
344    for partial in partials {
345        let serialized =
346            serde_json::to_vec(partial).map_err(|e| TwoPhaseError::Serde(e.to_string()))?;
347        store.publish(
348            Bytes::copy_from_slice(&partial.group_key),
349            partial.partition_id,
350            Bytes::from(serialized),
351        );
352    }
353    Ok(())
354}
355
356/// Collect and deserialize partials for a group key from a
357/// [`CrossPartitionAggregateStore`].
358#[must_use]
359pub fn collect_partials(
360    store: &CrossPartitionAggregateStore,
361    group_key: &[u8],
362) -> Vec<PartialAggregate> {
363    store
364        .collect_partials(group_key)
365        .into_iter()
366        .filter_map(|(_pid, bytes)| serde_json::from_slice(&bytes).ok())
367        .collect()
368}
369
370// ── Arrow IPC ───────────────────────────────────────────────────────
371
372/// Encode a `RecordBatch` to Arrow IPC file format bytes.
373///
374/// # Errors
375///
376/// Returns [`TwoPhaseError::Arrow`] on encoding failure.
377pub fn encode_batch_to_ipc(batch: &RecordBatch) -> Result<Vec<u8>, TwoPhaseError> {
378    let mut buf = Vec::new();
379    {
380        let mut writer = ipc::writer::FileWriter::try_new(&mut buf, &batch.schema())?;
381        writer.write(batch)?;
382        writer.finish()?;
383    }
384    Ok(buf)
385}
386
387/// Decode a `RecordBatch` from Arrow IPC file format bytes.
388///
389/// # Errors
390///
391/// Returns [`TwoPhaseError::Arrow`] on decoding failure, or
392/// [`TwoPhaseError::Serde`] if the IPC stream contains no batches.
393pub fn decode_batch_from_ipc(bytes: &[u8]) -> Result<RecordBatch, TwoPhaseError> {
394    let cursor = Cursor::new(bytes);
395    let mut reader = ipc::reader::FileReader::try_new(cursor, None)?;
396    reader
397        .next()
398        .ok_or_else(|| TwoPhaseError::Serde("empty IPC stream".into()))?
399        .map_err(TwoPhaseError::Arrow)
400}
401
402/// Serialize partial aggregates to JSON bytes.
403///
404/// # Errors
405///
406/// Returns [`TwoPhaseError::Serde`] on serialization failure.
407pub fn serialize_partials(partials: &[PartialAggregate]) -> Result<Vec<u8>, TwoPhaseError> {
408    serde_json::to_vec(partials).map_err(|e| TwoPhaseError::Serde(e.to_string()))
409}
410
411/// Deserialize partial aggregates from JSON bytes.
412///
413/// # Errors
414///
415/// Returns [`TwoPhaseError::Serde`] on deserialization failure.
416pub fn deserialize_partials(bytes: &[u8]) -> Result<Vec<PartialAggregate>, TwoPhaseError> {
417    serde_json::from_slice(bytes).map_err(|e| TwoPhaseError::Serde(e.to_string()))
418}
419
420// ── Detection ───────────────────────────────────────────────────────
421
422/// Check if all functions in a query support two-phase execution.
423///
424/// Returns `false` for empty input or if any function is not two-phase
425/// compatible.
426///
427/// # Examples
428///
429/// ```
430/// use laminar_core::aggregation::two_phase::can_use_two_phase;
431///
432/// assert!(can_use_two_phase(&["COUNT", "SUM", "AVG"]));
433/// assert!(!can_use_two_phase(&["COUNT", "MEDIAN"]));
434/// assert!(!can_use_two_phase(&[]));
435/// ```
436#[must_use]
437pub fn can_use_two_phase(function_names: &[&str]) -> bool {
438    !function_names.is_empty()
439        && function_names
440            .iter()
441            .all(|n| TwoPhaseKind::from_name(n).is_some())
442}
443
444// ── HyperLogLog Sketch ─────────────────────────────────────────────
445
446/// Minimal `HyperLogLog` sketch for approximate distinct counting.
447///
448/// Uses `2^precision` registers. Default precision is 8 (256 registers,
449/// ~256 bytes of memory, ~2% standard error).
450///
451/// ## Wire Format
452///
453/// `[precision: u8][registers: u8 * (2^precision)]`
454#[derive(Debug, Clone)]
455pub struct HllSketch {
456    registers: Vec<u8>,
457    precision: u8,
458}
459
460const DEFAULT_HLL_PRECISION: u8 = 8;
461
462impl HllSketch {
463    /// Create a new empty sketch with default precision (8).
464    #[must_use]
465    pub fn new() -> Self {
466        Self::with_precision(DEFAULT_HLL_PRECISION)
467    }
468
469    /// Create a new empty sketch with the given precision.
470    ///
471    /// Allocates `2^precision` registers.
472    ///
473    /// # Panics
474    ///
475    /// Panics if `precision` is not in the range 4..=18.
476    #[must_use]
477    pub fn with_precision(precision: u8) -> Self {
478        assert!(
479            (4..=18).contains(&precision),
480            "HLL precision must be 4..=18, got {precision}"
481        );
482        let num_registers = 1usize << precision;
483        Self {
484            registers: vec![0; num_registers],
485            precision,
486        }
487    }
488
489    /// Add a pre-hashed 64-bit value to the sketch.
490    #[allow(clippy::cast_possible_truncation)]
491    pub fn add_hash(&mut self, hash: u64) {
492        let p = u32::from(self.precision);
493        let idx = (hash >> (64 - p)) as usize;
494        let w = hash << p;
495        let rho = if w == 0 {
496            64 - p + 1
497        } else {
498            w.leading_zeros() + 1
499        } as u8;
500        if rho > self.registers[idx] {
501            self.registers[idx] = rho;
502        }
503    }
504
505    /// Merge another sketch into this one (HLL union).
506    ///
507    /// # Panics
508    ///
509    /// Panics if the two sketches have different precisions.
510    pub fn merge(&mut self, other: &Self) {
511        assert_eq!(
512            self.precision, other.precision,
513            "HLL precision mismatch: {} vs {}",
514            self.precision, other.precision
515        );
516        for (a, &b) in self.registers.iter_mut().zip(&other.registers) {
517            *a = (*a).max(b);
518        }
519    }
520
521    /// Estimate the cardinality (number of distinct elements).
522    #[must_use]
523    #[allow(clippy::cast_precision_loss)]
524    pub fn estimate(&self) -> f64 {
525        let m = self.registers.len() as f64;
526        let alpha = match self.precision {
527            4 => 0.673,
528            5 => 0.697,
529            6 => 0.709,
530            _ => 0.7213 / (1.0 + 1.079 / m),
531        };
532
533        let harmonic_sum: f64 = self
534            .registers
535            .iter()
536            .map(|&r| 2.0_f64.powi(-i32::from(r)))
537            .sum();
538
539        let raw_estimate = alpha * m * m / harmonic_sum;
540
541        // Small-range correction (linear counting)
542        if raw_estimate <= 2.5 * m {
543            #[allow(clippy::naive_bytecount)] // few registers; no bytecount dep needed
544            let zeros = self.registers.iter().filter(|&&r| r == 0).count() as f64;
545            if zeros > 0.0 {
546                m * (m / zeros).ln()
547            } else {
548                raw_estimate
549            }
550        } else {
551            raw_estimate
552        }
553    }
554
555    /// Serialize the sketch to bytes.
556    #[must_use]
557    pub fn to_bytes(&self) -> Vec<u8> {
558        let mut buf = Vec::with_capacity(1 + self.registers.len());
559        buf.push(self.precision);
560        buf.extend_from_slice(&self.registers);
561        buf
562    }
563
564    /// Deserialize a sketch from bytes.
565    ///
566    /// # Errors
567    ///
568    /// Returns [`TwoPhaseError::Serde`] if the bytes are malformed.
569    pub fn from_bytes(bytes: &[u8]) -> Result<Self, TwoPhaseError> {
570        if bytes.is_empty() {
571            return Err(TwoPhaseError::Serde("empty HLL bytes".into()));
572        }
573        let precision = bytes[0];
574        if !(4..=18).contains(&precision) {
575            return Err(TwoPhaseError::Serde(format!(
576                "HLL precision {precision} out of range 4..=18"
577            )));
578        }
579        let expected_len = 1 + (1usize << precision);
580        if bytes.len() != expected_len {
581            return Err(TwoPhaseError::Serde(format!(
582                "HLL bytes length mismatch: expected {expected_len}, got {}",
583                bytes.len()
584            )));
585        }
586        Ok(Self {
587            precision,
588            registers: bytes[1..].to_vec(),
589        })
590    }
591
592    /// Number of registers (`2^precision`).
593    #[must_use]
594    pub fn num_registers(&self) -> usize {
595        self.registers.len()
596    }
597
598    /// Precision (log2 of register count).
599    #[must_use]
600    pub fn precision(&self) -> u8 {
601        self.precision
602    }
603}
604
605impl Default for HllSketch {
606    fn default() -> Self {
607        Self::new()
608    }
609}
610
611// ── Conversion: PartialState ↔ AggregateState ──────────────────────
612
613#[cfg(feature = "delta")]
614mod delta_bridge {
615    use super::PartialState;
616    use crate::aggregation::gossip_aggregates::AggregateState;
617
618    impl PartialState {
619        /// Convert to the gossip [`AggregateState`] for cluster replication.
620        #[must_use]
621        pub fn to_aggregate_state(&self) -> AggregateState {
622            match self {
623                Self::Count(n) => AggregateState::Count(*n),
624                Self::Sum(s) => AggregateState::Sum(*s),
625                Self::Avg { sum, count } => AggregateState::Avg {
626                    sum: *sum,
627                    count: *count,
628                },
629                Self::Min(v) => AggregateState::Min(v.unwrap_or(f64::NAN)),
630                Self::Max(v) => AggregateState::Max(v.unwrap_or(f64::NAN)),
631                Self::ApproxDistinct(bytes) => AggregateState::Custom(bytes.clone()),
632            }
633        }
634
635        /// Convert from a gossip [`AggregateState`].
636        #[must_use]
637        pub fn from_aggregate_state(state: &AggregateState) -> Self {
638            match state {
639                AggregateState::Count(n) => Self::Count(*n),
640                AggregateState::Sum(s) => Self::Sum(*s),
641                AggregateState::Avg { sum, count } => Self::Avg {
642                    sum: *sum,
643                    count: *count,
644                },
645                AggregateState::Min(v) => Self::Min(if v.is_nan() { None } else { Some(*v) }),
646                AggregateState::Max(v) => Self::Max(if v.is_nan() { None } else { Some(*v) }),
647                AggregateState::Custom(bytes) => Self::ApproxDistinct(bytes.clone()),
648            }
649        }
650    }
651}
652
653#[cfg(test)]
654mod tests {
655    use super::*;
656    use arrow::array::{Float64Array, StringArray};
657    use arrow::datatypes::{DataType, Field, Schema};
658    use std::sync::Arc;
659
660    // ── TwoPhaseKind ────────────────────────────────────────────────
661
662    #[test]
663    fn test_kind_from_name() {
664        assert_eq!(TwoPhaseKind::from_name("COUNT"), Some(TwoPhaseKind::Count));
665        assert_eq!(TwoPhaseKind::from_name("sum"), Some(TwoPhaseKind::Sum));
666        assert_eq!(TwoPhaseKind::from_name("Avg"), Some(TwoPhaseKind::Avg));
667        assert_eq!(TwoPhaseKind::from_name("MIN"), Some(TwoPhaseKind::Min));
668        assert_eq!(TwoPhaseKind::from_name("max"), Some(TwoPhaseKind::Max));
669        assert_eq!(
670            TwoPhaseKind::from_name("APPROX_COUNT_DISTINCT"),
671            Some(TwoPhaseKind::ApproxDistinct)
672        );
673        assert_eq!(
674            TwoPhaseKind::from_name("APPROX_DISTINCT"),
675            Some(TwoPhaseKind::ApproxDistinct)
676        );
677        assert_eq!(TwoPhaseKind::from_name("MEDIAN"), None);
678        assert_eq!(TwoPhaseKind::from_name(""), None);
679    }
680
681    #[test]
682    fn test_can_use_two_phase() {
683        assert!(can_use_two_phase(&["COUNT", "SUM"]));
684        assert!(can_use_two_phase(&["AVG"]));
685        assert!(can_use_two_phase(&["MIN", "MAX", "COUNT"]));
686        assert!(!can_use_two_phase(&["COUNT", "MEDIAN"]));
687        assert!(!can_use_two_phase(&[]));
688    }
689
690    // ── PartialState merge ──────────────────────────────────────────
691
692    #[test]
693    fn test_merge_count() {
694        let mut a = PartialState::Count(10);
695        a.merge(&PartialState::Count(5));
696        assert_eq!(a, PartialState::Count(15));
697    }
698
699    #[test]
700    fn test_merge_sum() {
701        let mut a = PartialState::Sum(1.5);
702        a.merge(&PartialState::Sum(2.5));
703        assert_eq!(a, PartialState::Sum(4.0));
704    }
705
706    #[test]
707    fn test_merge_avg() {
708        let mut a = PartialState::Avg {
709            sum: 10.0,
710            count: 2,
711        };
712        a.merge(&PartialState::Avg {
713            sum: 20.0,
714            count: 3,
715        });
716        match a {
717            PartialState::Avg { sum, count } => {
718                assert!((sum - 30.0).abs() < f64::EPSILON);
719                assert_eq!(count, 5);
720            }
721            _ => panic!("expected Avg"),
722        }
723    }
724
725    #[test]
726    fn test_merge_min() {
727        let mut a = PartialState::Min(Some(10.0));
728        a.merge(&PartialState::Min(Some(5.0)));
729        assert_eq!(a, PartialState::Min(Some(5.0)));
730
731        // None + Some = Some
732        let mut b = PartialState::Min(None);
733        b.merge(&PartialState::Min(Some(3.0)));
734        assert_eq!(b, PartialState::Min(Some(3.0)));
735
736        // Some + None = Some
737        let mut c = PartialState::Min(Some(7.0));
738        c.merge(&PartialState::Min(None));
739        assert_eq!(c, PartialState::Min(Some(7.0)));
740
741        // None + None = None
742        let mut d = PartialState::Min(None);
743        d.merge(&PartialState::Min(None));
744        assert_eq!(d, PartialState::Min(None));
745    }
746
747    #[test]
748    fn test_merge_max() {
749        let mut a = PartialState::Max(Some(5.0));
750        a.merge(&PartialState::Max(Some(10.0)));
751        assert_eq!(a, PartialState::Max(Some(10.0)));
752
753        let mut b = PartialState::Max(None);
754        b.merge(&PartialState::Max(Some(3.0)));
755        assert_eq!(b, PartialState::Max(Some(3.0)));
756    }
757
758    #[test]
759    fn test_merge_type_mismatch_noop() {
760        let mut a = PartialState::Count(10);
761        a.merge(&PartialState::Sum(5.0));
762        assert_eq!(a, PartialState::Count(10));
763    }
764
765    // ── PartialState finalize ───────────────────────────────────────
766
767    #[test]
768    fn test_finalize_count() {
769        assert!((PartialState::Count(42).finalize() - 42.0).abs() < f64::EPSILON);
770    }
771
772    #[test]
773    fn test_finalize_avg() {
774        let avg = PartialState::Avg {
775            sum: 10.0,
776            count: 4,
777        };
778        assert!((avg.finalize() - 2.5).abs() < f64::EPSILON);
779    }
780
781    #[test]
782    fn test_finalize_avg_zero_count() {
783        let avg = PartialState::Avg { sum: 0.0, count: 0 };
784        assert!((avg.finalize()).abs() < f64::EPSILON);
785    }
786
787    #[test]
788    fn test_finalize_min_none() {
789        assert!(PartialState::Min(None).finalize().is_nan());
790    }
791
792    #[test]
793    fn test_finalize_min_some() {
794        assert!((PartialState::Min(Some(3.25)).finalize() - 3.25).abs() < f64::EPSILON);
795    }
796
797    // ── MergeAggregator ─────────────────────────────────────────────
798
799    #[test]
800    fn test_merge_single_group_three_partitions() {
801        let aggregator = MergeAggregator::new(vec![TwoPhaseKind::Count, TwoPhaseKind::Sum]);
802
803        let partials = vec![
804            PartialAggregate {
805                group_key: b"AAPL".to_vec(),
806                partition_id: 0,
807                states: vec![PartialState::Count(500), PartialState::Sum(75000.0)],
808                watermark_ms: 1000,
809                epoch: 1,
810            },
811            PartialAggregate {
812                group_key: b"AAPL".to_vec(),
813                partition_id: 1,
814                states: vec![PartialState::Count(300), PartialState::Sum(45000.0)],
815                watermark_ms: 1000,
816                epoch: 1,
817            },
818            PartialAggregate {
819                group_key: b"AAPL".to_vec(),
820                partition_id: 2,
821                states: vec![PartialState::Count(200), PartialState::Sum(30000.0)],
822                watermark_ms: 1000,
823                epoch: 1,
824            },
825        ];
826
827        let result = aggregator.merge_all(&partials).unwrap();
828        assert_eq!(result.len(), 1);
829
830        let merged = &result[b"AAPL".as_ref()];
831        assert_eq!(merged[0], PartialState::Count(1000));
832        assert_eq!(merged[1], PartialState::Sum(150_000.0));
833
834        let finals = MergeAggregator::finalize(merged);
835        assert!((finals[0] - 1000.0).abs() < f64::EPSILON);
836        assert!((finals[1] - 150_000.0).abs() < f64::EPSILON);
837    }
838
839    #[test]
840    fn test_merge_multi_group() {
841        let aggregator = MergeAggregator::new(vec![TwoPhaseKind::Count]);
842
843        let partials = vec![
844            PartialAggregate {
845                group_key: b"AAPL".to_vec(),
846                partition_id: 0,
847                states: vec![PartialState::Count(10)],
848                watermark_ms: 1000,
849                epoch: 1,
850            },
851            PartialAggregate {
852                group_key: b"GOOG".to_vec(),
853                partition_id: 0,
854                states: vec![PartialState::Count(20)],
855                watermark_ms: 1000,
856                epoch: 1,
857            },
858            PartialAggregate {
859                group_key: b"AAPL".to_vec(),
860                partition_id: 1,
861                states: vec![PartialState::Count(30)],
862                watermark_ms: 1000,
863                epoch: 1,
864            },
865        ];
866
867        let result = aggregator.merge_all(&partials).unwrap();
868        assert_eq!(result.len(), 2);
869        assert_eq!(result[b"AAPL".as_ref()][0], PartialState::Count(40));
870        assert_eq!(result[b"GOOG".as_ref()][0], PartialState::Count(20));
871    }
872
873    #[test]
874    fn test_merge_avg_weighted() {
875        let aggregator = MergeAggregator::new(vec![TwoPhaseKind::Avg]);
876
877        // Partition 0: avg of [10, 20, 30] = sum=60, count=3
878        // Partition 1: avg of [40, 50] = sum=90, count=2
879        // Correct weighted avg = (60+90) / (3+2) = 150/5 = 30
880        let partials = vec![
881            PartialAggregate {
882                group_key: b"g1".to_vec(),
883                partition_id: 0,
884                states: vec![PartialState::Avg {
885                    sum: 60.0,
886                    count: 3,
887                }],
888                watermark_ms: 1000,
889                epoch: 1,
890            },
891            PartialAggregate {
892                group_key: b"g1".to_vec(),
893                partition_id: 1,
894                states: vec![PartialState::Avg {
895                    sum: 90.0,
896                    count: 2,
897                }],
898                watermark_ms: 1000,
899                epoch: 1,
900            },
901        ];
902
903        let result = aggregator.merge_all(&partials).unwrap();
904        let finals = MergeAggregator::finalize(&result[b"g1".as_ref()]);
905        assert!((finals[0] - 30.0).abs() < f64::EPSILON);
906    }
907
908    #[test]
909    fn test_merge_min_max_global() {
910        let aggregator = MergeAggregator::new(vec![TwoPhaseKind::Min, TwoPhaseKind::Max]);
911
912        let partials = vec![
913            PartialAggregate {
914                group_key: b"g".to_vec(),
915                partition_id: 0,
916                states: vec![PartialState::Min(Some(10.0)), PartialState::Max(Some(90.0))],
917                watermark_ms: 1000,
918                epoch: 1,
919            },
920            PartialAggregate {
921                group_key: b"g".to_vec(),
922                partition_id: 1,
923                states: vec![PartialState::Min(Some(5.0)), PartialState::Max(Some(100.0))],
924                watermark_ms: 1000,
925                epoch: 1,
926            },
927            PartialAggregate {
928                group_key: b"g".to_vec(),
929                partition_id: 2,
930                states: vec![PartialState::Min(Some(15.0)), PartialState::Max(Some(80.0))],
931                watermark_ms: 1000,
932                epoch: 1,
933            },
934        ];
935
936        let result = aggregator.merge_all(&partials).unwrap();
937        let merged = &result[b"g".as_ref()];
938        assert_eq!(merged[0], PartialState::Min(Some(5.0)));
939        assert_eq!(merged[1], PartialState::Max(Some(100.0)));
940    }
941
942    #[test]
943    fn test_merge_empty_partials() {
944        let aggregator = MergeAggregator::new(vec![TwoPhaseKind::Count]);
945        let result = aggregator.merge_all(&[]).unwrap();
946        assert!(result.is_empty());
947    }
948
949    #[test]
950    fn test_merge_function_count_mismatch() {
951        let aggregator = MergeAggregator::new(vec![TwoPhaseKind::Count, TwoPhaseKind::Sum]);
952
953        let bad = PartialAggregate {
954            group_key: b"g".to_vec(),
955            partition_id: 0,
956            states: vec![PartialState::Count(1)], // only 1, expected 2
957            watermark_ms: 0,
958            epoch: 0,
959        };
960
961        let refs: Vec<&PartialAggregate> = vec![&bad];
962        let err = aggregator.merge_group(&refs).unwrap_err();
963        match err {
964            TwoPhaseError::FunctionCountMismatch {
965                expected: 2,
966                actual: 1,
967            } => {}
968            other => panic!("expected FunctionCountMismatch, got {other:?}"),
969        }
970    }
971
972    // ── Arrow IPC ───────────────────────────────────────────────────
973
974    #[test]
975    fn test_arrow_ipc_roundtrip() {
976        let schema = Arc::new(Schema::new(vec![
977            Field::new("symbol", DataType::Utf8, false),
978            Field::new("count", DataType::Float64, false),
979        ]));
980        let batch = RecordBatch::try_new(
981            schema,
982            vec![
983                Arc::new(StringArray::from(vec!["AAPL", "GOOG"])),
984                Arc::new(Float64Array::from(vec![1000.0, 500.0])),
985            ],
986        )
987        .unwrap();
988
989        let ipc_bytes = encode_batch_to_ipc(&batch).unwrap();
990        assert!(!ipc_bytes.is_empty());
991
992        let decoded = decode_batch_from_ipc(&ipc_bytes).unwrap();
993        assert_eq!(decoded.num_rows(), 2);
994        assert_eq!(decoded.num_columns(), 2);
995        assert_eq!(decoded.schema(), batch.schema());
996    }
997
998    #[test]
999    fn test_ipc_decode_invalid() {
1000        let result = decode_batch_from_ipc(b"not valid ipc");
1001        assert!(result.is_err());
1002    }
1003
1004    // ── Serialization ───────────────────────────────────────────────
1005
1006    #[test]
1007    fn test_serialize_deserialize_partials() {
1008        let partials = vec![
1009            PartialAggregate {
1010                group_key: b"AAPL".to_vec(),
1011                partition_id: 0,
1012                states: vec![PartialState::Count(42), PartialState::Sum(100.5)],
1013                watermark_ms: 1000,
1014                epoch: 5,
1015            },
1016            PartialAggregate {
1017                group_key: b"GOOG".to_vec(),
1018                partition_id: 1,
1019                states: vec![PartialState::Count(10), PartialState::Sum(50.0)],
1020                watermark_ms: 1000,
1021                epoch: 5,
1022            },
1023        ];
1024
1025        let bytes = serialize_partials(&partials).unwrap();
1026        let decoded = deserialize_partials(&bytes).unwrap();
1027        assert_eq!(decoded, partials);
1028    }
1029
1030    #[test]
1031    fn test_deserialize_invalid() {
1032        let result = deserialize_partials(b"not json");
1033        assert!(result.is_err());
1034    }
1035
1036    // ── Store Integration ───────────────────────────────────────────
1037
1038    #[test]
1039    fn test_store_publish_collect_roundtrip() {
1040        let store = CrossPartitionAggregateStore::new(3);
1041
1042        let partials = vec![
1043            PartialAggregate {
1044                group_key: b"AAPL".to_vec(),
1045                partition_id: 0,
1046                states: vec![PartialState::Count(100)],
1047                watermark_ms: 1000,
1048                epoch: 1,
1049            },
1050            PartialAggregate {
1051                group_key: b"AAPL".to_vec(),
1052                partition_id: 1,
1053                states: vec![PartialState::Count(200)],
1054                watermark_ms: 1000,
1055                epoch: 1,
1056            },
1057        ];
1058
1059        publish_partials(&store, &partials).unwrap();
1060
1061        let collected = collect_partials(&store, b"AAPL");
1062        assert_eq!(collected.len(), 2);
1063
1064        // Verify content
1065        let total: i64 = collected
1066            .iter()
1067            .map(|p| match &p.states[0] {
1068                PartialState::Count(n) => *n,
1069                _ => panic!("expected Count"),
1070            })
1071            .sum();
1072        assert_eq!(total, 300);
1073    }
1074
1075    // ── Full Pipeline ───────────────────────────────────────────────
1076
1077    #[test]
1078    fn test_three_partition_full_pipeline() {
1079        let store = CrossPartitionAggregateStore::new(3);
1080        let aggregator = MergeAggregator::new(vec![
1081            TwoPhaseKind::Count,
1082            TwoPhaseKind::Sum,
1083            TwoPhaseKind::Avg,
1084        ]);
1085
1086        // Phase 1: each partition publishes partials
1087        for pid in 0..3u32 {
1088            let count = (i64::from(pid) + 1) * 100; // 100, 200, 300
1089            let sum = (f64::from(pid) + 1.0) * 1000.0; // 1000, 2000, 3000
1090            let partial = PartialAggregate {
1091                group_key: b"AAPL".to_vec(),
1092                partition_id: pid,
1093                states: vec![
1094                    PartialState::Count(count),
1095                    PartialState::Sum(sum),
1096                    PartialState::Avg { sum, count },
1097                ],
1098                watermark_ms: 2000,
1099                epoch: 5,
1100            };
1101            publish_partials(&store, &[partial]).unwrap();
1102        }
1103
1104        // Phase 2: collect and merge
1105        let collected = collect_partials(&store, b"AAPL");
1106        assert_eq!(collected.len(), 3);
1107
1108        let result = aggregator.merge_all(&collected).unwrap();
1109        let merged = &result[b"AAPL".as_ref()];
1110
1111        // COUNT: 100 + 200 + 300 = 600
1112        assert_eq!(merged[0], PartialState::Count(600));
1113
1114        // SUM: 1000 + 2000 + 3000 = 6000
1115        assert_eq!(merged[1], PartialState::Sum(6000.0));
1116
1117        // AVG: (1000+2000+3000) / (100+200+300) = 6000/600 = 10.0
1118        let finals = MergeAggregator::finalize(merged);
1119        assert!((finals[2] - 10.0).abs() < f64::EPSILON);
1120    }
1121
1122    // ── HLL Sketch ──────────────────────────────────────────────────
1123
1124    /// Bit mixer that distributes sequential integers across all 64 bits.
1125    fn test_hash(x: u64) -> u64 {
1126        let mut h = x.wrapping_mul(0x517c_c1b7_2722_0a95);
1127        h ^= h >> 33;
1128        h = h.wrapping_mul(0xff51_afd7_ed55_8ccd);
1129        h ^= h >> 33;
1130        h
1131    }
1132
1133    #[test]
1134    fn test_hll_basic() {
1135        let mut hll = HllSketch::new();
1136        assert_eq!(hll.num_registers(), 256);
1137        assert_eq!(hll.precision(), 8);
1138
1139        // Empty sketch should estimate ~0
1140        assert!(hll.estimate() < 1.0);
1141
1142        // Add some values
1143        for i in 0..1000u64 {
1144            hll.add_hash(test_hash(i));
1145        }
1146
1147        let est = hll.estimate();
1148        // HLL with precision=8 has ~6.5% standard error
1149        // 1000 ± 300 is a generous bound (~3-4 sigma)
1150        assert!(est > 700.0, "estimate {est} too low");
1151        assert!(est < 1400.0, "estimate {est} too high");
1152    }
1153
1154    #[test]
1155    fn test_hll_merge() {
1156        let mut hll_a = HllSketch::new();
1157        let mut hll_b = HllSketch::new();
1158
1159        // Add distinct sets
1160        for i in 0..500u64 {
1161            hll_a.add_hash(test_hash(i));
1162        }
1163        for i in 500..1000u64 {
1164            hll_b.add_hash(test_hash(i));
1165        }
1166
1167        let est_a = hll_a.estimate();
1168        let est_b = hll_b.estimate();
1169
1170        hll_a.merge(&hll_b);
1171        let est_merged = hll_a.estimate();
1172
1173        // Merged should be roughly sum of distinct elements
1174        assert!(est_merged > est_a, "merged should be >= individual");
1175        assert!(est_merged > est_b, "merged should be >= individual");
1176        // Should be approximately 1000
1177        assert!(est_merged > 700.0, "merged {est_merged} too low");
1178        assert!(est_merged < 1400.0, "merged {est_merged} too high");
1179    }
1180
1181    #[test]
1182    fn test_hll_serialization_roundtrip() {
1183        let mut hll = HllSketch::new();
1184        for i in 0..100u64 {
1185            hll.add_hash(i * 12345);
1186        }
1187
1188        let bytes = hll.to_bytes();
1189        assert_eq!(bytes.len(), 1 + 256); // precision byte + 256 registers
1190
1191        let restored = HllSketch::from_bytes(&bytes).unwrap();
1192        assert_eq!(restored.precision(), hll.precision());
1193        assert!(
1194            (restored.estimate() - hll.estimate()).abs() < f64::EPSILON,
1195            "estimate should be identical after roundtrip"
1196        );
1197    }
1198
1199    #[test]
1200    fn test_hll_from_bytes_invalid() {
1201        assert!(HllSketch::from_bytes(&[]).is_err());
1202        assert!(HllSketch::from_bytes(&[3]).is_err()); // precision 3 < minimum 4
1203        assert!(HllSketch::from_bytes(&[8, 1, 2]).is_err()); // wrong length
1204    }
1205
1206    #[test]
1207    fn test_hll_merge_partial_state() {
1208        let mut hll_a = HllSketch::new();
1209        let mut hll_b = HllSketch::new();
1210
1211        for i in 0..500u64 {
1212            hll_a.add_hash(test_hash(i));
1213        }
1214        for i in 500..1000u64 {
1215            hll_b.add_hash(test_hash(i));
1216        }
1217
1218        let mut state_a = PartialState::ApproxDistinct(hll_a.to_bytes());
1219        let state_b = PartialState::ApproxDistinct(hll_b.to_bytes());
1220
1221        state_a.merge(&state_b);
1222
1223        let est = state_a.finalize();
1224        assert!(est > 700.0, "HLL partial merge estimate {est} too low");
1225        assert!(est < 1400.0, "HLL partial merge estimate {est} too high");
1226    }
1227
1228    // ── PartialState empty + kind ───────────────────────────────────
1229
1230    #[test]
1231    fn test_partial_state_empty() {
1232        assert_eq!(
1233            PartialState::empty(TwoPhaseKind::Count),
1234            PartialState::Count(0)
1235        );
1236        assert_eq!(
1237            PartialState::empty(TwoPhaseKind::Sum),
1238            PartialState::Sum(0.0)
1239        );
1240        assert_eq!(
1241            PartialState::empty(TwoPhaseKind::Avg),
1242            PartialState::Avg { sum: 0.0, count: 0 }
1243        );
1244        assert_eq!(
1245            PartialState::empty(TwoPhaseKind::Min),
1246            PartialState::Min(None)
1247        );
1248        assert_eq!(
1249            PartialState::empty(TwoPhaseKind::Max),
1250            PartialState::Max(None)
1251        );
1252
1253        // ApproxDistinct empty should be a valid HLL
1254        let empty_hll = PartialState::empty(TwoPhaseKind::ApproxDistinct);
1255        match &empty_hll {
1256            PartialState::ApproxDistinct(bytes) => {
1257                let sketch = HllSketch::from_bytes(bytes).unwrap();
1258                assert!(sketch.estimate() < 1.0);
1259            }
1260            _ => panic!("expected ApproxDistinct"),
1261        }
1262    }
1263
1264    // ── MergeAggregator accessor ────────────────────────────────────
1265
1266    #[test]
1267    fn test_merge_aggregator_accessors() {
1268        let aggregator = MergeAggregator::new(vec![TwoPhaseKind::Count, TwoPhaseKind::Avg]);
1269        assert_eq!(aggregator.num_functions(), 2);
1270        assert_eq!(
1271            aggregator.kinds(),
1272            &[TwoPhaseKind::Count, TwoPhaseKind::Avg]
1273        );
1274    }
1275
1276    // ── PartialAggregate serde ──────────────────────────────────────
1277
1278    #[test]
1279    fn test_partial_aggregate_json_roundtrip() {
1280        let pa = PartialAggregate {
1281            group_key: b"test".to_vec(),
1282            partition_id: 42,
1283            states: vec![
1284                PartialState::Count(100),
1285                PartialState::Avg {
1286                    sum: 500.0,
1287                    count: 10,
1288                },
1289                PartialState::Min(Some(1.5)),
1290                PartialState::Max(None),
1291            ],
1292            watermark_ms: 5000,
1293            epoch: 10,
1294        };
1295
1296        let json = serde_json::to_string(&pa).unwrap();
1297        let back: PartialAggregate = serde_json::from_str(&json).unwrap();
1298        assert_eq!(back, pa);
1299    }
1300}