Skip to main content

laminar_core/aggregation/
cross_partition.rs

1//! Cross-partition aggregate store backed by `papaya::HashMap`.
2//!
3//! In a partition-parallel system, each partition computes partial aggregates
4//! independently. The [`CrossPartitionAggregateStore`] provides a lock-free
5//! concurrent hash map where partitions publish their partial aggregates
6//! and readers can merge them on demand.
7//!
8//! ## Design
9//!
10//! - Each partition writes its partial aggregate under `(group_key, partition_id)`
11//! - Readers iterate all partitions for a given group key and merge
12//! - The underlying `papaya::HashMap` is lock-free and scales with readers
13//!
14//! ## Thread Safety
15//!
16//! All operations are `Send + Sync`. Writers use `pin()` + `insert()`;
17//! readers use `pin()` + `get()`. No external locking required.
18
19use bytes::Bytes;
20
21/// A composite key combining a group key with a partition identifier.
22#[derive(Debug, Clone, PartialEq, Eq, Hash)]
23struct CompositeKey {
24    /// The aggregation group key (serialized).
25    group_key: Bytes,
26    /// The partition that produced this partial aggregate.
27    partition_id: u32,
28}
29
30/// Lock-free concurrent store for cross-partition partial aggregates.
31///
32/// Each partition publishes serialized partial aggregates under its
33/// `partition_id`. Readers merge partials for a given group key to
34/// produce the final aggregate.
35///
36/// ## Performance
37///
38/// - Write (publish partial): single `papaya::HashMap::insert` — lock-free
39/// - Read (get partial): single `papaya::HashMap::get` — lock-free
40/// - Merge: iterate known partitions, collect partials, caller merges
41pub struct CrossPartitionAggregateStore {
42    /// Lock-free concurrent map: `(group_key, partition_id) -> partial_aggregate`.
43    map: papaya::HashMap<CompositeKey, Bytes>,
44    /// Total number of partitions (fixed at creation).
45    num_partitions: u32,
46}
47
48impl CrossPartitionAggregateStore {
49    /// Create a new store for the given number of partitions.
50    #[must_use]
51    pub fn new(num_partitions: u32) -> Self {
52        Self {
53            map: papaya::HashMap::new(),
54            num_partitions,
55        }
56    }
57
58    /// Publish a partial aggregate from a partition.
59    ///
60    /// Overwrites any previous partial for this `(group_key, partition_id)`.
61    pub fn publish(&self, group_key: Bytes, partition_id: u32, partial: Bytes) {
62        let key = CompositeKey {
63            group_key,
64            partition_id,
65        };
66        let guard = self.map.guard();
67        self.map.insert(key, partial, &guard);
68    }
69
70    /// Get the partial aggregate for a specific partition.
71    #[must_use]
72    pub fn get_partial(&self, group_key: &[u8], partition_id: u32) -> Option<Bytes> {
73        let key = CompositeKey {
74            group_key: Bytes::copy_from_slice(group_key),
75            partition_id,
76        };
77        let guard = self.map.guard();
78        self.map.get(&key, &guard).cloned()
79    }
80
81    /// Collect all partial aggregates for a group key across all partitions.
82    ///
83    /// Returns a vector of `(partition_id, partial_bytes)` for all
84    /// partitions that have published a partial for this key.
85    #[must_use]
86    pub fn collect_partials(&self, group_key: &[u8]) -> Vec<(u32, Bytes)> {
87        let guard = self.map.guard();
88        let mut result = Vec::new();
89        // Single allocation; clone() inside the loop is a ref-count bump (~2ns)
90        let group_bytes = Bytes::copy_from_slice(group_key);
91        for partition_id in 0..self.num_partitions {
92            let key = CompositeKey {
93                group_key: group_bytes.clone(),
94                partition_id,
95            };
96            if let Some(partial) = self.map.get(&key, &guard) {
97                result.push((partition_id, partial.clone()));
98            }
99        }
100        result
101    }
102
103    /// Remove all partials for a group key.
104    pub fn remove_group(&self, group_key: &[u8]) {
105        let guard = self.map.guard();
106        let group_bytes = Bytes::copy_from_slice(group_key);
107        for partition_id in 0..self.num_partitions {
108            let key = CompositeKey {
109                group_key: group_bytes.clone(),
110                partition_id,
111            };
112            self.map.remove(&key, &guard);
113        }
114    }
115
116    /// Total number of partitions.
117    #[must_use]
118    pub fn num_partitions(&self) -> u32 {
119        self.num_partitions
120    }
121
122    /// Number of entries in the map (across all partitions and groups).
123    #[must_use]
124    pub fn len(&self) -> usize {
125        self.map.len()
126    }
127
128    /// Whether the store has no entries.
129    #[must_use]
130    pub fn is_empty(&self) -> bool {
131        self.len() == 0
132    }
133
134    /// Snapshot all partial aggregates for checkpointing.
135    ///
136    /// Each entry is serialized as:
137    /// - Key: `group_key_len(4 bytes LE) + group_key + partition_id(4 bytes LE)`
138    /// - Value: raw partial aggregate bytes
139    ///
140    /// The `num_partitions` is stored as a sentinel entry with an empty key
141    /// and value containing the partition count as 4 bytes LE.
142    #[must_use]
143    pub fn snapshot(&self) -> Vec<(Vec<u8>, Vec<u8>)> {
144        let guard = self.map.guard();
145        let mut entries = Vec::new();
146
147        // Sentinel entry for num_partitions
148        entries.push((Vec::new(), self.num_partitions.to_le_bytes().to_vec()));
149
150        for (key, value) in self.map.iter(&guard) {
151            #[allow(clippy::cast_possible_truncation)] // group keys are always < 4 GiB
152            let group_len = key.group_key.len() as u32;
153            let mut serialized_key = Vec::with_capacity(4 + key.group_key.len() + 4);
154            serialized_key.extend_from_slice(&group_len.to_le_bytes());
155            serialized_key.extend_from_slice(&key.group_key);
156            serialized_key.extend_from_slice(&key.partition_id.to_le_bytes());
157            entries.push((serialized_key, value.to_vec()));
158        }
159
160        entries
161    }
162
163    /// Restore partial aggregates from a checkpoint snapshot.
164    ///
165    /// Clears the current state and inserts all entries from the snapshot.
166    ///
167    /// # Panics
168    ///
169    /// Panics if a non-sentinel entry has a key shorter than the encoded
170    /// length prefix (corrupted snapshot). Malformed entries with incorrect
171    /// total length are silently skipped.
172    pub fn restore(&self, snapshot: &[(Vec<u8>, Vec<u8>)]) {
173        let guard = self.map.guard();
174        self.map.clear(&guard);
175
176        for (key_bytes, value_bytes) in snapshot {
177            // Skip sentinel entry (empty key = num_partitions metadata)
178            if key_bytes.is_empty() {
179                continue;
180            }
181            if key_bytes.len() < 8 {
182                continue; // malformed entry
183            }
184
185            let group_len = u32::from_le_bytes(key_bytes[..4].try_into().unwrap()) as usize;
186            if key_bytes.len() < 4 + group_len + 4 {
187                continue; // malformed entry
188            }
189
190            let group_key = Bytes::copy_from_slice(&key_bytes[4..4 + group_len]);
191            let partition_id = u32::from_le_bytes(
192                key_bytes[4 + group_len..4 + group_len + 4]
193                    .try_into()
194                    .unwrap(),
195            );
196
197            let composite = CompositeKey {
198                group_key,
199                partition_id,
200            };
201            self.map
202                .insert(composite, Bytes::copy_from_slice(value_bytes), &guard);
203        }
204    }
205}
206
207// papaya::HashMap<K, V> is Send + Sync when K and V are Send + Sync.
208// CompositeKey contains Bytes (Send+Sync) and u32 (Send+Sync),
209// so the auto-derived impls apply. No manual unsafe needed.
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    #[test]
216    fn test_publish_and_get() {
217        let store = CrossPartitionAggregateStore::new(4);
218
219        store.publish(
220            Bytes::from_static(b"group1"),
221            0,
222            Bytes::from_static(b"partial_0"),
223        );
224
225        let result = store.get_partial(b"group1", 0);
226        assert_eq!(result, Some(Bytes::from_static(b"partial_0")));
227
228        // Missing partition
229        assert!(store.get_partial(b"group1", 1).is_none());
230
231        // Missing group
232        assert!(store.get_partial(b"group2", 0).is_none());
233    }
234
235    #[test]
236    fn test_overwrite_partial() {
237        let store = CrossPartitionAggregateStore::new(2);
238
239        store.publish(Bytes::from_static(b"key"), 0, Bytes::from_static(b"v1"));
240        store.publish(Bytes::from_static(b"key"), 0, Bytes::from_static(b"v2"));
241
242        assert_eq!(
243            store.get_partial(b"key", 0),
244            Some(Bytes::from_static(b"v2"))
245        );
246    }
247
248    #[test]
249    fn test_collect_partials() {
250        let store = CrossPartitionAggregateStore::new(3);
251
252        store.publish(Bytes::from_static(b"g"), 0, Bytes::from_static(b"p0"));
253        store.publish(Bytes::from_static(b"g"), 2, Bytes::from_static(b"p2"));
254        // partition 1 hasn't published yet
255
256        let partials = store.collect_partials(b"g");
257        assert_eq!(partials.len(), 2);
258
259        let ids: Vec<u32> = partials.iter().map(|(id, _)| *id).collect();
260        assert!(ids.contains(&0));
261        assert!(ids.contains(&2));
262    }
263
264    #[test]
265    fn test_remove_group() {
266        let store = CrossPartitionAggregateStore::new(2);
267
268        store.publish(Bytes::from_static(b"g1"), 0, Bytes::from_static(b"a"));
269        store.publish(Bytes::from_static(b"g1"), 1, Bytes::from_static(b"b"));
270        store.publish(Bytes::from_static(b"g2"), 0, Bytes::from_static(b"c"));
271
272        assert_eq!(store.len(), 3);
273
274        store.remove_group(b"g1");
275
276        assert!(store.get_partial(b"g1", 0).is_none());
277        assert!(store.get_partial(b"g1", 1).is_none());
278        // g2 still present
279        assert_eq!(store.get_partial(b"g2", 0), Some(Bytes::from_static(b"c")));
280    }
281
282    #[test]
283    fn test_empty_store() {
284        let store = CrossPartitionAggregateStore::new(4);
285        assert!(store.is_empty());
286        assert_eq!(store.len(), 0);
287        assert_eq!(store.num_partitions(), 4);
288    }
289
290    #[test]
291    fn test_snapshot_and_restore() {
292        let store = CrossPartitionAggregateStore::new(3);
293
294        store.publish(Bytes::from_static(b"g1"), 0, Bytes::from_static(b"p0"));
295        store.publish(Bytes::from_static(b"g1"), 1, Bytes::from_static(b"p1"));
296        store.publish(Bytes::from_static(b"g2"), 2, Bytes::from_static(b"p2"));
297
298        let snapshot = store.snapshot();
299        assert_eq!(snapshot.len(), 4); // 3 entries + 1 sentinel
300
301        // Restore into a fresh store
302        let store2 = CrossPartitionAggregateStore::new(3);
303        store2.restore(&snapshot);
304
305        assert_eq!(store2.len(), 3);
306        assert_eq!(
307            store2.get_partial(b"g1", 0),
308            Some(Bytes::from_static(b"p0"))
309        );
310        assert_eq!(
311            store2.get_partial(b"g1", 1),
312            Some(Bytes::from_static(b"p1"))
313        );
314        assert_eq!(
315            store2.get_partial(b"g2", 2),
316            Some(Bytes::from_static(b"p2"))
317        );
318    }
319
320    #[test]
321    fn test_restore_clears_existing() {
322        let store = CrossPartitionAggregateStore::new(2);
323
324        store.publish(Bytes::from_static(b"old"), 0, Bytes::from_static(b"v"));
325        assert_eq!(store.len(), 1);
326
327        // Restore empty snapshot (just sentinel)
328        let empty_snapshot = vec![(Vec::new(), 2u32.to_le_bytes().to_vec())];
329        store.restore(&empty_snapshot);
330
331        assert!(store.is_empty());
332        assert!(store.get_partial(b"old", 0).is_none());
333    }
334
335    #[test]
336    fn test_concurrent_access() {
337        use std::sync::Arc;
338        use std::thread;
339
340        let store = Arc::new(CrossPartitionAggregateStore::new(4));
341        let mut handles = vec![];
342
343        // Multiple writers concurrently
344        for partition in 0..4u32 {
345            let store = Arc::clone(&store);
346            handles.push(thread::spawn(move || {
347                for i in 0..100u32 {
348                    let group = format!("group_{i}");
349                    let value = format!("p{partition}_v{i}");
350                    store.publish(Bytes::from(group), partition, Bytes::from(value));
351                }
352            }));
353        }
354
355        for h in handles {
356            h.join().unwrap();
357        }
358
359        // All partitions should have published for group_50
360        let partials = store.collect_partials(b"group_50");
361        assert_eq!(partials.len(), 4);
362    }
363}