1use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use async_trait::async_trait;
8use parking_lot::Mutex;
9use rustc_hash::{FxHashMap, FxHashSet};
10use serde::{Deserialize, Serialize};
11
12use crate::cluster::discovery::NodeId;
13#[cfg(feature = "cluster")]
14use crate::cluster::discovery::{NodeInfo, NodeState};
15#[cfg(feature = "cluster")]
16use tokio::sync::watch;
17
18pub const ANNOUNCEMENT_KEY: &str = "control:barrier";
20
21pub const ACK_KEY: &str = "control:barrier-ack";
23
24#[cfg(feature = "cluster")]
26pub const BARRIER_ADDR_KEY: &str = "barrier:addr";
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
30pub enum Phase {
31 Prepare,
34 Aligned,
38 Commit,
40 Abort,
42}
43
44#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
46pub struct BarrierAnnouncement {
47 pub epoch: u64,
49 pub checkpoint_id: u64,
51 pub phase: Phase,
53 pub flags: u64,
55 #[serde(default)]
68 pub min_watermark_ms: Option<i64>,
69}
70
71#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
73pub struct BarrierAck {
74 pub epoch: u64,
76 pub ok: bool,
78 pub error: Option<String>,
80 #[serde(default)]
89 pub local_watermark_ms: Option<i64>,
90}
91
92#[derive(Debug, Clone, PartialEq, Eq)]
94pub enum QuorumOutcome {
95 Reached {
97 acks: Vec<NodeId>,
99 min_follower_watermark_ms: Option<i64>,
105 },
106 TimedOut {
108 got: Vec<NodeId>,
110 missing: Vec<NodeId>,
112 },
113 Failed {
115 failures: Vec<(NodeId, String)>,
117 },
118}
119
120#[async_trait]
122pub trait ClusterKv: Send + Sync + 'static {
123 async fn write(&self, key: &str, value: String);
125 async fn read_from(&self, who: NodeId, key: &str) -> Option<String>;
127 async fn scan(&self, key: &str) -> Vec<(NodeId, String)>;
129 async fn scan_prefix(&self, prefix: &str) -> Vec<(NodeId, String, String)>;
131 fn supports_subscription_routing(&self) -> bool {
134 true
135 }
136}
137
138#[derive(Debug)]
140pub struct InMemoryKv {
141 local_id: NodeId,
142 state: Mutex<FxHashMap<(NodeId, String), String>>,
143}
144
145impl InMemoryKv {
146 #[must_use]
148 pub fn new(local_id: NodeId) -> Self {
149 Self {
150 local_id,
151 state: Mutex::new(FxHashMap::default()),
152 }
153 }
154
155 pub fn seed(&self, peer: NodeId, key: &str, value: String) {
157 self.state.lock().insert((peer, key.to_string()), value);
158 }
159}
160
161#[async_trait]
162impl ClusterKv for InMemoryKv {
163 async fn write(&self, key: &str, value: String) {
164 self.state
165 .lock()
166 .insert((self.local_id, key.to_string()), value);
167 }
168
169 async fn read_from(&self, who: NodeId, key: &str) -> Option<String> {
170 self.state.lock().get(&(who, key.to_string())).cloned()
171 }
172
173 async fn scan(&self, key: &str) -> Vec<(NodeId, String)> {
174 self.state
175 .lock()
176 .iter()
177 .filter(|((_, k), _)| k == key)
178 .map(|((n, _), v)| (*n, v.clone()))
179 .collect()
180 }
181
182 async fn scan_prefix(&self, prefix: &str) -> Vec<(NodeId, String, String)> {
183 self.state
184 .lock()
185 .iter()
186 .filter(|((_, k), _)| k.starts_with(prefix))
187 .map(|((n, k), v)| (*n, k.clone(), v.clone()))
188 .collect()
189 }
190}
191
192#[cfg(feature = "cluster")]
193#[allow(
194 clippy::doc_markdown,
195 clippy::default_trait_access,
196 clippy::missing_const_for_fn,
197 clippy::must_use_candidate,
198 clippy::too_many_lines,
199 missing_docs
200)]
201pub(crate) mod barrier_v1 {
202 tonic::include_proto!("laminar.barrier.v1");
203}
204
205#[cfg(feature = "cluster")]
206type BarrierFlavor = crossfire::mpsc::Array<BarrierAnnouncement>;
207
208#[cfg(feature = "cluster")]
211type BarrierClientPool = Arc<
212 parking_lot::Mutex<
213 FxHashMap<
214 NodeId,
215 barrier_v1::barrier_sync_client::BarrierSyncClient<tonic::transport::Channel>,
216 >,
217 >,
218>;
219
220#[cfg(feature = "cluster")]
221struct GrpcState {
222 latest_rx: watch::Receiver<Option<BarrierAnnouncement>>,
228 #[allow(dead_code)]
229 incoming_tx: crossfire::MAsyncTx<BarrierFlavor>,
230 pending_acks: Arc<parking_lot::Mutex<FxHashMap<u64, tokio::sync::oneshot::Sender<BarrierAck>>>>,
231 completed_acks: Arc<parking_lot::Mutex<FxHashMap<u64, BarrierAck>>>,
232 clients: BarrierClientPool,
233 server_handle: Arc<parking_lot::Mutex<Option<tokio::task::JoinHandle<()>>>>,
234 relay_handle: Arc<parking_lot::Mutex<Option<tokio::task::JoinHandle<()>>>>,
235 advertise_addr: String,
236}
237
238#[cfg(feature = "cluster")]
239type ActiveLeaderState = Option<(NodeId, watch::Receiver<Vec<NodeInfo>>)>;
240
241#[cfg(feature = "cluster")]
242struct GrpcBarrierServer {
243 kv: Arc<dyn ClusterKv>,
244 incoming_tx: crossfire::MAsyncTx<BarrierFlavor>,
245 pending_acks: Arc<parking_lot::Mutex<FxHashMap<u64, tokio::sync::oneshot::Sender<BarrierAck>>>>,
246 completed_acks: Arc<parking_lot::Mutex<FxHashMap<u64, BarrierAck>>>,
247 leader_election: Arc<parking_lot::Mutex<ActiveLeaderState>>,
248}
249
250#[cfg(feature = "cluster")]
251impl GrpcBarrierServer {
252 async fn validate_leader(
253 &self,
254 metadata: &tonic::metadata::MetadataMap,
255 ) -> Result<(), tonic::Status> {
256 let leader_id_str = metadata
257 .get("x-leader-id")
258 .ok_or_else(|| tonic::Status::permission_denied("Missing leader identity"))?
259 .to_str()
260 .map_err(|_| tonic::Status::permission_denied("Invalid leader identity"))?;
261 let leader_id_u64 = leader_id_str
262 .parse::<u64>()
263 .map_err(|_| tonic::Status::permission_denied("Invalid leader identity"))?;
264 let sender_leader_id = NodeId(leader_id_u64);
265
266 let election_state = self.leader_election.lock().clone();
267
268 let observed_leader = if let Some((instance_id, members_rx)) = election_state {
269 let members = members_rx.borrow();
270 let mut ids: Vec<NodeId> = members
271 .iter()
272 .filter(|m| matches!(m.state, NodeState::Active))
273 .map(|m| m.id)
274 .collect();
275 ids.push(instance_id);
276 super::leader_of(&ids)
277 } else {
278 let live_nodes: Vec<NodeId> = self
279 .kv
280 .scan(BARRIER_ADDR_KEY)
281 .await
282 .into_iter()
283 .map(|(id, _)| id)
284 .collect();
285 super::leader_of(&live_nodes)
286 };
287
288 if Some(sender_leader_id) != observed_leader {
289 return Err(tonic::Status::permission_denied(
290 "Sender is not the observed leader",
291 ));
292 }
293 Ok(())
294 }
295}
296
297#[cfg(feature = "cluster")]
298#[tonic::async_trait]
299impl barrier_v1::barrier_sync_server::BarrierSync for GrpcBarrierServer {
300 async fn prepare(
301 &self,
302 request: tonic::Request<barrier_v1::PrepareRequest>,
303 ) -> Result<tonic::Response<barrier_v1::Ack>, tonic::Status> {
304 let validation_res = self.validate_leader(request.metadata()).await;
305 let req = request.into_inner();
306
307 {
308 let mut completed = self.completed_acks.lock();
309 if let Some(ack) = completed.remove(&req.epoch) {
310 validation_res?;
311 return Ok(tonic::Response::new(barrier_v1::Ack {
312 epoch: ack.epoch,
313 ok: ack.ok,
314 error: ack.error,
315 local_watermark_ms: ack.local_watermark_ms,
316 }));
317 }
318 }
319
320 let (tx, rx) = tokio::sync::oneshot::channel::<BarrierAck>();
321
322 {
323 let mut guard = self.pending_acks.lock();
324 guard.insert(req.epoch, tx);
325 }
326
327 if let Err(status) = validation_res {
328 let mut guard = self.pending_acks.lock();
329 guard.remove(&req.epoch);
330 return Err(status);
331 }
332
333 let ann = BarrierAnnouncement {
334 epoch: req.epoch,
335 checkpoint_id: req.checkpoint_id,
336 phase: Phase::Prepare,
337 flags: req.flags,
338 min_watermark_ms: None,
339 };
340
341 if self.incoming_tx.send(ann).await.is_err() {
342 let mut guard = self.pending_acks.lock();
343 guard.remove(&req.epoch);
344 return Err(tonic::Status::aborted("Follower coordinator shutdown"));
345 }
346
347 match tokio::time::timeout(Duration::from_secs(30), rx).await {
348 Ok(Ok(ack)) => Ok(tonic::Response::new(barrier_v1::Ack {
349 epoch: ack.epoch,
350 ok: ack.ok,
351 error: ack.error,
352 local_watermark_ms: ack.local_watermark_ms,
353 })),
354 Ok(Err(_)) => Err(tonic::Status::internal("Ack sender dropped")),
355 Err(_) => {
356 let mut guard = self.pending_acks.lock();
357 guard.remove(&req.epoch);
358 Err(tonic::Status::deadline_exceeded(
359 "Follower checkpoint prepare timed out",
360 ))
361 }
362 }
363 }
364
365 async fn aligned(
366 &self,
367 request: tonic::Request<barrier_v1::AlignedRequest>,
368 ) -> Result<tonic::Response<barrier_v1::Ack>, tonic::Status> {
369 self.validate_leader(request.metadata()).await?;
370 let req = request.into_inner();
371
372 let ann = BarrierAnnouncement {
376 epoch: req.epoch,
377 checkpoint_id: req.checkpoint_id,
378 phase: Phase::Aligned,
379 flags: req.flags,
380 min_watermark_ms: req.min_watermark_ms,
381 };
382 if self.incoming_tx.send(ann).await.is_err() {
383 return Err(tonic::Status::aborted("Follower coordinator shutdown"));
384 }
385 Ok(tonic::Response::new(barrier_v1::Ack {
386 epoch: req.epoch,
387 ok: true,
388 error: None,
389 local_watermark_ms: None,
390 }))
391 }
392
393 async fn commit(
394 &self,
395 request: tonic::Request<barrier_v1::CommitRequest>,
396 ) -> Result<tonic::Response<barrier_v1::Ack>, tonic::Status> {
397 self.validate_leader(request.metadata()).await?;
398 let req = request.into_inner();
399
400 {
401 let mut completed = self.completed_acks.lock();
402 completed.remove(&req.epoch);
403 completed.retain(|&epoch, _| epoch >= req.epoch);
404 }
405
406 let ann = BarrierAnnouncement {
407 epoch: req.epoch,
408 checkpoint_id: req.checkpoint_id,
409 phase: Phase::Commit,
410 flags: req.flags,
411 min_watermark_ms: req.min_watermark_ms,
412 };
413 if self.incoming_tx.send(ann).await.is_err() {
414 return Err(tonic::Status::aborted("Follower coordinator shutdown"));
415 }
416 Ok(tonic::Response::new(barrier_v1::Ack {
417 epoch: req.epoch,
418 ok: true,
419 error: None,
420 local_watermark_ms: None,
421 }))
422 }
423
424 async fn abort(
425 &self,
426 request: tonic::Request<barrier_v1::AbortRequest>,
427 ) -> Result<tonic::Response<barrier_v1::Ack>, tonic::Status> {
428 self.validate_leader(request.metadata()).await?;
429 let req = request.into_inner();
430
431 {
432 let mut completed = self.completed_acks.lock();
433 completed.remove(&req.epoch);
434 completed.retain(|&epoch, _| epoch >= req.epoch);
435 }
436
437 let ann = BarrierAnnouncement {
438 epoch: req.epoch,
439 checkpoint_id: req.checkpoint_id,
440 phase: Phase::Abort,
441 flags: req.flags,
442 min_watermark_ms: None,
443 };
444 if self.incoming_tx.send(ann).await.is_err() {
445 return Err(tonic::Status::aborted("Follower coordinator shutdown"));
446 }
447 Ok(tonic::Response::new(barrier_v1::Ack {
448 epoch: req.epoch,
449 ok: true,
450 error: None,
451 local_watermark_ms: None,
452 }))
453 }
454}
455
456#[cfg(feature = "cluster")]
457async fn get_barrier_client(
458 peer: NodeId,
459 pool: &BarrierClientPool,
460 kv: &Arc<dyn ClusterKv>,
461) -> Option<barrier_v1::barrier_sync_client::BarrierSyncClient<tonic::transport::Channel>> {
462 if let Some(client) = pool.lock().get(&peer) {
463 return Some(client.clone());
464 }
465
466 let addr_str = kv.read_from(peer, BARRIER_ADDR_KEY).await?;
467 let endpoint = super::tls::client_endpoint(&addr_str).ok()?;
468 let channel = endpoint.connect_lazy();
469 let client = barrier_v1::barrier_sync_client::BarrierSyncClient::new(channel);
470
471 pool.lock().insert(peer, client.clone());
472 Some(client)
473}
474
475#[cfg(feature = "cluster")]
478fn stamp_leader_id<T>(req: &mut tonic::Request<T>, local_id: Option<NodeId>) {
479 if let Some(lid) = local_id {
480 if let Ok(val) = lid.0.to_string().parse() {
481 req.metadata_mut().insert("x-leader-id", val);
482 }
483 }
484}
485
486#[cfg(feature = "cluster")]
489async fn send_phase_rpc(
490 peer: NodeId,
491 clients_pool: BarrierClientPool,
492 kv: Arc<dyn ClusterKv>,
493 ann: BarrierAnnouncement,
494 local_id: Option<NodeId>,
495) -> Result<(), String> {
496 let mut client = get_barrier_client(peer, &clients_pool, &kv)
497 .await
498 .ok_or_else(|| format!("failed to get client for peer {}", peer.0))?;
499 let result = match ann.phase {
500 Phase::Aligned => {
501 let mut req = tonic::Request::new(barrier_v1::AlignedRequest {
502 epoch: ann.epoch,
503 checkpoint_id: ann.checkpoint_id,
504 flags: ann.flags,
505 min_watermark_ms: ann.min_watermark_ms,
506 });
507 stamp_leader_id(&mut req, local_id);
508 client
509 .aligned(req)
510 .await
511 .map(|_| ())
512 .map_err(|e| ("aligned", e))
513 }
514 Phase::Commit => {
515 let mut req = tonic::Request::new(barrier_v1::CommitRequest {
516 epoch: ann.epoch,
517 checkpoint_id: ann.checkpoint_id,
518 flags: ann.flags,
519 min_watermark_ms: ann.min_watermark_ms,
520 });
521 stamp_leader_id(&mut req, local_id);
522 client
523 .commit(req)
524 .await
525 .map(|_| ())
526 .map_err(|e| ("commit", e))
527 }
528 Phase::Abort => {
529 let mut req = tonic::Request::new(barrier_v1::AbortRequest {
530 epoch: ann.epoch,
531 checkpoint_id: ann.checkpoint_id,
532 flags: ann.flags,
533 });
534 stamp_leader_id(&mut req, local_id);
535 client
536 .abort(req)
537 .await
538 .map(|_| ())
539 .map_err(|e| ("abort", e))
540 }
541 Phase::Prepare => Ok(()),
542 };
543 result.map_err(|(rpc, e)| {
544 clients_pool.lock().remove(&peer);
545 format!("{rpc} RPC to peer {} failed: {e}", peer.0)
546 })
547}
548
549#[cfg(feature = "cluster")]
554enum PeerFailure {
555 Unreachable,
556 Nack(String),
557}
558
559pub struct BarrierCoordinator {
561 kv: Arc<dyn ClusterKv>,
562 #[cfg(feature = "cluster")]
563 grpc: Arc<parking_lot::Mutex<Option<Arc<GrpcState>>>>,
564 #[cfg(feature = "cluster")]
565 leader_election: Arc<parking_lot::Mutex<ActiveLeaderState>>,
566}
567
568impl std::fmt::Debug for BarrierCoordinator {
569 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
570 f.debug_struct("BarrierCoordinator").finish_non_exhaustive()
571 }
572}
573
574impl Drop for BarrierCoordinator {
575 fn drop(&mut self) {
576 #[cfg(feature = "cluster")]
577 {
578 let grpc_opt = self.grpc.lock().take();
579 if let Some(state) = grpc_opt {
580 let handle_opt = state.server_handle.lock().take();
581 if let Some(handle) = handle_opt {
582 handle.abort();
583 }
584 let relay_opt = state.relay_handle.lock().take();
585 if let Some(handle) = relay_opt {
586 handle.abort();
587 }
588 }
589 }
590 }
591}
592
593impl BarrierCoordinator {
594 #[must_use]
596 pub fn new(kv: Arc<dyn ClusterKv>) -> Self {
597 Self {
598 kv,
599 #[cfg(feature = "cluster")]
600 grpc: Arc::new(parking_lot::Mutex::new(None)),
601 #[cfg(feature = "cluster")]
602 leader_election: Arc::new(parking_lot::Mutex::new(None)),
603 }
604 }
605
606 #[cfg(feature = "cluster")]
608 pub fn set_leader_election(
609 &mut self,
610 instance_id: NodeId,
611 members_rx: watch::Receiver<Vec<NodeInfo>>,
612 ) {
613 *self.leader_election.lock() = Some((instance_id, members_rx));
614 }
615
616 #[cfg(feature = "cluster")]
617 async fn local_node_id(&self) -> Option<NodeId> {
618 let grpc_opt = self.grpc.lock().clone();
619 let state = grpc_opt?;
620 let local_addr_str = state.advertise_addr.clone();
621 for (node_id, addr) in self.kv.scan(BARRIER_ADDR_KEY).await {
622 if addr == local_addr_str {
623 return Some(node_id);
624 }
625 }
626 None
627 }
628
629 #[cfg(feature = "cluster")]
634 pub async fn start_server(
635 &self,
636 bind_addr: std::net::SocketAddr,
637 advertise_host: Option<String>,
638 query_handler: super::QueryHandlerSlot,
639 ) -> Result<std::net::SocketAddr, String> {
640 use super::query::query_service_server;
641 use barrier_v1::barrier_sync_server::BarrierSyncServer;
642 use std::net::TcpListener;
643 use tonic::transport::Server;
644
645 let listener = TcpListener::bind(bind_addr).map_err(|e| e.to_string())?;
646 let local_addr = listener.local_addr().map_err(|e| e.to_string())?;
647 listener.set_nonblocking(true).map_err(|e| e.to_string())?;
648 let tokio_listener =
649 tokio::net::TcpListener::from_std(listener).map_err(|e| e.to_string())?;
650
651 let (incoming_tx, incoming_rx) = crossfire::mpsc::bounded_async::<BarrierAnnouncement>(128);
652 let pending_acks = Arc::new(parking_lot::Mutex::new(FxHashMap::default()));
653 let completed_acks = Arc::new(parking_lot::Mutex::new(FxHashMap::default()));
654 let clients = Arc::new(parking_lot::Mutex::new(FxHashMap::default()));
655
656 let server_impl = GrpcBarrierServer {
657 kv: Arc::clone(&self.kv),
658 incoming_tx: incoming_tx.clone(),
659 pending_acks: Arc::clone(&pending_acks),
660 completed_acks: Arc::clone(&completed_acks),
661 leader_election: Arc::clone(&self.leader_election),
662 };
663
664 let query_svc = query_service_server(query_handler);
667 let mut builder = Server::builder();
670 if let Some(tls) = super::tls::server_tls() {
671 builder = builder
672 .tls_config(tls.clone())
673 .map_err(|e| format!("cluster control-plane TLS config: {e}"))?;
674 }
675 let router = builder
676 .add_service(BarrierSyncServer::new(server_impl))
677 .add_service(query_svc);
678 let server_task = tokio::spawn(async move {
679 let incoming_stream = tokio_stream::wrappers::TcpListenerStream::new(tokio_listener);
680 let _ = router.serve_with_incoming(incoming_stream).await;
681 });
682
683 let advertise_addr = if let Some(ref host) = advertise_host {
684 format!("{host}:{}", local_addr.port())
685 } else if local_addr.ip().is_unspecified() {
686 let hostname = gethostname::gethostname();
687 let hostname = hostname.to_string_lossy();
688 if hostname.is_empty() {
689 local_addr.to_string()
690 } else {
691 format!("{hostname}:{}", local_addr.port())
692 }
693 } else {
694 local_addr.to_string()
695 };
696
697 let (latest_tx, latest_rx) = watch::channel::<Option<BarrierAnnouncement>>(None);
703 let relay_task = tokio::spawn(async move {
704 while let Ok(ann) = incoming_rx.recv().await {
705 let _ = latest_tx.send(Some(ann));
706 }
707 });
708
709 let grpc_state = Arc::new(GrpcState {
710 latest_rx,
711 incoming_tx,
712 pending_acks,
713 completed_acks,
714 clients,
715 server_handle: Arc::new(parking_lot::Mutex::new(Some(server_task))),
716 relay_handle: Arc::new(parking_lot::Mutex::new(Some(relay_task))),
717 advertise_addr: advertise_addr.clone(),
718 });
719
720 *self.grpc.lock() = Some(grpc_state);
721
722 self.kv.write(BARRIER_ADDR_KEY, advertise_addr).await;
723
724 Ok(local_addr)
725 }
726
727 pub async fn announce(&self, ann: &BarrierAnnouncement) -> Result<(), String> {
732 #[cfg(feature = "cluster")]
733 {
734 let grpc_opt = self.grpc.lock().clone();
735 if let Some(state) = grpc_opt {
736 let local_id = self.local_node_id().await;
737 if ann.phase == Phase::Prepare {
738 } else {
741 let mut expected = Vec::new();
742 for (node_id, addr) in self.kv.scan(BARRIER_ADDR_KEY).await {
743 if addr == state.advertise_addr {
744 continue;
745 }
746 expected.push(node_id);
747 }
748
749 let mut futures = Vec::new();
750 for peer in expected {
751 let clients_pool = Arc::clone(&state.clients);
752 let kv = Arc::clone(&self.kv);
753 let ann_clone = ann.clone();
754 futures.push(send_phase_rpc(peer, clients_pool, kv, ann_clone, local_id));
755 }
756 let results = futures::future::join_all(futures).await;
757 for res in results {
758 match res {
759 Ok(()) => {}
760 Err(e) if ann.phase == Phase::Aligned => {
766 tracing::warn!(
767 epoch = ann.epoch,
768 error = %e,
769 "aligned announcement RPC failed; peer resumes on Commit"
770 );
771 }
772 Err(e) => return Err(e),
773 }
774 }
775 }
776
777 let json = serde_json::to_string(ann).map_err(|e| e.to_string())?;
778 self.kv.write(ANNOUNCEMENT_KEY, json).await;
779 return Ok(());
780 }
781 }
782
783 let json = serde_json::to_string(ann).map_err(|e| e.to_string())?;
784 self.kv.write(ANNOUNCEMENT_KEY, json).await;
785 Ok(())
786 }
787
788 #[cfg(feature = "cluster")]
793 #[must_use]
794 pub fn announcement_watch(&self) -> Option<watch::Receiver<Option<BarrierAnnouncement>>> {
795 self.grpc.lock().as_ref().map(|s| s.latest_rx.clone())
796 }
797
798 pub async fn observe(&self, leader: NodeId) -> Result<Option<BarrierAnnouncement>, String> {
810 #[cfg(feature = "cluster")]
811 let grpc_latest: Option<BarrierAnnouncement> = {
812 let grpc_opt = self.grpc.lock().clone();
813 grpc_opt.and_then(|state| state.latest_rx.borrow().clone())
814 };
815 #[cfg(not(feature = "cluster"))]
816 let grpc_latest: Option<BarrierAnnouncement> = None;
817
818 let kv_latest: Option<BarrierAnnouncement> =
819 match self.kv.read_from(leader, ANNOUNCEMENT_KEY).await {
820 Some(json) => match serde_json::from_str(&json) {
821 Ok(a) => Some(a),
822 Err(e) if grpc_latest.is_some() => {
825 tracing::warn!(error = %e, "corrupt gossip announcement; using gRPC value");
826 None
827 }
828 Err(e) => return Err(e.to_string()),
829 },
830 None => None,
831 };
832
833 Ok(match (grpc_latest, kv_latest) {
834 (Some(g), Some(k)) => {
835 if k.epoch > g.epoch {
836 Some(k)
837 } else {
838 Some(g)
839 }
840 }
841 (Some(g), None) => Some(g),
842 (None, k) => k,
843 })
844 }
845
846 pub async fn ack(&self, ack: &BarrierAck) -> Result<(), String> {
851 #[cfg(feature = "cluster")]
852 {
853 let grpc_opt = self.grpc.lock().clone();
854 if let Some(state) = grpc_opt {
855 {
856 let mut completed = state.completed_acks.lock();
857 completed.insert(ack.epoch, ack.clone());
858 }
859 let tx_opt = {
860 let mut guard = state.pending_acks.lock();
861 guard.remove(&ack.epoch)
862 };
863 if let Some(tx) = tx_opt {
864 let _ = tx.send(ack.clone());
865 }
866 return Ok(());
867 }
868 }
869
870 let json = serde_json::to_string(ack).map_err(|e| e.to_string())?;
871 self.kv.write(ACK_KEY, json).await;
872 Ok(())
873 }
874
875 #[allow(clippy::too_many_lines)]
877 pub async fn wait_for_quorum(
880 &self,
881 epoch: u64,
882 expected: &[NodeId],
883 deadline: Duration,
884 ) -> QuorumOutcome {
885 #[cfg(feature = "cluster")]
886 {
887 let grpc_opt = self.grpc.lock().clone();
888 if let Some(state) = grpc_opt {
889 let checkpoint_id =
890 match self
891 .kv
892 .scan(ANNOUNCEMENT_KEY)
893 .await
894 .into_iter()
895 .find(|(_, json)| {
896 serde_json::from_str::<BarrierAnnouncement>(json)
897 .is_ok_and(|a| a.epoch == epoch)
898 }) {
899 Some((_, json)) => serde_json::from_str::<BarrierAnnouncement>(&json)
900 .map_or(0, |a| a.checkpoint_id),
901 None => 0,
902 };
903
904 let local_id = self.local_node_id().await;
905 let mut futures = Vec::new();
906 for &peer in expected {
907 let clients_pool = Arc::clone(&state.clients);
908 let kv = Arc::clone(&self.kv);
909 futures.push(async move {
910 let client_opt = get_barrier_client(peer, &clients_pool, &kv).await;
911 let Some(mut client) = client_opt else {
912 return Err((peer, PeerFailure::Unreachable));
913 };
914
915 let mut req = tonic::Request::new(barrier_v1::PrepareRequest {
916 epoch,
917 checkpoint_id,
918 flags: 0,
919 });
920 stamp_leader_id(&mut req, local_id);
921
922 match tokio::time::timeout(deadline, client.prepare(req)).await {
923 Ok(Ok(response)) => {
924 let ack = response.into_inner();
925 if ack.ok {
926 Ok((peer, ack.local_watermark_ms))
927 } else {
928 Err((
929 peer,
930 PeerFailure::Nack(ack.error.unwrap_or_else(|| {
931 "Unknown prepare failure".to_string()
932 })),
933 ))
934 }
935 }
936 Ok(Err(status)) => {
937 clients_pool.lock().remove(&peer);
938 match status.code() {
944 tonic::Code::Unavailable
945 | tonic::Code::DeadlineExceeded
946 | tonic::Code::Cancelled
947 | tonic::Code::Aborted => Err((peer, PeerFailure::Unreachable)),
948 _ => Err((peer, PeerFailure::Nack(status.to_string()))),
949 }
950 }
951 Err(_) => {
952 clients_pool.lock().remove(&peer);
953 Err((peer, PeerFailure::Unreachable))
954 }
955 }
956 });
957 }
958
959 let results = futures::future::join_all(futures).await;
960
961 let mut successful = Vec::new();
962 let mut failures = Vec::new();
963 let mut min_follower_wm: Option<i64> = None;
964 let mut timed_out = Vec::new();
965
966 for res in results {
967 match res {
968 Ok((peer, wm)) => {
969 successful.push(peer);
970 if let Some(w) = wm {
971 min_follower_wm = Some(match min_follower_wm {
972 Some(cur) => cur.min(w),
973 None => w,
974 });
975 }
976 }
977 Err((peer, PeerFailure::Unreachable)) => timed_out.push(peer),
978 Err((peer, PeerFailure::Nack(msg))) => failures.push((peer, msg)),
979 }
980 }
981
982 if !failures.is_empty() {
983 return QuorumOutcome::Failed { failures };
984 }
985
986 if !timed_out.is_empty() || successful.len() < expected.len() {
987 let got = successful;
988 let mut missing = timed_out;
989 for &peer in expected {
990 if !got.contains(&peer) && !missing.contains(&peer) {
991 missing.push(peer);
992 }
993 }
994 return QuorumOutcome::TimedOut { got, missing };
995 }
996
997 return QuorumOutcome::Reached {
998 acks: successful,
999 min_follower_watermark_ms: min_follower_wm,
1000 };
1001 }
1002 }
1003
1004 let start = Instant::now();
1005 let expected_set: FxHashSet<NodeId> = expected.iter().copied().collect();
1006 let mut successful: Vec<NodeId> = Vec::new();
1007 let mut failures: Vec<(NodeId, String)> = Vec::new();
1008 let mut min_follower_wm: Option<i64>;
1009
1010 loop {
1011 successful.clear();
1012 failures.clear();
1013 min_follower_wm = None;
1014
1015 for (from, json) in self.kv.scan(ACK_KEY).await {
1016 if !expected_set.contains(&from) {
1017 continue;
1018 }
1019 let Ok(ack) = serde_json::from_str::<BarrierAck>(&json) else {
1020 continue;
1021 };
1022 if ack.epoch != epoch {
1023 continue;
1024 }
1025 if ack.ok {
1026 successful.push(from);
1027 if let Some(wm) = ack.local_watermark_ms {
1028 min_follower_wm = Some(match min_follower_wm {
1029 Some(cur) => cur.min(wm),
1030 None => wm,
1031 });
1032 }
1033 } else {
1034 failures.push((from, ack.error.unwrap_or_default()));
1035 }
1036 }
1037
1038 if !failures.is_empty() {
1039 return QuorumOutcome::Failed { failures };
1040 }
1041 if successful.len() == expected.len() {
1042 return QuorumOutcome::Reached {
1043 acks: successful,
1044 min_follower_watermark_ms: min_follower_wm,
1045 };
1046 }
1047 if start.elapsed() >= deadline {
1048 let got: FxHashSet<NodeId> = successful.iter().copied().collect();
1049 let missing: Vec<NodeId> = expected
1050 .iter()
1051 .copied()
1052 .filter(|n| !got.contains(n))
1053 .collect();
1054 return QuorumOutcome::TimedOut {
1055 got: successful,
1056 missing,
1057 };
1058 }
1059 tokio::time::sleep(Duration::from_millis(50)).await;
1060 }
1061 }
1062}
1063
1064#[cfg(test)]
1065mod tests {
1066 use super::*;
1067
1068 fn kv(id: NodeId) -> Arc<InMemoryKv> {
1069 Arc::new(InMemoryKv::new(id))
1070 }
1071
1072 #[cfg(all(test, feature = "cluster"))]
1073 mod grpc_tests {
1074 use super::*;
1075 use std::net::SocketAddr;
1076
1077 async fn wait_observe(
1080 coord: &BarrierCoordinator,
1081 leader: NodeId,
1082 phase: Phase,
1083 ) -> BarrierAnnouncement {
1084 for _ in 0..100 {
1085 if let Some(ann) = coord.observe(leader).await.unwrap() {
1086 if ann.phase == phase {
1087 return ann;
1088 }
1089 }
1090 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1091 }
1092 panic!("timed out waiting for {phase:?} announcement from leader {leader:?}");
1093 }
1094
1095 #[tokio::test]
1096 async fn test_grpc_barrier_flow() {
1097 let leader_kv = kv(NodeId(1));
1098 let follower_kv = kv(NodeId(2));
1099 let leader_coord = BarrierCoordinator::new(leader_kv.clone());
1100 let follower_coord = BarrierCoordinator::new(follower_kv.clone());
1101
1102 let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
1103 let slot = || Arc::new(parking_lot::RwLock::new(None));
1104 let leader_addr = leader_coord.start_server(addr, None, slot()).await.unwrap();
1105 let bound_addr = follower_coord
1106 .start_server(addr, None, slot())
1107 .await
1108 .unwrap();
1109
1110 leader_kv.seed(NodeId(2), BARRIER_ADDR_KEY, bound_addr.to_string());
1111 follower_kv.seed(NodeId(1), BARRIER_ADDR_KEY, leader_addr.to_string());
1112
1113 let (aligned_seen_tx, aligned_seen_rx) = tokio::sync::oneshot::channel::<()>();
1117
1118 let follower_task = tokio::spawn(async move {
1119 let ann = wait_observe(&follower_coord, NodeId(1), Phase::Prepare).await;
1120 assert_eq!(ann.epoch, 1);
1121 assert_eq!(ann.checkpoint_id, 42);
1122
1123 follower_coord
1124 .ack(&BarrierAck {
1125 epoch: 1,
1126 ok: true,
1127 error: None,
1128 local_watermark_ms: Some(100),
1129 })
1130 .await
1131 .unwrap();
1132
1133 let aligned_ann = wait_observe(&follower_coord, NodeId(1), Phase::Aligned).await;
1134 assert_eq!(aligned_ann.epoch, 1);
1135 assert_eq!(aligned_ann.min_watermark_ms, Some(100));
1136 aligned_seen_tx.send(()).unwrap();
1137
1138 let commit_ann = wait_observe(&follower_coord, NodeId(1), Phase::Commit).await;
1139 assert_eq!(commit_ann.min_watermark_ms, Some(100));
1140 });
1141
1142 leader_coord
1143 .announce(&BarrierAnnouncement {
1144 epoch: 1,
1145 checkpoint_id: 42,
1146 phase: Phase::Prepare,
1147 flags: 0,
1148 min_watermark_ms: None,
1149 })
1150 .await
1151 .unwrap();
1152
1153 let outcome = leader_coord
1154 .wait_for_quorum(1, &[NodeId(2)], Duration::from_secs(5))
1155 .await;
1156 match outcome {
1157 QuorumOutcome::Reached {
1158 acks,
1159 min_follower_watermark_ms,
1160 } => {
1161 assert_eq!(acks, vec![NodeId(2)]);
1162 assert_eq!(min_follower_watermark_ms, Some(100));
1163
1164 leader_coord
1166 .announce(&BarrierAnnouncement {
1167 epoch: 1,
1168 checkpoint_id: 42,
1169 phase: Phase::Aligned,
1170 flags: 0,
1171 min_watermark_ms: min_follower_watermark_ms,
1172 })
1173 .await
1174 .unwrap();
1175 aligned_seen_rx.await.unwrap();
1176
1177 leader_coord
1179 .announce(&BarrierAnnouncement {
1180 epoch: 1,
1181 checkpoint_id: 42,
1182 phase: Phase::Commit,
1183 flags: 0,
1184 min_watermark_ms: min_follower_watermark_ms,
1185 })
1186 .await
1187 .unwrap();
1188 }
1189 other => panic!("expected Reached, got {other:?}"),
1190 }
1191
1192 follower_task.await.unwrap();
1193 }
1194 }
1195
1196 #[cfg(feature = "cluster")]
1203 #[tokio::test]
1204 async fn observe_merges_grpc_and_gossip_by_epoch() {
1205 let leader_kv = kv(NodeId(1));
1206 let follower_kv = kv(NodeId(2));
1207 let leader_coord = BarrierCoordinator::new(leader_kv.clone());
1208 let follower_coord = BarrierCoordinator::new(follower_kv.clone());
1209
1210 let addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap();
1211 let slot = || Arc::new(parking_lot::RwLock::new(None));
1212 let leader_addr = leader_coord.start_server(addr, None, slot()).await.unwrap();
1213 let bound_addr = follower_coord
1214 .start_server(addr, None, slot())
1215 .await
1216 .unwrap();
1217 leader_kv.seed(NodeId(2), BARRIER_ADDR_KEY, bound_addr.to_string());
1218 follower_kv.seed(NodeId(1), BARRIER_ADDR_KEY, leader_addr.to_string());
1219
1220 leader_coord
1223 .announce(&BarrierAnnouncement {
1224 epoch: 5,
1225 checkpoint_id: 9,
1226 phase: Phase::Abort,
1227 flags: 0,
1228 min_watermark_ms: None,
1229 })
1230 .await
1231 .unwrap();
1232 for _ in 0..100 {
1233 if let Some(ann) = follower_coord.observe(NodeId(1)).await.unwrap() {
1234 if ann.phase == Phase::Abort {
1235 break;
1236 }
1237 }
1238 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
1239 }
1240
1241 let next = serde_json::to_string(&BarrierAnnouncement {
1245 epoch: 6,
1246 checkpoint_id: 10,
1247 phase: Phase::Prepare,
1248 flags: 0,
1249 min_watermark_ms: None,
1250 })
1251 .unwrap();
1252 follower_kv.seed(NodeId(1), ANNOUNCEMENT_KEY, next);
1253 let got = follower_coord.observe(NodeId(1)).await.unwrap().unwrap();
1254 assert_eq!(got.epoch, 6);
1255 assert_eq!(got.phase, Phase::Prepare);
1256
1257 let stale = serde_json::to_string(&BarrierAnnouncement {
1260 epoch: 5,
1261 checkpoint_id: 9,
1262 phase: Phase::Prepare,
1263 flags: 0,
1264 min_watermark_ms: None,
1265 })
1266 .unwrap();
1267 follower_kv.seed(NodeId(1), ANNOUNCEMENT_KEY, stale);
1268 let got = follower_coord.observe(NodeId(1)).await.unwrap().unwrap();
1269 assert_eq!(
1270 got.phase,
1271 Phase::Abort,
1272 "lagging gossip must not mask the fresher gRPC announcement",
1273 );
1274 }
1275
1276 #[tokio::test]
1277 async fn leader_announces_follower_observes() {
1278 let leader_kv = kv(NodeId(1));
1279 let coord = BarrierCoordinator::new(leader_kv.clone());
1280 coord
1281 .announce(&BarrierAnnouncement {
1282 epoch: 5,
1283 checkpoint_id: 42,
1284 phase: Phase::Prepare,
1285 flags: 0,
1286 min_watermark_ms: None,
1287 })
1288 .await
1289 .unwrap();
1290 let got = coord.observe(NodeId(1)).await.unwrap().unwrap();
1291 assert_eq!(got.epoch, 5);
1292 assert_eq!(got.checkpoint_id, 42);
1293 }
1294
1295 #[tokio::test]
1296 async fn observe_returns_none_when_leader_silent() {
1297 let k = kv(NodeId(1));
1298 let coord = BarrierCoordinator::new(k);
1299 assert!(coord.observe(NodeId(1)).await.unwrap().is_none());
1300 }
1301
1302 #[tokio::test]
1303 async fn quorum_reached_when_all_ack_success() {
1304 let k = kv(NodeId(1));
1305 let ack_json = serde_json::to_string(&BarrierAck {
1306 epoch: 7,
1307 ok: true,
1308 error: None,
1309 local_watermark_ms: None,
1310 })
1311 .unwrap();
1312 k.seed(NodeId(2), ACK_KEY, ack_json.clone());
1313 k.seed(NodeId(3), ACK_KEY, ack_json);
1314
1315 let coord = BarrierCoordinator::new(k);
1316 let outcome = coord
1317 .wait_for_quorum(7, &[NodeId(2), NodeId(3)], Duration::from_millis(200))
1318 .await;
1319 match outcome {
1320 QuorumOutcome::Reached {
1321 mut acks,
1322 min_follower_watermark_ms,
1323 } => {
1324 acks.sort_by_key(|n| n.0);
1325 assert_eq!(acks, vec![NodeId(2), NodeId(3)]);
1326 assert_eq!(
1327 min_follower_watermark_ms, None,
1328 "no follower reported a watermark — min is None"
1329 );
1330 }
1331 other => panic!("expected Reached, got {other:?}"),
1332 }
1333 }
1334
1335 #[tokio::test]
1336 async fn quorum_timeout_when_follower_silent() {
1337 let k = kv(NodeId(1));
1338 let ack_json = serde_json::to_string(&BarrierAck {
1339 epoch: 8,
1340 ok: true,
1341 error: None,
1342 local_watermark_ms: None,
1343 })
1344 .unwrap();
1345 k.seed(NodeId(2), ACK_KEY, ack_json);
1346
1347 let coord = BarrierCoordinator::new(k);
1348 let outcome = coord
1349 .wait_for_quorum(8, &[NodeId(2), NodeId(3)], Duration::from_millis(150))
1350 .await;
1351 match outcome {
1352 QuorumOutcome::TimedOut { got, missing } => {
1353 assert_eq!(got, vec![NodeId(2)]);
1354 assert_eq!(missing, vec![NodeId(3)]);
1355 }
1356 other => panic!("expected TimedOut, got {other:?}"),
1357 }
1358 }
1359
1360 #[tokio::test]
1361 async fn quorum_fails_fast_on_reported_error() {
1362 let k = kv(NodeId(1));
1363 let good = serde_json::to_string(&BarrierAck {
1364 epoch: 9,
1365 ok: true,
1366 error: None,
1367 local_watermark_ms: None,
1368 })
1369 .unwrap();
1370 let bad = serde_json::to_string(&BarrierAck {
1371 epoch: 9,
1372 ok: false,
1373 error: Some("state snapshot failed: disk full".into()),
1374 local_watermark_ms: None,
1375 })
1376 .unwrap();
1377 k.seed(NodeId(2), ACK_KEY, good);
1378 k.seed(NodeId(3), ACK_KEY, bad);
1379
1380 let coord = BarrierCoordinator::new(k);
1381 let outcome = coord
1382 .wait_for_quorum(9, &[NodeId(2), NodeId(3)], Duration::from_secs(2))
1383 .await;
1384 match outcome {
1385 QuorumOutcome::Failed { failures } => {
1386 assert_eq!(failures.len(), 1);
1387 assert_eq!(failures[0].0, NodeId(3));
1388 assert!(failures[0].1.contains("disk full"));
1389 }
1390 other => panic!("expected Failed, got {other:?}"),
1391 }
1392 }
1393
1394 #[tokio::test]
1395 async fn wrong_epoch_ack_is_ignored() {
1396 let k = kv(NodeId(1));
1397 let stale = serde_json::to_string(&BarrierAck {
1398 epoch: 9,
1399 ok: true,
1400 error: None,
1401 local_watermark_ms: None,
1402 })
1403 .unwrap();
1404 k.seed(NodeId(2), ACK_KEY, stale);
1405
1406 let coord = BarrierCoordinator::new(k);
1407 let outcome = coord
1408 .wait_for_quorum(10, &[NodeId(2)], Duration::from_millis(100))
1409 .await;
1410 assert!(
1411 matches!(outcome, QuorumOutcome::TimedOut { .. }),
1412 "stale-epoch ack must not satisfy quorum"
1413 );
1414 }
1415}