1#![allow(clippy::disallowed_types)] use std::collections::HashMap;
20use std::sync::Arc;
21use std::time::Duration;
22
23use parking_lot::RwLock;
24use tokio::net::{TcpListener, TcpStream};
25use tokio::sync::watch;
26use tokio_util::sync::CancellationToken;
27
28use super::{Discovery, DiscoveryError, NodeId, NodeInfo, NodeMetadata, NodeState};
29
30const CONNECT_TIMEOUT: Duration = Duration::from_secs(2);
32
33const IO_TIMEOUT: Duration = Duration::from_secs(5);
35
36const MAX_HANDLER_TASKS: usize = 64;
38
39const MAX_MESSAGE_SIZE: usize = 1_048_576;
41
42const LEFT_REAP_THRESHOLD: u32 = 30;
44
45#[derive(Debug, Clone)]
47pub struct StaticDiscoveryConfig {
48 pub local_node: NodeInfo,
50 pub seeds: Vec<String>,
52 pub heartbeat_interval: Duration,
54 pub suspect_threshold: u32,
56 pub dead_threshold: u32,
58 pub listen_address: String,
60}
61
62impl Default for StaticDiscoveryConfig {
63 fn default() -> Self {
64 Self {
65 local_node: NodeInfo {
66 id: NodeId(1),
67 name: "node-1".into(),
68 rpc_address: "127.0.0.1:9000".into(),
69 raft_address: "127.0.0.1:9001".into(),
70 state: NodeState::Active,
71 metadata: NodeMetadata::default(),
72 last_heartbeat_ms: 0,
73 },
74 seeds: Vec::new(),
75 heartbeat_interval: Duration::from_secs(1),
76 suspect_threshold: 3,
77 dead_threshold: 10,
78 listen_address: "127.0.0.1:9002".into(),
79 }
80 }
81}
82
83#[derive(Debug)]
85struct PeerState {
86 info: NodeInfo,
87 missed_heartbeats: u32,
89 left_ticks: u32,
91}
92
93#[derive(Debug)]
95pub struct StaticDiscovery {
96 config: StaticDiscoveryConfig,
97 peers: Arc<RwLock<HashMap<u64, PeerState>>>,
98 #[allow(dead_code)]
99 membership_tx: watch::Sender<Vec<NodeInfo>>,
100 membership_rx: watch::Receiver<Vec<NodeInfo>>,
101 cancel: CancellationToken,
102 listener_handle: Option<tokio::task::JoinHandle<Result<(), DiscoveryError>>>,
103 heartbeater_handle: Option<tokio::task::JoinHandle<()>>,
104 started: bool,
105}
106
107impl StaticDiscovery {
108 #[must_use]
110 pub fn new(config: StaticDiscoveryConfig) -> Self {
111 debug_assert!(
112 config.suspect_threshold < config.dead_threshold,
113 "suspect_threshold ({}) must be less than dead_threshold ({})",
114 config.suspect_threshold,
115 config.dead_threshold,
116 );
117 let (tx, rx) = watch::channel(Vec::new());
118 Self {
119 config,
120 peers: Arc::new(RwLock::new(HashMap::new())),
121 membership_tx: tx,
122 membership_rx: rx,
123 cancel: CancellationToken::new(),
124 listener_handle: None,
125 heartbeater_handle: None,
126 started: false,
127 }
128 }
129
130 fn serialize_node_info(info: &NodeInfo) -> Result<Vec<u8>, DiscoveryError> {
132 rkyv::to_bytes::<rkyv::rancor::Error>(info)
133 .map(|v| v.to_vec())
134 .map_err(|e| DiscoveryError::Serialization(e.to_string()))
135 }
136
137 fn deserialize_node_info(data: &[u8]) -> Result<NodeInfo, DiscoveryError> {
139 rkyv::from_bytes::<NodeInfo, rkyv::rancor::Error>(data)
140 .map_err(|e| DiscoveryError::Serialization(e.to_string()))
141 }
142
143 fn broadcast_membership(&self) {
145 let peers = self.peers.read();
146 let peer_list: Vec<NodeInfo> = peers.values().map(|p| p.info.clone()).collect();
147 let _ = self.membership_tx.send(peer_list);
148 }
149
150 #[allow(clippy::cast_possible_truncation)]
152 async fn send_heartbeat(address: &str, data: &[u8]) -> Result<Option<Vec<u8>>, DiscoveryError> {
153 use tokio::io::{AsyncReadExt, AsyncWriteExt};
154
155 let mut stream = tokio::time::timeout(CONNECT_TIMEOUT, TcpStream::connect(address))
157 .await
158 .map_err(|_| DiscoveryError::Connection {
159 address: address.into(),
160 reason: "connect timeout".into(),
161 })?
162 .map_err(|e| DiscoveryError::Connection {
163 address: address.into(),
164 reason: e.to_string(),
165 })?;
166
167 if data.len() > MAX_MESSAGE_SIZE {
169 return Err(DiscoveryError::Serialization(
170 "message too large to send".into(),
171 ));
172 }
173
174 let len = data.len() as u32;
176 tokio::time::timeout(IO_TIMEOUT, async {
177 stream.write_all(&len.to_be_bytes()).await?;
178 stream.write_all(data).await
179 })
180 .await
181 .map_err(|_| DiscoveryError::Connection {
182 address: address.into(),
183 reason: "write timeout".into(),
184 })?
185 .map_err(|e| DiscoveryError::Connection {
186 address: address.into(),
187 reason: e.to_string(),
188 })?;
189
190 let resp = tokio::time::timeout(IO_TIMEOUT, async {
192 let mut len_buf = [0u8; 4];
193 if stream.read_exact(&mut len_buf).await.is_err() {
194 return Ok(None);
195 }
196
197 let resp_len = u32::from_be_bytes(len_buf) as usize;
198 if resp_len > MAX_MESSAGE_SIZE {
199 return Err(DiscoveryError::Serialization("response too large".into()));
200 }
201 let mut resp = vec![0u8; resp_len];
202 stream.read_exact(&mut resp).await?;
203 Ok(Some(resp))
204 })
205 .await
206 .map_err(|_| DiscoveryError::Connection {
207 address: address.into(),
208 reason: "read timeout".into(),
209 })?;
210
211 resp.map_err(|e: DiscoveryError| e)
212 }
213
214 #[allow(clippy::cast_possible_truncation)]
220 async fn run_listener(
221 listen_address: String,
222 local_info: NodeInfo,
223 peers: Arc<RwLock<HashMap<u64, PeerState>>>,
224 membership_tx: watch::Sender<Vec<NodeInfo>>,
225 cancel: CancellationToken,
226 ) -> Result<(), DiscoveryError> {
227 use tokio::io::{AsyncReadExt, AsyncWriteExt};
228
229 let listener = TcpListener::bind(&listen_address)
230 .await
231 .map_err(|e| DiscoveryError::Bind(e.to_string()))?;
232
233 let semaphore = Arc::new(tokio::sync::Semaphore::new(MAX_HANDLER_TASKS));
235
236 loop {
237 tokio::select! {
238 () = cancel.cancelled() => break,
239 accept = listener.accept() => {
240 let (mut stream, _) = accept?;
241 let local_info = local_info.clone();
242 let peers = Arc::clone(&peers);
243 let membership_tx = membership_tx.clone();
244 let permit = Arc::clone(&semaphore);
245
246 tokio::spawn(async move {
247 let Ok(_permit) = permit.try_acquire() else {
249 return; };
251
252 let result = tokio::time::timeout(IO_TIMEOUT, async {
254 let mut len_buf = [0u8; 4];
255 if stream.read_exact(&mut len_buf).await.is_err() {
256 return;
257 }
258 let msg_len = u32::from_be_bytes(len_buf) as usize;
259 if msg_len > MAX_MESSAGE_SIZE {
260 return;
261 }
262 let mut data = vec![0u8; msg_len];
263 if stream.read_exact(&mut data).await.is_err() {
264 return;
265 }
266
267 if let Ok(remote_info) = Self::deserialize_node_info(&data) {
268 if remote_info.id == local_info.id {
270 if let Ok(resp) = Self::serialize_node_info(&local_info) {
272 let len = resp.len() as u32;
273 let _ = stream.write_all(&len.to_be_bytes()).await;
274 let _ = stream.write_all(&resp).await;
275 }
276 return;
277 }
278
279 let peer_list = {
280 let mut guard = peers.write();
281 let now = chrono::Utc::now().timestamp_millis();
282 let peer =
283 guard.entry(remote_info.id.0).or_insert_with(|| {
284 PeerState {
288 info: NodeInfo {
289 last_heartbeat_ms: now,
290 state: NodeState::Joining,
291 ..remote_info.clone()
292 },
293 missed_heartbeats: 0,
294 left_ticks: 0,
295 }
296 });
297 peer.info.rpc_address.clone_from(&remote_info.rpc_address);
301 peer.info.raft_address.clone_from(&remote_info.raft_address);
302 peer.info.name.clone_from(&remote_info.name);
303 peer.info.metadata = remote_info.metadata.clone();
304 peer.info.last_heartbeat_ms = now;
305 guard.values().map(|p| p.info.clone()).collect::<Vec<_>>()
306 };
307 let _ = membership_tx.send(peer_list);
308 }
309
310 if let Ok(resp) = Self::serialize_node_info(&local_info) {
311 let len = resp.len() as u32;
312 let _ = stream.write_all(&len.to_be_bytes()).await;
313 let _ = stream.write_all(&resp).await;
314 }
315 })
316 .await;
317
318 if result.is_err() {
319 }
321 });
322 }
323 }
324 }
325
326 Ok(())
327 }
328
329 async fn run_heartbeater(config: StaticDiscoveryConfig, ctx: HeartbeatContext) {
335 let local_id = config.local_node.id;
336 let mut interval = tokio::time::interval(config.heartbeat_interval);
337 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
339
340 let seed_to_peer = Arc::new(parking_lot::Mutex::new(HashMap::<String, u64>::new()));
343
344 loop {
345 tokio::select! {
346 () = ctx.cancel.cancelled() => break,
347 _ = interval.tick() => {
348 let Ok(data) = Self::serialize_node_info(&config.local_node) else {
349 continue;
350 };
351 let data = Arc::new(data);
352
353 let mut tasks = Vec::with_capacity(config.seeds.len());
355 for seed in &config.seeds {
356 let seed = seed.clone();
357 let data = Arc::clone(&data);
358 tasks.push(tokio::spawn(async move {
359 let result = Self::send_heartbeat(&seed, &data).await;
360 (seed, result)
361 }));
362 }
363
364 for task in tasks {
366 let Ok((seed, result)) = task.await else {
367 continue; };
369
370 if let Ok(Some(resp_data)) = result {
371 if let Ok(remote_info) = Self::deserialize_node_info(&resp_data) {
372 if remote_info.id == local_id {
374 continue;
375 }
376
377 seed_to_peer.lock().insert(seed, remote_info.id.0);
379
380 let mut peers = ctx.peers.write();
381 let now = chrono::Utc::now().timestamp_millis();
382 let peer =
383 peers.entry(remote_info.id.0).or_insert_with(|| {
384 PeerState {
385 info: remote_info.clone(),
386 missed_heartbeats: 0,
387 left_ticks: 0,
388 }
389 });
390 peer.info = NodeInfo {
391 last_heartbeat_ms: now,
392 state: NodeState::Active,
393 ..remote_info
394 };
395 peer.missed_heartbeats = 0;
396 peer.left_ticks = 0;
397 }
398 } else {
399 let map = seed_to_peer.lock();
401 if let Some(&peer_id) = map.get(seed.as_str()) {
402 drop(map);
403 let mut peers = ctx.peers.write();
404 if let Some(peer) = peers.get_mut(&peer_id) {
405 peer.missed_heartbeats += 1;
406 if peer.missed_heartbeats >= config.dead_threshold {
407 peer.info.state = NodeState::Left;
408 } else if peer.missed_heartbeats
409 >= config.suspect_threshold
410 {
411 peer.info.state = NodeState::Suspected;
412 }
413 }
414 }
415 }
416 }
417
418 {
420 let mut peers = ctx.peers.write();
421 peers.retain(|_id, peer| {
422 if peer.info.state == NodeState::Left {
423 peer.left_ticks += 1;
424 peer.left_ticks < LEFT_REAP_THRESHOLD
425 } else {
426 true
427 }
428 });
429 }
430
431 {
433 let peers = ctx.peers.read();
434 let mut map = seed_to_peer.lock();
435 map.retain(|_, peer_id| peers.contains_key(peer_id));
436 }
437
438 let peer_list: Vec<NodeInfo> = {
439 let peers = ctx.peers.read();
440 peers.values().map(|p| p.info.clone()).collect()
441 };
442 let _ = ctx.membership_tx.send(peer_list);
443 }
444 }
445 }
446 }
447}
448
449struct HeartbeatContext {
451 peers: Arc<RwLock<HashMap<u64, PeerState>>>,
452 membership_tx: watch::Sender<Vec<NodeInfo>>,
453 cancel: CancellationToken,
454}
455
456impl Discovery for StaticDiscovery {
457 async fn start(&mut self) -> Result<(), DiscoveryError> {
458 if self.started {
459 return Ok(());
460 }
461
462 self.cancel = CancellationToken::new();
464
465 let peers = Arc::clone(&self.peers);
466 let membership_tx = self.membership_tx.clone();
467 let cancel = self.cancel.clone();
468 let listen_address = self.config.listen_address.clone();
469 let local_info = self.config.local_node.clone();
470
471 self.listener_handle = Some(tokio::spawn(Self::run_listener(
473 listen_address,
474 local_info,
475 Arc::clone(&peers),
476 membership_tx.clone(),
477 cancel.clone(),
478 )));
479
480 self.heartbeater_handle = Some(tokio::spawn(Self::run_heartbeater(
482 self.config.clone(),
483 HeartbeatContext {
484 peers,
485 membership_tx,
486 cancel,
487 },
488 )));
489
490 self.started = true;
491 Ok(())
492 }
493
494 async fn peers(&self) -> Result<Vec<NodeInfo>, DiscoveryError> {
495 if !self.started {
496 return Err(DiscoveryError::NotStarted);
497 }
498 let peers = self.peers.read();
499 Ok(peers.values().map(|p| p.info.clone()).collect())
500 }
501
502 async fn announce(&self, info: NodeInfo) -> Result<(), DiscoveryError> {
503 if !self.started {
504 return Err(DiscoveryError::NotStarted);
505 }
506 {
507 let mut peers = self.peers.write();
508 peers.insert(
509 info.id.0,
510 PeerState {
511 info,
512 missed_heartbeats: 0,
513 left_ticks: 0,
514 },
515 );
516 }
517 self.broadcast_membership();
518 Ok(())
519 }
520
521 fn membership_watch(&self) -> watch::Receiver<Vec<NodeInfo>> {
522 self.membership_rx.clone()
523 }
524
525 async fn stop(&mut self) -> Result<(), DiscoveryError> {
526 self.cancel.cancel();
527 self.started = false;
528
529 if let Some(h) = self.listener_handle.take() {
531 let _ = h.await;
532 }
533 if let Some(h) = self.heartbeater_handle.take() {
534 let _ = h.await;
535 }
536
537 Ok(())
538 }
539}
540
541#[cfg(test)]
542mod tests {
543 use super::*;
544
545 #[test]
546 fn test_config_default() {
547 let config = StaticDiscoveryConfig::default();
548 assert_eq!(config.heartbeat_interval, Duration::from_secs(1));
549 assert_eq!(config.suspect_threshold, 3);
550 assert_eq!(config.dead_threshold, 10);
551 }
552
553 #[test]
554 fn test_serialize_round_trip() {
555 let info = NodeInfo {
556 id: NodeId(42),
557 name: "test".into(),
558 rpc_address: "127.0.0.1:9000".into(),
559 raft_address: "127.0.0.1:9001".into(),
560 state: NodeState::Active,
561 metadata: NodeMetadata::default(),
562 last_heartbeat_ms: 1000,
563 };
564
565 let data = StaticDiscovery::serialize_node_info(&info).unwrap();
566 let back = StaticDiscovery::deserialize_node_info(&data).unwrap();
567 assert_eq!(back.id, NodeId(42));
568 assert_eq!(back.name, "test");
569 }
570
571 #[test]
572 fn test_deserialize_invalid() {
573 let result = StaticDiscovery::deserialize_node_info(&[0xff, 0xff]);
574 assert!(result.is_err());
575 }
576
577 #[tokio::test]
578 async fn test_not_started_errors() {
579 let config = StaticDiscoveryConfig::default();
580 let disc = StaticDiscovery::new(config);
581 assert!(disc.peers().await.is_err());
582 }
583
584 #[tokio::test]
585 async fn test_start_stop() {
586 let config = StaticDiscoveryConfig {
587 listen_address: "127.0.0.1:0".into(),
588 ..StaticDiscoveryConfig::default()
589 };
590 let mut disc = StaticDiscovery::new(config);
591 disc.start().await.unwrap();
592 assert!(disc.started);
593 disc.stop().await.unwrap();
594 assert!(!disc.started);
595 }
596
597 #[tokio::test]
598 async fn test_double_start_ok() {
599 let config = StaticDiscoveryConfig {
600 listen_address: "127.0.0.1:0".into(),
601 ..StaticDiscoveryConfig::default()
602 };
603 let mut disc = StaticDiscovery::new(config);
604 disc.start().await.unwrap();
605 disc.start().await.unwrap(); disc.stop().await.unwrap();
607 }
608
609 #[tokio::test]
610 async fn test_membership_watch() {
611 let config = StaticDiscoveryConfig {
612 listen_address: "127.0.0.1:0".into(),
613 ..StaticDiscoveryConfig::default()
614 };
615 let disc = StaticDiscovery::new(config);
616 let rx = disc.membership_watch();
617 assert!(rx.borrow().is_empty());
618 }
619
620 #[tokio::test]
621 async fn test_announce_adds_peer() {
622 let config = StaticDiscoveryConfig {
623 listen_address: "127.0.0.1:0".into(),
624 ..StaticDiscoveryConfig::default()
625 };
626 let mut disc = StaticDiscovery::new(config);
627 disc.start().await.unwrap();
628
629 let peer = NodeInfo {
630 id: NodeId(99),
631 name: "peer".into(),
632 rpc_address: "127.0.0.1:8000".into(),
633 raft_address: "127.0.0.1:8001".into(),
634 state: NodeState::Active,
635 metadata: NodeMetadata::default(),
636 last_heartbeat_ms: 0,
637 };
638 disc.announce(peer).await.unwrap();
639
640 let peers = disc.peers().await.unwrap();
641 assert_eq!(peers.len(), 1);
642 assert_eq!(peers[0].id, NodeId(99));
643
644 disc.stop().await.unwrap();
645 }
646
647 #[tokio::test]
648 async fn test_two_node_heartbeat() {
649 let listener1 = TcpListener::bind("127.0.0.1:0").await.unwrap();
650 let addr1 = listener1.local_addr().unwrap().to_string();
651 drop(listener1);
652
653 let listener2 = TcpListener::bind("127.0.0.1:0").await.unwrap();
654 let addr2 = listener2.local_addr().unwrap().to_string();
655 drop(listener2);
656
657 let config1 = StaticDiscoveryConfig {
658 local_node: NodeInfo {
659 id: NodeId(1),
660 name: "node-1".into(),
661 rpc_address: addr1.clone(),
662 raft_address: addr1.clone(),
663 state: NodeState::Active,
664 metadata: NodeMetadata::default(),
665 last_heartbeat_ms: 0,
666 },
667 seeds: vec![addr2.clone()],
668 heartbeat_interval: Duration::from_millis(100),
669 listen_address: addr1.clone(),
670 ..StaticDiscoveryConfig::default()
671 };
672
673 let config2 = StaticDiscoveryConfig {
674 local_node: NodeInfo {
675 id: NodeId(2),
676 name: "node-2".into(),
677 rpc_address: addr2.clone(),
678 raft_address: addr2.clone(),
679 state: NodeState::Active,
680 metadata: NodeMetadata::default(),
681 last_heartbeat_ms: 0,
682 },
683 seeds: vec![addr1],
684 heartbeat_interval: Duration::from_millis(100),
685 listen_address: addr2,
686 ..StaticDiscoveryConfig::default()
687 };
688
689 let mut disc1 = StaticDiscovery::new(config1);
690 let mut disc2 = StaticDiscovery::new(config2);
691
692 disc1.start().await.unwrap();
693 disc2.start().await.unwrap();
694
695 tokio::time::sleep(Duration::from_millis(500)).await;
696
697 let peers1 = disc1.peers().await.unwrap();
698 let peers2 = disc2.peers().await.unwrap();
699
700 assert!(
701 !peers1.is_empty() || !peers2.is_empty(),
702 "at least one node should have discovered peers"
703 );
704
705 disc1.stop().await.unwrap();
706 disc2.stop().await.unwrap();
707 }
708
709 #[tokio::test]
710 async fn test_restart_after_stop() {
711 let config = StaticDiscoveryConfig {
712 listen_address: "127.0.0.1:0".into(),
713 ..StaticDiscoveryConfig::default()
714 };
715 let mut disc = StaticDiscovery::new(config);
716
717 disc.start().await.unwrap();
719 disc.stop().await.unwrap();
720
721 disc.start().await.unwrap();
723 assert!(disc.started);
724 disc.stop().await.unwrap();
725 }
726}