1use std::collections::HashMap;
4use std::path::PathBuf;
5use std::sync::Arc;
6
7use tokio::signal;
8use tokio::sync::watch;
9use tracing::{info, warn};
10
11use laminar_core::cluster::discovery::{
12 Discovery, DiscoveryError, GossipDiscovery, GossipDiscoveryConfig, NodeId, NodeInfo,
13 NodeMetadata, NodeState, StaticDiscovery, StaticDiscoveryConfig,
14};
15enum DiscoveryImpl {
17 Static(StaticDiscovery),
18 Gossip(GossipDiscovery),
19}
20
21impl DiscoveryImpl {
22 async fn start(&mut self) -> Result<(), DiscoveryError> {
23 match self {
24 Self::Static(d) => d.start().await,
25 Self::Gossip(d) => d.start().await,
26 }
27 }
28
29 async fn peers(&self) -> Result<Vec<NodeInfo>, DiscoveryError> {
30 match self {
31 Self::Static(d) => d.peers().await,
32 Self::Gossip(d) => d.peers().await,
33 }
34 }
35
36 fn membership_watch(&self) -> watch::Receiver<Vec<NodeInfo>> {
37 match self {
38 Self::Static(d) => d.membership_watch(),
39 Self::Gossip(d) => d.membership_watch(),
40 }
41 }
42
43 async fn stop(&mut self) -> Result<(), DiscoveryError> {
44 match self {
45 Self::Static(d) => d.stop().await,
46 Self::Gossip(d) => d.stop().await,
47 }
48 }
49}
50
51fn spawn_membership_watcher(
53 local_node_id: &str,
54 mut rx: watch::Receiver<Vec<NodeInfo>>,
55) -> tokio::task::JoinHandle<()> {
56 let local_name = local_node_id.to_string();
57 tokio::spawn(async move {
58 let mut known: HashMap<u64, (String, NodeState)> = HashMap::new();
59 for node in rx.borrow_and_update().iter() {
60 known.insert(node.id.0, (node.name.clone(), node.state));
61 }
62
63 loop {
64 if rx.changed().await.is_err() {
65 info!("[{local_name}] Membership watcher stopping (discovery shut down)");
67 break;
68 }
69
70 let current_peers = rx.borrow_and_update().clone();
71
72 let mut current: HashMap<u64, (String, NodeState)> = HashMap::new();
73 for node in ¤t_peers {
74 current.insert(node.id.0, (node.name.clone(), node.state));
75 }
76
77 for (id, (name, state)) in ¤t {
78 if !known.contains_key(id) {
79 info!(
80 "[{local_name}] Peer joined: '{}' (id={}, state={})",
81 name, id, state
82 );
83 }
84 }
85
86 for (id, (name, old_state)) in &known {
87 if !current.contains_key(id) {
88 if *old_state == NodeState::Suspected {
89 warn!(
90 "[{local_name}] Peer crashed: '{}' (id={}, was suspected)",
91 name, id
92 );
93 } else {
94 warn!(
95 "[{local_name}] Peer left: '{}' (id={}, was {})",
96 name, id, old_state
97 );
98 }
99 }
100 }
101
102 for (id, (name, new_state)) in ¤t {
103 if let Some((_, old_state)) = known.get(id) {
104 if old_state != new_state {
105 let level = match new_state {
106 NodeState::Suspected => "WARN",
107 NodeState::Left | NodeState::Draining => "WARN",
108 _ => "INFO",
109 };
110 if level == "WARN" {
111 warn!(
112 "[{local_name}] Peer state changed: '{}' (id={}) {} -> {}",
113 name, id, old_state, new_state
114 );
115 } else {
116 info!(
117 "[{local_name}] Peer state changed: '{}' (id={}) {} -> {}",
118 name, id, old_state, new_state
119 );
120 }
121 }
122 }
123 }
124
125 known = current;
126 }
127 })
128}
129
130use laminar_db::{LaminarDB, Profile};
131
132use crate::cluster_config::ClusterConfig;
133use crate::config::ServerConfig;
134use crate::server;
135
136#[derive(Debug, thiserror::Error)]
137pub enum ClusterStartupError {
138 #[error("discovery failed: {0}")]
139 Discovery(String),
140 #[error("formation timeout: only {found} of {needed} peers discovered")]
141 FormationTimeout { found: usize, needed: usize },
142 #[error("engine construction failed: {0}")]
143 EngineConstruction(String),
144 #[error("HTTP startup failed: {0}")]
145 HttpStartup(String),
146 #[error(
147 "invalid coordination.raft_port={0}: RPC port = raft_port + 1 would \
148 overflow u16; choose a raft_port below {max}",
149 max = u16::MAX
150 )]
151 InvalidRaftPort(u16),
152}
153
154pub struct ClusterHandle {
155 db: Arc<LaminarDB>,
156 discovery: DiscoveryImpl,
157 api_handle: tokio::task::JoinHandle<()>,
158 watcher_handle: Option<tokio::task::JoinHandle<()>>,
159 membership_handle: tokio::task::JoinHandle<()>,
160 rebalance_tasks: Vec<tokio::task::JoinHandle<()>>,
164 rebalance_shutdown: Arc<tokio::sync::Notify>,
168}
169
170impl ClusterHandle {
171 pub async fn wait_for_shutdown(mut self) -> Result<(), ClusterStartupError> {
172 signal::ctrl_c()
173 .await
174 .map_err(|e| ClusterStartupError::Discovery(format!("signal handler: {e}")))?;
175
176 info!("Received shutdown signal, shutting down cluster node...");
177
178 self.rebalance_shutdown.notify_waiters();
182 for task in &self.rebalance_tasks {
183 task.abort();
184 }
185 for task in self.rebalance_tasks.drain(..) {
186 let _ = task.await;
187 }
188
189 self.membership_handle.abort();
191
192 if let Err(e) = self.discovery.stop().await {
194 warn!("Discovery stop error: {e}");
195 }
196
197 if let Some(wh) = &self.watcher_handle {
199 wh.abort();
200 }
201
202 if let Err(e) = self.db.shutdown().await {
204 tracing::warn!("Engine shutdown error: {e}");
205 }
206
207 self.api_handle.abort();
209
210 info!("Cluster node shutdown complete");
211 Ok(())
212 }
213}
214
215pub async fn start_cluster(
217 config: ServerConfig,
218 cluster_cfg: ClusterConfig,
219 config_path: PathBuf,
220) -> Result<ClusterHandle, ClusterStartupError> {
221 let node_id_str = cluster_cfg.node_id.as_str().to_string();
222 let node_id_num = {
225 let h = xxhash_rust::xxh3::xxh3_64(node_id_str.as_bytes());
226 if h == 0 {
228 1
229 } else {
230 h
231 }
232 };
233 let node_id = NodeId(node_id_num);
234
235 let bind_addr = &config.server.bind;
236 let coordination = &cluster_cfg.coordination;
237 let raft_port = coordination.raft_port;
238 let rpc_port = raft_port
239 .checked_add(1)
240 .ok_or(ClusterStartupError::InvalidRaftPort(raft_port))?;
241
242 let bind_host = if let Some(bracket_end) = bind_addr.rfind(']') {
245 &bind_addr[..=bracket_end]
247 } else if let Some(colon) = bind_addr.rfind(':') {
248 &bind_addr[..colon]
250 } else {
251 bind_addr.as_str()
252 };
253
254 let local_node = NodeInfo {
255 id: node_id,
256 name: node_id_str.clone(),
257 rpc_address: format!("{bind_host}:{rpc_port}"),
258 raft_address: format!("{bind_host}:{raft_port}"),
259 state: NodeState::Joining,
260 metadata: NodeMetadata {
261 cores: num_cpus(),
262 memory_bytes: 0,
263 failure_domain: None,
264 tags: std::collections::HashMap::new(),
265 owned_partitions: Vec::new(),
266 version: env!("CARGO_PKG_VERSION").to_string(),
267 },
268 last_heartbeat_ms: 0,
269 };
270
271 info!(
273 "Starting cluster discovery (strategy: {})",
274 cluster_cfg.discovery.strategy
275 );
276
277 let mut discovery: DiscoveryImpl = match cluster_cfg.discovery.strategy.as_str() {
278 "gossip" => {
279 let gossip_config = GossipDiscoveryConfig {
280 gossip_address: format!("0.0.0.0:{}", cluster_cfg.discovery.gossip_port),
281 seed_nodes: cluster_cfg.discovery.seeds.clone(),
282 gossip_interval: std::time::Duration::from_secs(1),
283 phi_threshold: 8.0,
284 dead_node_grace_period: std::time::Duration::from_secs(60),
285 cluster_id: "laminardb".to_string(),
286 node_id,
287 local_node: local_node.clone(),
288 };
289 DiscoveryImpl::Gossip(GossipDiscovery::new(gossip_config))
290 }
291 _ => {
292 let static_config = StaticDiscoveryConfig {
294 local_node: local_node.clone(),
295 seeds: cluster_cfg.discovery.seeds.clone(),
296 heartbeat_interval: std::time::Duration::from_secs(1),
297 suspect_threshold: 3,
298 dead_threshold: 10,
299 listen_address: format!("0.0.0.0:{}", cluster_cfg.discovery.gossip_port),
300 };
301 DiscoveryImpl::Static(StaticDiscovery::new(static_config))
302 }
303 };
304
305 discovery
306 .start()
307 .await
308 .map_err(|e| ClusterStartupError::Discovery(e.to_string()))?;
309 info!("Discovery layer started");
310
311 if cluster_cfg.discovery.seeds.is_empty() {
316 return Err(ClusterStartupError::Discovery(
317 "cluster mode requires at least one seed address".into(),
318 ));
319 }
320 let expected_peers = cluster_cfg.discovery.seeds.len().saturating_sub(1);
321 let deadline = std::time::Instant::now() + cluster_cfg.formation_timeout;
322 let mut last_seen = 0usize;
323 let peers: Vec<NodeInfo> = loop {
324 if let Ok(p) = discovery.peers().await {
325 last_seen = p.len();
326 if p.len() >= expected_peers {
327 break p;
328 }
329 }
330 if std::time::Instant::now() >= deadline {
331 return Err(ClusterStartupError::FormationTimeout {
332 found: last_seen,
333 needed: expected_peers,
334 });
335 }
336 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
337 };
338 info!(
339 "Discovered {} peer(s) (expected {})",
340 peers.len(),
341 expected_peers
342 );
343
344 let mut builder = LaminarDB::builder();
346 builder = builder.profile(Profile::Cluster);
347
348 if let Some(path) = config.state.local_storage_dir() {
349 builder = builder.storage_dir(path);
350 }
351
352 let state_backend: Arc<dyn laminar_core::state::StateBackend> = config
357 .state
358 .build()
359 .await
360 .map_err(|e| ClusterStartupError::EngineConstruction(format!("state backend: {e}")))?;
361 let vnode_count = config.state.vnode_capacity();
362
363 let (vnode_registry, snapshot_store) =
369 resolve_vnode_assignment(node_id, &peers, &config.state).await?;
370
371 let cluster_controller: Option<Arc<laminar_core::cluster::control::ClusterController>> =
374 if let DiscoveryImpl::Gossip(ref gossip) = discovery {
375 if let Some(handle) = gossip.chitchat_handle() {
376 use laminar_core::cluster::control::{ChitchatKv, ClusterController, ClusterKv};
377 let kv: Arc<dyn ClusterKv> = Arc::new(ChitchatKv::from_handle(handle));
378 let members_rx = discovery.membership_watch();
379 let controller = Arc::new(ClusterController::new(
380 node_id,
381 kv,
382 snapshot_store.clone(),
383 members_rx,
384 ));
385 info!(
386 "ClusterController installed (leader={})",
387 controller.is_leader()
388 );
389 builder = builder.cluster_controller(Arc::clone(&controller));
390 Some(controller)
391 } else {
392 None
393 }
394 } else {
395 info!(
396 "Static discovery — cluster control plane skipped \
397 (no chitchat KV). Leader/follower barrier protocol \
398 inactive in this mode."
399 );
400 None
401 };
402
403 let checkpoint_url = {
405 let base = &config.checkpoint.url;
406 if base.is_empty() {
407 String::new()
408 } else if base.ends_with('/') {
409 format!("{base}nodes/{node_id_str}/")
410 } else {
411 format!("{base}/nodes/{node_id_str}/")
412 }
413 };
414 builder = server::apply_checkpoint_config(builder, &checkpoint_url, &config.checkpoint);
415
416 builder = builder
417 .state_backend(Arc::clone(&state_backend))
418 .vnode_registry(Arc::clone(&vnode_registry));
419
420 if let Some(decision_os) = config
427 .state
428 .build_object_store()
429 .map_err(|e| ClusterStartupError::EngineConstruction(format!("decision store: {e}")))?
430 {
431 let decision_store =
432 Arc::new(laminar_core::cluster::control::CheckpointDecisionStore::new(decision_os));
433 builder = builder.decision_store(decision_store);
434 }
435
436 if let Some(snap_store) = snapshot_store.clone() {
443 builder = builder.assignment_snapshot_store(snap_store);
444 }
445
446 let shuffle_receiver = build_shuffle_receiver(&discovery, node_id).await?;
451 let shuffle_advertise = shuffle_advertise_addr(shuffle_receiver.local_addr(), bind_host);
452 let shuffle_sender =
453 Arc::new(build_shuffle_sender(node_id.0, &discovery, shuffle_advertise).await);
454
455 builder = builder
459 .shuffle_sender(Arc::clone(&shuffle_sender))
460 .shuffle_receiver(Arc::clone(&shuffle_receiver))
461 .target_partitions(vnode_count as usize);
462
463 let db = builder
464 .build()
465 .await
466 .map_err(|e| ClusterStartupError::EngineConstruction(e.to_string()))?;
467 let db = Arc::new(db);
468
469 let hostname = gethostname::gethostname().to_string_lossy().into_owned();
471 let pipeline_name = config
472 .pipelines
473 .first()
474 .map_or("default", |p| p.name.as_str())
475 .to_string();
476 let registry = Arc::new(crate::metrics::build_registry([
477 ("instance".into(), hostname),
478 ("pipeline".into(), pipeline_name),
479 ]));
480 let engine_metrics = Arc::new(laminar_db::EngineMetrics::new(®istry));
481 db.set_engine_metrics(engine_metrics);
482 db.set_prometheus_registry(Arc::clone(®istry));
483
484 server::execute_config_ddl(&db, &config)
485 .await
486 .map_err(|e| ClusterStartupError::EngineConstruction(e.to_string()))?;
487
488 db.start()
489 .await
490 .map_err(|e| ClusterStartupError::EngineConstruction(format!("pipeline start: {e}")))?;
491 info!("Pipeline started");
492
493 let rebalance_shutdown = Arc::new(tokio::sync::Notify::new());
497 let mut rebalance_tasks: Vec<tokio::task::JoinHandle<()>> = Vec::new();
498 if let (Some(snap_store), Some(controller)) =
499 (snapshot_store.clone(), cluster_controller.as_ref())
500 {
501 let cfg = laminar_db::rebalance::RebalanceConfig::default();
502 rebalance_tasks.push(laminar_db::rebalance::spawn_snapshot_watcher(
503 Arc::clone(&db),
504 Arc::clone(&snap_store),
505 Arc::clone(&vnode_registry),
506 Arc::clone(&rebalance_shutdown),
507 cfg,
508 ));
509 rebalance_tasks.push(laminar_db::rebalance::spawn_rebalance_controller(
510 Arc::clone(&db),
511 Arc::clone(controller),
512 snap_store,
513 Arc::clone(&vnode_registry),
514 Arc::clone(&rebalance_shutdown),
515 cfg,
516 ));
517 info!("Rebalance control plane started");
518 }
519
520 let (app_state, api_handle) =
521 server::start_http_api(Arc::clone(&db), registry, config_path.clone(), config)
522 .await
523 .map_err(|e| ClusterStartupError::HttpStartup(e.to_string()))?;
524 let watcher_handle = server::spawn_config_watcher(&app_state, config_path);
525
526 let membership_rx = discovery.membership_watch();
527 let membership_handle = spawn_membership_watcher(&node_id_str, membership_rx);
528 info!("Membership watcher started");
529
530 info!("Cluster node '{node_id_str}' started");
531
532 Ok(ClusterHandle {
533 db,
534 discovery,
535 api_handle,
536 watcher_handle,
537 membership_handle,
538 rebalance_tasks,
539 rebalance_shutdown,
540 })
541}
542
543fn num_cpus() -> u32 {
544 std::thread::available_parallelism()
545 .map(|n| n.get() as u32)
546 .unwrap_or(1)
547}
548
549async fn resolve_vnode_assignment(
562 self_id: laminar_core::cluster::discovery::NodeId,
563 peers: &[laminar_core::cluster::discovery::NodeInfo],
564 state_cfg: &laminar_core::state::StateBackendConfig,
565) -> Result<
566 (
567 Arc<laminar_core::state::VnodeRegistry>,
568 Option<Arc<laminar_core::cluster::control::AssignmentSnapshotStore>>,
569 ),
570 ClusterStartupError,
571> {
572 use laminar_core::cluster::control::{AssignmentSnapshot, AssignmentSnapshotStore};
573 use laminar_core::state::{round_robin_assignment, NodeId, VnodeRegistry};
574
575 let vnode_count = state_cfg.vnode_capacity();
576 let peer_ids: Vec<NodeId> = peers
577 .iter()
578 .map(|p| NodeId(p.id.0))
579 .chain(std::iter::once(NodeId(self_id.0)))
580 .collect();
581 let assignment: Arc<[NodeId]> = round_robin_assignment(vnode_count, &peer_ids);
582
583 let maybe_store = state_cfg
584 .build_object_store()
585 .map_err(|e| ClusterStartupError::EngineConstruction(format!("state object store: {e}")))?;
586 let Some(store) = maybe_store else {
587 let registry = VnodeRegistry::new(vnode_count);
589 registry.set_assignment(Arc::clone(&assignment));
590 return Ok((Arc::new(registry), None));
591 };
592 let snapshot_store = Arc::new(AssignmentSnapshotStore::new(store));
593
594 if let Some(existing) = snapshot_store
597 .load()
598 .await
599 .map_err(|e| ClusterStartupError::EngineConstruction(format!("snapshot load: {e}")))?
600 {
601 let registry = VnodeRegistry::new(vnode_count);
602 registry.set_assignment_and_version(
603 existing.to_vnode_vec(vnode_count).into(),
604 existing.version,
605 );
606 info!("Adopted existing assignment snapshot v{}", existing.version);
607 return Ok((Arc::new(registry), Some(snapshot_store)));
608 }
609
610 let proposal =
613 AssignmentSnapshot::empty().next(AssignmentSnapshot::vnodes_from_vec(&assignment));
614 let winner = match snapshot_store
615 .save_if_absent(&proposal)
616 .await
617 .map_err(|e| ClusterStartupError::EngineConstruction(format!("snapshot save: {e}")))?
618 {
619 Some(w) => {
620 info!("Created assignment snapshot v{}", w.version);
621 w
622 }
623 None => {
624 let w = snapshot_store
625 .load()
626 .await
627 .map_err(|e| {
628 ClusterStartupError::EngineConstruction(format!("snapshot re-load: {e}"))
629 })?
630 .ok_or_else(|| {
631 ClusterStartupError::EngineConstruction(
632 "snapshot CAS lost but re-load returned None".into(),
633 )
634 })?;
635 info!("Adopted snapshot v{} after CAS race", w.version);
636 w
637 }
638 };
639 let registry = VnodeRegistry::new(vnode_count);
640 registry.set_assignment_and_version(winner.to_vnode_vec(vnode_count).into(), winner.version);
641 Ok((Arc::new(registry), Some(snapshot_store)))
642}
643
644async fn build_shuffle_receiver(
648 discovery: &DiscoveryImpl,
649 node_id: NodeId,
650) -> Result<Arc<laminar_core::shuffle::ShuffleReceiver>, ClusterStartupError> {
651 use laminar_core::cluster::control::{ChitchatKv, ClusterKv};
652 use laminar_core::shuffle::ShuffleReceiver;
653
654 let bind: std::net::SocketAddr = "0.0.0.0:0".parse().unwrap();
655 let recv = if let DiscoveryImpl::Gossip(gossip) = discovery {
656 if let Some(handle) = gossip.chitchat_handle() {
657 let kv: Arc<dyn ClusterKv> = Arc::new(ChitchatKv::from_handle(handle));
658 ShuffleReceiver::bind_with_kv(node_id.0, bind, kv)
659 .await
660 .map_err(|e| {
661 ClusterStartupError::EngineConstruction(format!("shuffle bind: {e}"))
662 })?
663 } else {
664 ShuffleReceiver::bind(node_id.0, bind).await.map_err(|e| {
665 ClusterStartupError::EngineConstruction(format!("shuffle bind: {e}"))
666 })?
667 }
668 } else {
669 ShuffleReceiver::bind(node_id.0, bind)
670 .await
671 .map_err(|e| ClusterStartupError::EngineConstruction(format!("shuffle bind: {e}")))?
672 };
673 Ok(Arc::new(recv))
674}
675
676async fn build_shuffle_sender(
682 node_id: u64,
683 discovery: &DiscoveryImpl,
684 advertise_addr: String,
685) -> laminar_core::shuffle::ShuffleSender {
686 use laminar_core::cluster::control::{ChitchatKv, ClusterKv};
687 use laminar_core::shuffle::{ShuffleSender, SHUFFLE_ADDR_KEY};
688
689 let DiscoveryImpl::Gossip(gossip) = discovery else {
690 return ShuffleSender::new(node_id);
691 };
692 let Some(handle) = gossip.chitchat_handle() else {
693 return ShuffleSender::new(node_id);
694 };
695 let kv: Arc<dyn ClusterKv> = Arc::new(ChitchatKv::from_handle(handle));
696 kv.write(SHUFFLE_ADDR_KEY, advertise_addr).await;
697 ShuffleSender::with_kv(node_id, kv)
698}
699
700fn shuffle_advertise_addr(local_addr: std::net::SocketAddr, bind_host: &str) -> String {
708 let port = local_addr.port();
709 let host = bind_host.trim_start_matches('[').trim_end_matches(']');
710 let ip_wildcard = host == "0.0.0.0" || host == "::" || host.is_empty();
711 if !ip_wildcard {
712 return format!("{bind_host}:{port}");
713 }
714 let hostname = gethostname::gethostname();
715 let hostname = hostname.to_string_lossy();
716 if hostname.is_empty() {
717 local_addr.to_string()
718 } else {
719 format!("{hostname}:{port}")
720 }
721}
722
723#[cfg(test)]
724mod tests {
725 use super::*;
726
727 #[test]
728 fn test_cluster_startup_error_display() {
729 let errors: Vec<ClusterStartupError> = vec![
730 ClusterStartupError::Discovery("connection refused".into()),
731 ClusterStartupError::FormationTimeout {
732 found: 1,
733 needed: 3,
734 },
735 ClusterStartupError::EngineConstruction("build failed".into()),
736 ClusterStartupError::HttpStartup("port in use".into()),
737 ];
738 for err in &errors {
739 assert!(!err.to_string().is_empty());
740 }
741 }
742
743 #[test]
744 fn test_formation_timeout_includes_counts() {
745 let err = ClusterStartupError::FormationTimeout {
746 found: 1,
747 needed: 3,
748 };
749 let msg = err.to_string();
750 assert!(msg.contains('1'));
751 assert!(msg.contains('3'));
752 }
753}