1use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8use tokio::sync::watch;
9
10use super::barrier::{
11 BarrierAck, BarrierAnnouncement, BarrierCoordinator, ClusterKv, Phase, QuorumOutcome,
12};
13use super::leader::leader_of;
14use super::snapshot::AssignmentSnapshotStore;
15use crate::cluster::discovery::{assignable_node_ids, NodeId, NodeInfo, NodeState};
16use crate::state::Locality;
17
18pub struct ClusterController {
20 instance_id: NodeId,
21 kv: Arc<dyn ClusterKv>,
22 barrier: BarrierCoordinator,
23 snapshot: Option<Arc<AssignmentSnapshotStore>>,
24 members_rx: watch::Receiver<Vec<NodeInfo>>,
25 cluster_min_watermark: Arc<AtomicI64>,
31 draining: Arc<AtomicBool>,
35 active: Arc<AtomicBool>,
37 unresponsive: Arc<parking_lot::Mutex<rustc_hash::FxHashMap<u64, std::time::Instant>>>,
45 self_locality: parking_lot::RwLock<Locality>,
48 #[cfg(feature = "cluster")]
50 query_handler: super::query::QueryHandlerSlot,
51 #[cfg(feature = "cluster")]
53 query_client_pool: super::query::QueryClientPool,
54}
55
56impl std::fmt::Debug for ClusterController {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 f.debug_struct("ClusterController")
59 .field("instance_id", &self.instance_id)
60 .finish_non_exhaustive()
61 }
62}
63
64impl ClusterController {
65 #[must_use]
67 pub fn new(
68 instance_id: NodeId,
69 kv: Arc<dyn ClusterKv>,
70 snapshot: Option<Arc<AssignmentSnapshotStore>>,
71 members_rx: watch::Receiver<Vec<NodeInfo>>,
72 ) -> Self {
73 let mut barrier = BarrierCoordinator::new(Arc::clone(&kv));
74 #[cfg(feature = "cluster")]
75 barrier.set_leader_election(instance_id, members_rx.clone());
76 Self {
77 instance_id,
78 barrier,
79 kv,
80 snapshot,
81 members_rx,
82 cluster_min_watermark: Arc::new(AtomicI64::new(i64::MIN)),
83 draining: Arc::new(AtomicBool::new(false)),
84 active: Arc::new(AtomicBool::new(true)),
85 unresponsive: Arc::new(parking_lot::Mutex::new(rustc_hash::FxHashMap::default())),
86 self_locality: parking_lot::RwLock::new(Locality::default()),
87 #[cfg(feature = "cluster")]
88 query_handler: Arc::new(parking_lot::RwLock::new(None)),
89 #[cfg(feature = "cluster")]
90 query_client_pool: Arc::new(parking_lot::Mutex::new(rustc_hash::FxHashMap::default())),
91 }
92 }
93
94 #[cfg(feature = "cluster")]
96 pub fn register_query_handler(&self, handler: Arc<dyn super::query::RemoteQueryHandler>) {
97 *self.query_handler.write() = Some(handler);
98 }
99
100 #[cfg(feature = "cluster")]
102 #[must_use]
103 pub fn query_client_pool(&self) -> &super::query::QueryClientPool {
104 &self.query_client_pool
105 }
106
107 #[must_use]
111 pub fn cluster_min_watermark(&self) -> Option<i64> {
112 let v = self.cluster_min_watermark.load(Ordering::Acquire);
113 if v == i64::MIN {
114 None
115 } else {
116 Some(v)
117 }
118 }
119
120 pub fn publish_cluster_min_watermark(&self, wm: i64) {
128 let mut cur = self.cluster_min_watermark.load(Ordering::Acquire);
129 while wm > cur {
130 match self.cluster_min_watermark.compare_exchange(
131 cur,
132 wm,
133 Ordering::AcqRel,
134 Ordering::Acquire,
135 ) {
136 Ok(_) => break,
137 Err(observed) => cur = observed,
138 }
139 }
140 }
141
142 #[must_use]
144 pub fn instance_id(&self) -> NodeId {
145 self.instance_id
146 }
147
148 #[must_use]
151 pub fn kv(&self) -> &Arc<dyn ClusterKv> {
152 &self.kv
153 }
154
155 #[must_use]
157 pub fn current_leader(&self) -> Option<NodeId> {
158 let members = self.members_rx.borrow();
159 let mut ids: Vec<NodeId> = members
160 .iter()
161 .filter(|m| matches!(m.state, NodeState::Active))
162 .map(|m| m.id)
163 .collect();
164 if self.active.load(Ordering::SeqCst) {
166 ids.push(self.instance_id);
167 }
168 leader_of(&ids)
169 }
170
171 #[must_use]
173 pub fn is_leader(&self) -> bool {
174 self.current_leader() == Some(self.instance_id)
175 }
176
177 pub fn set_active(&self, active: bool) {
179 self.active.store(active, Ordering::SeqCst);
180 }
181
182 #[must_use]
184 pub fn live_instances(&self) -> Vec<NodeId> {
185 let mut ids: Vec<NodeId> = self
186 .members_rx
187 .borrow()
188 .iter()
189 .filter(|m| matches!(m.state, NodeState::Active))
190 .map(|m| m.id)
191 .collect();
192 if self.active.load(Ordering::SeqCst) {
193 ids.push(self.instance_id);
194 }
195 ids
196 }
197
198 pub fn note_unresponsive(&self, peers: &[NodeId]) {
200 let now = std::time::Instant::now();
201 let mut map = self.unresponsive.lock();
202 for p in peers {
203 map.insert(p.0, now);
204 }
205 }
206
207 pub fn note_responsive(&self, peers: &[NodeId]) {
209 let mut map = self.unresponsive.lock();
210 for p in peers {
211 map.remove(&p.0);
212 }
213 }
214
215 #[must_use]
217 pub fn is_recently_unresponsive(&self, peer: NodeId) -> bool {
218 const UNRESPONSIVE_TTL: Duration = Duration::from_secs(60);
220 self.unresponsive
221 .lock()
222 .get(&peer.0)
223 .is_some_and(|at| at.elapsed() < UNRESPONSIVE_TTL)
224 }
225
226 pub fn begin_drain(&self) {
228 self.draining.store(true, Ordering::SeqCst);
229 }
230
231 #[must_use]
233 pub fn is_draining(&self) -> bool {
234 self.draining.load(Ordering::SeqCst)
235 }
236
237 #[must_use]
242 pub fn assignable_instances(&self) -> Vec<NodeId> {
243 let mut ids = assignable_node_ids(&self.members_rx.borrow());
244 if self.active.load(Ordering::SeqCst)
245 && !self.is_draining()
246 && !self.instance_id.is_unassigned()
247 {
248 ids.push(self.instance_id);
249 }
250 ids.sort_unstable();
251 ids.dedup();
252 ids
253 }
254
255 pub fn set_self_locality(&self, locality: Locality) {
257 *self.self_locality.write() = locality;
258 }
259
260 #[must_use]
263 pub fn assignable_with_locality(&self) -> Vec<(NodeId, Locality)> {
264 let members = self.members_rx.borrow();
265 self.assignable_instances()
266 .into_iter()
267 .map(|id| {
268 let locality = if id == self.instance_id {
269 self.self_locality.read().clone()
270 } else {
271 members
272 .iter()
273 .find(|m| m.id == id)
274 .and_then(|m| m.metadata.failure_domain.as_deref())
275 .map(Locality::parse)
276 .unwrap_or_default()
277 };
278 (id, locality)
279 })
280 .collect()
281 }
282
283 #[must_use]
287 pub fn members_watch(&self) -> watch::Receiver<Vec<NodeInfo>> {
288 self.members_rx.clone()
289 }
290
291 pub async fn announce_snapshot_version(&self, version: u64) {
293 self.kv
294 .write("control:snapshot-version", version.to_string())
295 .await;
296 }
297
298 pub async fn read_snapshot_version(&self) -> Option<u64> {
300 let scans = self.kv.scan("control:snapshot-version").await;
301 scans
302 .into_iter()
303 .filter_map(|(_, v)| v.parse::<u64>().ok())
304 .max()
305 }
306
307 #[cfg(feature = "cluster")]
312 pub async fn start_barrier_server(
313 &self,
314 bind_addr: std::net::SocketAddr,
315 advertise_host: Option<String>,
316 ) -> Result<std::net::SocketAddr, String> {
317 self.barrier
318 .start_server(bind_addr, advertise_host, Arc::clone(&self.query_handler))
319 .await
320 }
321
322 pub async fn announce_barrier(&self, ann: &BarrierAnnouncement) -> Result<(), String> {
327 self.barrier.announce(ann).await
328 }
329
330 pub async fn observe_barrier(&self) -> Result<Option<BarrierAnnouncement>, String> {
342 let Some(leader) = self.current_leader() else {
343 return Ok(None);
344 };
345 let observed = self.barrier.observe(leader).await?;
346 if let Some(ref ann) = observed {
347 if matches!(ann.phase, Phase::Commit | Phase::Aligned) {
348 if let Some(wm) = ann.min_watermark_ms {
349 let mut cur = self.cluster_min_watermark.load(Ordering::Acquire);
352 while wm > cur {
353 match self.cluster_min_watermark.compare_exchange(
354 cur,
355 wm,
356 Ordering::AcqRel,
357 Ordering::Acquire,
358 ) {
359 Ok(_) => break,
360 Err(observed) => cur = observed,
361 }
362 }
363 }
364 }
365 }
366 Ok(observed)
367 }
368
369 pub async fn ack_barrier(&self, ack: &BarrierAck) -> Result<(), String> {
374 self.barrier.ack(ack).await
375 }
376
377 #[cfg(feature = "cluster")]
383 pub async fn wait_for_barrier<F>(
384 &self,
385 mut pred: F,
386 timeout: Duration,
387 ) -> Option<BarrierAnnouncement>
388 where
389 F: FnMut(&BarrierAnnouncement) -> bool,
390 {
391 let mut watch = self.barrier.announcement_watch();
392 let poll_for = |watch: &Option<_>| {
395 if watch.is_some() {
396 Duration::from_millis(250)
397 } else {
398 Duration::from_millis(25)
399 }
400 };
401 let deadline = tokio::time::Instant::now() + timeout;
402 loop {
403 if let Ok(Some(ann)) = self.observe_barrier().await {
404 if pred(&ann) {
405 return Some(ann);
406 }
407 }
408 if tokio::time::Instant::now() >= deadline {
409 return None;
410 }
411 let poll = poll_for(&watch);
412 let pushed = async {
413 match watch.as_mut() {
414 Some(w) => w.changed().await.is_ok(),
415 None => std::future::pending().await,
416 }
417 };
418 tokio::select! {
419 ok = pushed => {
420 if !ok {
421 watch = None;
424 }
425 }
426 () = tokio::time::sleep(poll) => {}
427 () = tokio::time::sleep_until(deadline) => return None,
428 }
429 }
430 }
431
432 pub async fn wait_for_quorum(
434 &self,
435 epoch: u64,
436 expected: &[NodeId],
437 deadline: Duration,
438 ) -> QuorumOutcome {
439 self.barrier
440 .wait_for_quorum(epoch, expected, deadline)
441 .await
442 }
443
444 #[must_use]
446 pub fn snapshot_store(&self) -> Option<&AssignmentSnapshotStore> {
447 self.snapshot.as_deref()
448 }
449}
450
451#[cfg(test)]
452mod tests {
453 use super::*;
454 use crate::cluster::control::barrier::InMemoryKv;
455 use crate::cluster::discovery::{NodeMetadata, NodeState};
456
457 fn info(id: u64) -> NodeInfo {
458 NodeInfo {
459 id: NodeId(id),
460 name: format!("n{id}"),
461 rpc_address: String::new(),
462 raft_address: String::new(),
463 state: NodeState::Active,
464 metadata: NodeMetadata::default(),
465 last_heartbeat_ms: 0,
466 }
467 }
468
469 fn ctl(self_id: u64, peers: Vec<NodeInfo>) -> ClusterController {
470 let (_tx, rx) = watch::channel(peers);
471 let kv: Arc<dyn ClusterKv> = Arc::new(InMemoryKv::new(NodeId(self_id)));
472 ClusterController::new(NodeId(self_id), kv, None, rx)
473 }
474
475 #[test]
476 fn is_leader_when_lowest_id() {
477 let c = ctl(1, vec![info(5), info(7)]);
478 assert!(c.is_leader());
479 }
480
481 #[test]
482 fn follower_when_peer_has_lower_id() {
483 let c = ctl(7, vec![info(3), info(5)]);
484 assert!(!c.is_leader());
485 assert_eq!(c.current_leader(), Some(NodeId(3)));
486 }
487
488 #[test]
489 fn solo_instance_is_leader() {
490 let c = ctl(42, vec![]);
491 assert!(c.is_leader());
492 }
493
494 #[test]
495 fn assignable_instances_excludes_draining_peer_and_self_on_drain() {
496 let mut draining_peer = info(5);
497 draining_peer.state = NodeState::Draining;
498 let c = ctl(1, vec![info(3), draining_peer]);
499
500 assert_eq!(c.assignable_instances(), vec![NodeId(1), NodeId(3)]);
502 assert!(!c.is_draining());
503
504 c.begin_drain();
506 assert!(c.is_draining());
507 assert_eq!(c.assignable_instances(), vec![NodeId(3)]);
508 }
509
510 #[test]
511 fn assignable_with_locality_attaches_self_and_peer_domains() {
512 let mut peer = info(3);
513 peer.metadata.failure_domain = Some("region=r;zone=z2".to_string());
514 let c = ctl(1, vec![peer]);
515 c.set_self_locality(Locality::parse("region=r;zone=z1"));
516
517 let pairs = c.assignable_with_locality();
518 let ids: Vec<NodeId> = pairs.iter().map(|(id, _)| *id).collect();
520 assert_eq!(ids, vec![NodeId(1), NodeId(3)]);
521 let self_loc = &pairs.iter().find(|(id, _)| *id == NodeId(1)).unwrap().1;
523 let peer_loc = &pairs.iter().find(|(id, _)| *id == NodeId(3)).unwrap().1;
524 assert_eq!(self_loc.domain_at(1), "r;z1");
525 assert_eq!(peer_loc.domain_at(1), "r;z2");
526 }
527
528 #[test]
529 fn assignable_with_locality_defaults_unlabeled_to_empty_domain() {
530 let c = ctl(1, vec![info(3)]);
533 let pairs = c.assignable_with_locality();
534 assert_eq!(pairs.len(), 2);
535 assert!(pairs.iter().all(|(_, loc)| loc.domain_at(0).is_empty()));
536 }
537
538 #[tokio::test]
539 async fn announce_observe_roundtrip_when_alone() {
540 let c = ctl(1, vec![]);
543 c.announce_barrier(&BarrierAnnouncement {
544 epoch: 5,
545 checkpoint_id: 1,
546 phase: crate::cluster::control::Phase::Prepare,
547 flags: 0,
548 min_watermark_ms: None,
549 })
550 .await
551 .unwrap();
552 let got = c.observe_barrier().await.unwrap().unwrap();
553 assert_eq!(got.epoch, 5);
554 }
555
556 #[test]
557 fn publish_cluster_min_watermark_is_monotonic() {
558 let c = ctl(1, vec![]);
561 assert_eq!(c.cluster_min_watermark(), None);
562
563 c.publish_cluster_min_watermark(100);
564 assert_eq!(c.cluster_min_watermark(), Some(100));
565
566 c.publish_cluster_min_watermark(250);
568 assert_eq!(c.cluster_min_watermark(), Some(250));
569
570 c.publish_cluster_min_watermark(42);
572 assert_eq!(c.cluster_min_watermark(), Some(250));
573
574 c.publish_cluster_min_watermark(250);
576 assert_eq!(c.cluster_min_watermark(), Some(250));
577 }
578
579 #[tokio::test]
580 async fn observe_commit_publishes_cluster_min_watermark() {
581 let c = ctl(1, vec![]);
585 assert_eq!(c.cluster_min_watermark(), None, "uninitialised");
586
587 c.announce_barrier(&BarrierAnnouncement {
588 epoch: 9,
589 checkpoint_id: 1,
590 phase: crate::cluster::control::Phase::Commit,
591 flags: 0,
592 min_watermark_ms: Some(12_345),
593 })
594 .await
595 .unwrap();
596 c.observe_barrier().await.unwrap();
597 assert_eq!(c.cluster_min_watermark(), Some(12_345));
598
599 c.announce_barrier(&BarrierAnnouncement {
602 epoch: 10,
603 checkpoint_id: 2,
604 phase: crate::cluster::control::Phase::Commit,
605 flags: 0,
606 min_watermark_ms: Some(100), })
608 .await
609 .unwrap();
610 c.observe_barrier().await.unwrap();
611 assert_eq!(
612 c.cluster_min_watermark(),
613 Some(12_345),
614 "stale Commit must not lower the published watermark",
615 );
616
617 c.announce_barrier(&BarrierAnnouncement {
619 epoch: 11,
620 checkpoint_id: 3,
621 phase: crate::cluster::control::Phase::Prepare,
622 flags: 0,
623 min_watermark_ms: None,
624 })
625 .await
626 .unwrap();
627 c.observe_barrier().await.unwrap();
628 assert_eq!(c.cluster_min_watermark(), Some(12_345));
629 }
630}