laminar_core/shuffle/
barrier_tracker.rs1use parking_lot::Mutex;
6use rustc_hash::{FxHashMap, FxHashSet};
7
8use crate::checkpoint::barrier::CheckpointBarrier;
9
10#[derive(Debug)]
12struct Pending {
13 barrier: CheckpointBarrier,
14 seen: FxHashSet<usize>,
15}
16
17pub struct BarrierTracker {
20 inputs: usize,
21 state: Mutex<FxHashMap<u64, Pending>>,
22}
23
24impl std::fmt::Debug for BarrierTracker {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 f.debug_struct("BarrierTracker")
27 .field("inputs", &self.inputs)
28 .field("pending_epochs", &self.state.lock().len())
29 .finish()
30 }
31}
32
33impl BarrierTracker {
34 #[must_use]
41 pub fn new(inputs: usize) -> Self {
42 assert!(inputs > 0, "BarrierTracker needs at least one input");
43 Self {
44 inputs,
45 state: Mutex::new(FxHashMap::default()),
46 }
47 }
48
49 pub fn observe(
60 &self,
61 from_input: usize,
62 barrier: CheckpointBarrier,
63 ) -> Option<CheckpointBarrier> {
64 assert!(
65 from_input < self.inputs,
66 "input {from_input} >= inputs {}",
67 self.inputs,
68 );
69 let mut state = self.state.lock();
70 let entry = state
71 .entry(barrier.checkpoint_id)
72 .or_insert_with(|| Pending {
73 barrier,
74 seen: FxHashSet::default(),
75 });
76 entry.seen.insert(from_input);
77 if entry.seen.len() == self.inputs {
78 let out = entry.barrier;
79 state.remove(&barrier.checkpoint_id);
80 Some(out)
81 } else {
82 None
83 }
84 }
85
86 #[must_use]
88 pub fn pending(&self) -> usize {
89 self.state.lock().len()
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96 use crate::checkpoint::barrier::flags;
97
98 fn b(cp: u64, epoch: u64) -> CheckpointBarrier {
99 CheckpointBarrier {
100 checkpoint_id: cp,
101 epoch,
102 flags: flags::FULL_SNAPSHOT,
103 }
104 }
105
106 #[test]
107 fn aligns_when_every_input_observed() {
108 let t = BarrierTracker::new(3);
109 assert!(t.observe(0, b(1, 1)).is_none());
110 assert!(t.observe(1, b(1, 1)).is_none());
111 let fired = t.observe(2, b(1, 1)).expect("aligned");
112 assert_eq!(fired.checkpoint_id, 1);
113 assert_eq!(t.pending(), 0, "state cleaned up post-alignment");
114 }
115
116 #[test]
117 fn duplicate_observation_is_idempotent() {
118 let t = BarrierTracker::new(2);
119 assert!(t.observe(0, b(5, 2)).is_none());
120 assert!(t.observe(0, b(5, 2)).is_none(), "repeat for input 0 no-op");
121 let fired = t.observe(1, b(5, 2)).expect("aligned");
122 assert_eq!(fired.checkpoint_id, 5);
123 }
124
125 #[test]
126 fn independent_checkpoints_align_independently() {
127 let t = BarrierTracker::new(2);
128 assert!(t.observe(0, b(10, 4)).is_none());
130 assert!(t.observe(0, b(11, 5)).is_none());
131 assert_eq!(t.pending(), 2);
132 assert_eq!(t.observe(1, b(10, 4)).unwrap().checkpoint_id, 10);
133 assert_eq!(t.observe(1, b(11, 5)).unwrap().checkpoint_id, 11);
134 assert_eq!(t.pending(), 0);
135 }
136
137 #[test]
138 #[should_panic(expected = "input 9 >= inputs 2")]
139 fn observe_rejects_out_of_range_input() {
140 let t = BarrierTracker::new(2);
141 let _ = t.observe(9, b(1, 1));
142 }
143
144 #[test]
145 #[should_panic(expected = "at least one input")]
146 fn zero_inputs_rejected() {
147 let _ = BarrierTracker::new(0);
148 }
149}