laminar_core/aggregation/
cross_partition.rs1use bytes::Bytes;
20
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
23struct CompositeKey {
24 group_key: Bytes,
26 partition_id: u32,
28}
29
30pub struct CrossPartitionAggregateStore {
42 map: papaya::HashMap<CompositeKey, Bytes>,
44 num_partitions: u32,
46}
47
48impl CrossPartitionAggregateStore {
49 #[must_use]
51 pub fn new(num_partitions: u32) -> Self {
52 Self {
53 map: papaya::HashMap::new(),
54 num_partitions,
55 }
56 }
57
58 pub fn publish(&self, group_key: Bytes, partition_id: u32, partial: Bytes) {
62 let key = CompositeKey {
63 group_key,
64 partition_id,
65 };
66 let guard = self.map.guard();
67 self.map.insert(key, partial, &guard);
68 }
69
70 #[must_use]
72 pub fn get_partial(&self, group_key: &[u8], partition_id: u32) -> Option<Bytes> {
73 let key = CompositeKey {
74 group_key: Bytes::copy_from_slice(group_key),
75 partition_id,
76 };
77 let guard = self.map.guard();
78 self.map.get(&key, &guard).cloned()
79 }
80
81 #[must_use]
86 pub fn collect_partials(&self, group_key: &[u8]) -> Vec<(u32, Bytes)> {
87 let guard = self.map.guard();
88 let mut result = Vec::new();
89 let group_bytes = Bytes::copy_from_slice(group_key);
91 for partition_id in 0..self.num_partitions {
92 let key = CompositeKey {
93 group_key: group_bytes.clone(),
94 partition_id,
95 };
96 if let Some(partial) = self.map.get(&key, &guard) {
97 result.push((partition_id, partial.clone()));
98 }
99 }
100 result
101 }
102
103 pub fn remove_group(&self, group_key: &[u8]) {
105 let guard = self.map.guard();
106 let group_bytes = Bytes::copy_from_slice(group_key);
107 for partition_id in 0..self.num_partitions {
108 let key = CompositeKey {
109 group_key: group_bytes.clone(),
110 partition_id,
111 };
112 self.map.remove(&key, &guard);
113 }
114 }
115
116 #[must_use]
118 pub fn num_partitions(&self) -> u32 {
119 self.num_partitions
120 }
121
122 #[must_use]
124 pub fn len(&self) -> usize {
125 self.map.len()
126 }
127
128 #[must_use]
130 pub fn is_empty(&self) -> bool {
131 self.len() == 0
132 }
133
134 #[must_use]
143 pub fn snapshot(&self) -> Vec<(Vec<u8>, Vec<u8>)> {
144 let guard = self.map.guard();
145 let mut entries = Vec::new();
146
147 entries.push((Vec::new(), self.num_partitions.to_le_bytes().to_vec()));
149
150 for (key, value) in self.map.iter(&guard) {
151 #[allow(clippy::cast_possible_truncation)] let group_len = key.group_key.len() as u32;
153 let mut serialized_key = Vec::with_capacity(4 + key.group_key.len() + 4);
154 serialized_key.extend_from_slice(&group_len.to_le_bytes());
155 serialized_key.extend_from_slice(&key.group_key);
156 serialized_key.extend_from_slice(&key.partition_id.to_le_bytes());
157 entries.push((serialized_key, value.to_vec()));
158 }
159
160 entries
161 }
162
163 pub fn restore(&self, snapshot: &[(Vec<u8>, Vec<u8>)]) {
173 let guard = self.map.guard();
174 self.map.clear(&guard);
175
176 for (key_bytes, value_bytes) in snapshot {
177 if key_bytes.is_empty() {
179 continue;
180 }
181 if key_bytes.len() < 8 {
182 continue; }
184
185 let group_len = u32::from_le_bytes(key_bytes[..4].try_into().unwrap()) as usize;
186 if key_bytes.len() < 4 + group_len + 4 {
187 continue; }
189
190 let group_key = Bytes::copy_from_slice(&key_bytes[4..4 + group_len]);
191 let partition_id = u32::from_le_bytes(
192 key_bytes[4 + group_len..4 + group_len + 4]
193 .try_into()
194 .unwrap(),
195 );
196
197 let composite = CompositeKey {
198 group_key,
199 partition_id,
200 };
201 self.map
202 .insert(composite, Bytes::copy_from_slice(value_bytes), &guard);
203 }
204 }
205}
206
207#[cfg(test)]
212mod tests {
213 use super::*;
214
215 #[test]
216 fn test_publish_and_get() {
217 let store = CrossPartitionAggregateStore::new(4);
218
219 store.publish(
220 Bytes::from_static(b"group1"),
221 0,
222 Bytes::from_static(b"partial_0"),
223 );
224
225 let result = store.get_partial(b"group1", 0);
226 assert_eq!(result, Some(Bytes::from_static(b"partial_0")));
227
228 assert!(store.get_partial(b"group1", 1).is_none());
230
231 assert!(store.get_partial(b"group2", 0).is_none());
233 }
234
235 #[test]
236 fn test_overwrite_partial() {
237 let store = CrossPartitionAggregateStore::new(2);
238
239 store.publish(Bytes::from_static(b"key"), 0, Bytes::from_static(b"v1"));
240 store.publish(Bytes::from_static(b"key"), 0, Bytes::from_static(b"v2"));
241
242 assert_eq!(
243 store.get_partial(b"key", 0),
244 Some(Bytes::from_static(b"v2"))
245 );
246 }
247
248 #[test]
249 fn test_collect_partials() {
250 let store = CrossPartitionAggregateStore::new(3);
251
252 store.publish(Bytes::from_static(b"g"), 0, Bytes::from_static(b"p0"));
253 store.publish(Bytes::from_static(b"g"), 2, Bytes::from_static(b"p2"));
254 let partials = store.collect_partials(b"g");
257 assert_eq!(partials.len(), 2);
258
259 let ids: Vec<u32> = partials.iter().map(|(id, _)| *id).collect();
260 assert!(ids.contains(&0));
261 assert!(ids.contains(&2));
262 }
263
264 #[test]
265 fn test_remove_group() {
266 let store = CrossPartitionAggregateStore::new(2);
267
268 store.publish(Bytes::from_static(b"g1"), 0, Bytes::from_static(b"a"));
269 store.publish(Bytes::from_static(b"g1"), 1, Bytes::from_static(b"b"));
270 store.publish(Bytes::from_static(b"g2"), 0, Bytes::from_static(b"c"));
271
272 assert_eq!(store.len(), 3);
273
274 store.remove_group(b"g1");
275
276 assert!(store.get_partial(b"g1", 0).is_none());
277 assert!(store.get_partial(b"g1", 1).is_none());
278 assert_eq!(store.get_partial(b"g2", 0), Some(Bytes::from_static(b"c")));
280 }
281
282 #[test]
283 fn test_empty_store() {
284 let store = CrossPartitionAggregateStore::new(4);
285 assert!(store.is_empty());
286 assert_eq!(store.len(), 0);
287 assert_eq!(store.num_partitions(), 4);
288 }
289
290 #[test]
291 fn test_snapshot_and_restore() {
292 let store = CrossPartitionAggregateStore::new(3);
293
294 store.publish(Bytes::from_static(b"g1"), 0, Bytes::from_static(b"p0"));
295 store.publish(Bytes::from_static(b"g1"), 1, Bytes::from_static(b"p1"));
296 store.publish(Bytes::from_static(b"g2"), 2, Bytes::from_static(b"p2"));
297
298 let snapshot = store.snapshot();
299 assert_eq!(snapshot.len(), 4); let store2 = CrossPartitionAggregateStore::new(3);
303 store2.restore(&snapshot);
304
305 assert_eq!(store2.len(), 3);
306 assert_eq!(
307 store2.get_partial(b"g1", 0),
308 Some(Bytes::from_static(b"p0"))
309 );
310 assert_eq!(
311 store2.get_partial(b"g1", 1),
312 Some(Bytes::from_static(b"p1"))
313 );
314 assert_eq!(
315 store2.get_partial(b"g2", 2),
316 Some(Bytes::from_static(b"p2"))
317 );
318 }
319
320 #[test]
321 fn test_restore_clears_existing() {
322 let store = CrossPartitionAggregateStore::new(2);
323
324 store.publish(Bytes::from_static(b"old"), 0, Bytes::from_static(b"v"));
325 assert_eq!(store.len(), 1);
326
327 let empty_snapshot = vec![(Vec::new(), 2u32.to_le_bytes().to_vec())];
329 store.restore(&empty_snapshot);
330
331 assert!(store.is_empty());
332 assert!(store.get_partial(b"old", 0).is_none());
333 }
334
335 #[test]
336 fn test_concurrent_access() {
337 use std::sync::Arc;
338 use std::thread;
339
340 let store = Arc::new(CrossPartitionAggregateStore::new(4));
341 let mut handles = vec![];
342
343 for partition in 0..4u32 {
345 let store = Arc::clone(&store);
346 handles.push(thread::spawn(move || {
347 for i in 0..100u32 {
348 let group = format!("group_{i}");
349 let value = format!("p{partition}_v{i}");
350 store.publish(Bytes::from(group), partition, Bytes::from(value));
351 }
352 }));
353 }
354
355 for h in handles {
356 h.join().unwrap();
357 }
358
359 let partials = store.collect_partials(b"group_50");
361 assert_eq!(partials.len(), 4);
362 }
363}