1use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use async_trait::async_trait;
8use parking_lot::Mutex;
9use rustc_hash::{FxHashMap, FxHashSet};
10use serde::{Deserialize, Serialize};
11
12use crate::cluster::discovery::NodeId;
13
14pub const ANNOUNCEMENT_KEY: &str = "control:barrier";
16
17pub const ACK_KEY: &str = "control:barrier-ack";
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
22pub enum Phase {
23 Prepare,
25 Commit,
27 Abort,
29}
30
31#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
33pub struct BarrierAnnouncement {
34 pub epoch: u64,
36 pub checkpoint_id: u64,
38 pub phase: Phase,
40 pub flags: u64,
42 #[serde(default)]
55 pub min_watermark_ms: Option<i64>,
56}
57
58#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
60pub struct BarrierAck {
61 pub epoch: u64,
63 pub ok: bool,
65 pub error: Option<String>,
67 #[serde(default)]
76 pub local_watermark_ms: Option<i64>,
77}
78
79#[derive(Debug, Clone, PartialEq, Eq)]
81pub enum QuorumOutcome {
82 Reached {
84 acks: Vec<NodeId>,
86 min_follower_watermark_ms: Option<i64>,
92 },
93 TimedOut {
95 got: Vec<NodeId>,
97 missing: Vec<NodeId>,
99 },
100 Failed {
102 failures: Vec<(NodeId, String)>,
104 },
105}
106
107#[async_trait]
109pub trait ClusterKv: Send + Sync + 'static {
110 async fn write(&self, key: &str, value: String);
112 async fn read_from(&self, who: NodeId, key: &str) -> Option<String>;
114 async fn scan(&self, key: &str) -> Vec<(NodeId, String)>;
116}
117
118#[derive(Debug)]
120pub struct InMemoryKv {
121 local_id: NodeId,
122 state: Mutex<FxHashMap<(NodeId, String), String>>,
123}
124
125impl InMemoryKv {
126 #[must_use]
128 pub fn new(local_id: NodeId) -> Self {
129 Self {
130 local_id,
131 state: Mutex::new(FxHashMap::default()),
132 }
133 }
134
135 pub fn seed(&self, peer: NodeId, key: &str, value: String) {
137 self.state.lock().insert((peer, key.to_string()), value);
138 }
139}
140
141#[async_trait]
142impl ClusterKv for InMemoryKv {
143 async fn write(&self, key: &str, value: String) {
144 self.state
145 .lock()
146 .insert((self.local_id, key.to_string()), value);
147 }
148
149 async fn read_from(&self, who: NodeId, key: &str) -> Option<String> {
150 self.state.lock().get(&(who, key.to_string())).cloned()
151 }
152
153 async fn scan(&self, key: &str) -> Vec<(NodeId, String)> {
154 self.state
155 .lock()
156 .iter()
157 .filter(|((_, k), _)| k == key)
158 .map(|((n, _), v)| (*n, v.clone()))
159 .collect()
160 }
161}
162
163pub struct BarrierCoordinator {
165 kv: Arc<dyn ClusterKv>,
166}
167
168impl std::fmt::Debug for BarrierCoordinator {
169 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170 f.debug_struct("BarrierCoordinator").finish_non_exhaustive()
171 }
172}
173
174impl BarrierCoordinator {
175 #[must_use]
177 pub fn new(kv: Arc<dyn ClusterKv>) -> Self {
178 Self { kv }
179 }
180
181 pub async fn announce(&self, ann: &BarrierAnnouncement) -> Result<(), String> {
186 let json = serde_json::to_string(ann).map_err(|e| e.to_string())?;
187 self.kv.write(ANNOUNCEMENT_KEY, json).await;
188 Ok(())
189 }
190
191 pub async fn observe(&self, leader: NodeId) -> Result<Option<BarrierAnnouncement>, String> {
196 match self.kv.read_from(leader, ANNOUNCEMENT_KEY).await {
197 Some(json) => serde_json::from_str(&json)
198 .map(Some)
199 .map_err(|e| e.to_string()),
200 None => Ok(None),
201 }
202 }
203
204 pub async fn ack(&self, ack: &BarrierAck) -> Result<(), String> {
209 let json = serde_json::to_string(ack).map_err(|e| e.to_string())?;
210 self.kv.write(ACK_KEY, json).await;
211 Ok(())
212 }
213
214 pub async fn wait_for_quorum(
216 &self,
217 epoch: u64,
218 expected: &[NodeId],
219 deadline: Duration,
220 ) -> QuorumOutcome {
221 let start = Instant::now();
222 let expected_set: FxHashSet<NodeId> = expected.iter().copied().collect();
223 let mut successful: Vec<NodeId> = Vec::new();
224 let mut failures: Vec<(NodeId, String)> = Vec::new();
225 let mut min_follower_wm: Option<i64>;
226
227 loop {
228 successful.clear();
229 failures.clear();
230 min_follower_wm = None;
231
232 for (from, json) in self.kv.scan(ACK_KEY).await {
233 if !expected_set.contains(&from) {
234 continue;
235 }
236 let Ok(ack) = serde_json::from_str::<BarrierAck>(&json) else {
237 continue;
238 };
239 if ack.epoch != epoch {
240 continue;
241 }
242 if ack.ok {
243 successful.push(from);
244 if let Some(wm) = ack.local_watermark_ms {
245 min_follower_wm = Some(match min_follower_wm {
246 Some(cur) => cur.min(wm),
247 None => wm,
248 });
249 }
250 } else {
251 failures.push((from, ack.error.unwrap_or_default()));
252 }
253 }
254
255 if !failures.is_empty() {
256 return QuorumOutcome::Failed { failures };
257 }
258 if successful.len() == expected.len() {
259 return QuorumOutcome::Reached {
260 acks: successful,
261 min_follower_watermark_ms: min_follower_wm,
262 };
263 }
264 if start.elapsed() >= deadline {
265 let got: FxHashSet<NodeId> = successful.iter().copied().collect();
266 let missing: Vec<NodeId> = expected
267 .iter()
268 .copied()
269 .filter(|n| !got.contains(n))
270 .collect();
271 return QuorumOutcome::TimedOut {
272 got: successful,
273 missing,
274 };
275 }
276 tokio::time::sleep(Duration::from_millis(50)).await;
277 }
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 fn kv(id: NodeId) -> Arc<InMemoryKv> {
286 Arc::new(InMemoryKv::new(id))
287 }
288
289 #[tokio::test]
290 async fn leader_announces_follower_observes() {
291 let leader_kv = kv(NodeId(1));
292 let coord = BarrierCoordinator::new(leader_kv.clone());
293 coord
294 .announce(&BarrierAnnouncement {
295 epoch: 5,
296 checkpoint_id: 42,
297 phase: Phase::Prepare,
298 flags: 0,
299 min_watermark_ms: None,
300 })
301 .await
302 .unwrap();
303 let got = coord.observe(NodeId(1)).await.unwrap().unwrap();
304 assert_eq!(got.epoch, 5);
305 assert_eq!(got.checkpoint_id, 42);
306 }
307
308 #[tokio::test]
309 async fn observe_returns_none_when_leader_silent() {
310 let k = kv(NodeId(1));
311 let coord = BarrierCoordinator::new(k);
312 assert!(coord.observe(NodeId(1)).await.unwrap().is_none());
313 }
314
315 #[tokio::test]
316 async fn quorum_reached_when_all_ack_success() {
317 let k = kv(NodeId(1));
318 let ack_json = serde_json::to_string(&BarrierAck {
320 epoch: 7,
321 ok: true,
322 error: None,
323 local_watermark_ms: None,
324 })
325 .unwrap();
326 k.seed(NodeId(2), ACK_KEY, ack_json.clone());
327 k.seed(NodeId(3), ACK_KEY, ack_json);
328
329 let coord = BarrierCoordinator::new(k);
330 let outcome = coord
331 .wait_for_quorum(7, &[NodeId(2), NodeId(3)], Duration::from_millis(200))
332 .await;
333 match outcome {
334 QuorumOutcome::Reached {
335 mut acks,
336 min_follower_watermark_ms,
337 } => {
338 acks.sort_by_key(|n| n.0);
339 assert_eq!(acks, vec![NodeId(2), NodeId(3)]);
340 assert_eq!(
341 min_follower_watermark_ms, None,
342 "no follower reported a watermark — min is None"
343 );
344 }
345 other => panic!("expected Reached, got {other:?}"),
346 }
347 }
348
349 #[tokio::test]
350 async fn quorum_timeout_when_follower_silent() {
351 let k = kv(NodeId(1));
352 let ack_json = serde_json::to_string(&BarrierAck {
353 epoch: 8,
354 ok: true,
355 error: None,
356 local_watermark_ms: None,
357 })
358 .unwrap();
359 k.seed(NodeId(2), ACK_KEY, ack_json);
360 let coord = BarrierCoordinator::new(k);
363 let outcome = coord
364 .wait_for_quorum(8, &[NodeId(2), NodeId(3)], Duration::from_millis(150))
365 .await;
366 match outcome {
367 QuorumOutcome::TimedOut { got, missing } => {
368 assert_eq!(got, vec![NodeId(2)]);
369 assert_eq!(missing, vec![NodeId(3)]);
370 }
371 other => panic!("expected TimedOut, got {other:?}"),
372 }
373 }
374
375 #[tokio::test]
376 async fn quorum_fails_fast_on_reported_error() {
377 let k = kv(NodeId(1));
378 let good = serde_json::to_string(&BarrierAck {
379 epoch: 9,
380 ok: true,
381 error: None,
382 local_watermark_ms: None,
383 })
384 .unwrap();
385 let bad = serde_json::to_string(&BarrierAck {
386 epoch: 9,
387 ok: false,
388 error: Some("state snapshot failed: disk full".into()),
389 local_watermark_ms: None,
390 })
391 .unwrap();
392 k.seed(NodeId(2), ACK_KEY, good);
393 k.seed(NodeId(3), ACK_KEY, bad);
394
395 let coord = BarrierCoordinator::new(k);
396 let outcome = coord
397 .wait_for_quorum(9, &[NodeId(2), NodeId(3)], Duration::from_secs(2))
398 .await;
399 match outcome {
400 QuorumOutcome::Failed { failures } => {
401 assert_eq!(failures.len(), 1);
402 assert_eq!(failures[0].0, NodeId(3));
403 assert!(failures[0].1.contains("disk full"));
404 }
405 other => panic!("expected Failed, got {other:?}"),
406 }
407 }
408
409 #[tokio::test]
410 async fn wrong_epoch_ack_is_ignored() {
411 let k = kv(NodeId(1));
412 let stale = serde_json::to_string(&BarrierAck {
415 epoch: 9,
416 ok: true,
417 error: None,
418 local_watermark_ms: None,
419 })
420 .unwrap();
421 k.seed(NodeId(2), ACK_KEY, stale);
422
423 let coord = BarrierCoordinator::new(k);
424 let outcome = coord
425 .wait_for_quorum(10, &[NodeId(2)], Duration::from_millis(100))
426 .await;
427 assert!(
428 matches!(outcome, QuorumOutcome::TimedOut { .. }),
429 "stale-epoch ack must not satisfy quorum"
430 );
431 }
432}