1#![allow(clippy::disallowed_types)] use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::Duration;
7
8use parking_lot::RwLock;
9use tokio::net::{TcpListener, TcpStream};
10use tokio::sync::watch;
11use tokio_util::sync::CancellationToken;
12
13use super::{Discovery, DiscoveryError, NodeId, NodeInfo, NodeMetadata, NodeState};
14
15const CONNECT_TIMEOUT: Duration = Duration::from_secs(2);
17
18const IO_TIMEOUT: Duration = Duration::from_secs(5);
20
21const MAX_HANDLER_TASKS: usize = 64;
23
24const MAX_MESSAGE_SIZE: usize = 1_048_576;
26
27const LEFT_REAP_THRESHOLD: u32 = 30;
29
30#[derive(Debug, Clone)]
32pub struct StaticDiscoveryConfig {
33 pub local_node: NodeInfo,
35 pub seeds: Vec<String>,
37 pub heartbeat_interval: Duration,
39 pub suspect_threshold: u32,
41 pub dead_threshold: u32,
43 pub listen_address: String,
45}
46
47impl Default for StaticDiscoveryConfig {
48 fn default() -> Self {
49 Self {
50 local_node: NodeInfo {
51 id: NodeId(1),
52 name: "node-1".into(),
53 rpc_address: "127.0.0.1:9000".into(),
54 raft_address: "127.0.0.1:9001".into(),
55 state: NodeState::Active,
56 metadata: NodeMetadata::default(),
57 last_heartbeat_ms: 0,
58 },
59 seeds: Vec::new(),
60 heartbeat_interval: Duration::from_secs(1),
61 suspect_threshold: 3,
62 dead_threshold: 10,
63 listen_address: "127.0.0.1:9002".into(),
64 }
65 }
66}
67
68#[derive(Debug)]
70struct PeerState {
71 info: NodeInfo,
72 missed_heartbeats: u32,
74 left_ticks: u32,
76}
77
78#[derive(Debug)]
80pub struct StaticDiscovery {
81 config: StaticDiscoveryConfig,
82 peers: Arc<RwLock<HashMap<u64, PeerState>>>,
83 membership_tx: watch::Sender<Vec<NodeInfo>>,
84 membership_rx: watch::Receiver<Vec<NodeInfo>>,
85 cancel: CancellationToken,
86 listener_handle: Option<tokio::task::JoinHandle<Result<(), DiscoveryError>>>,
87 heartbeater_handle: Option<tokio::task::JoinHandle<()>>,
88 started: bool,
89}
90
91impl StaticDiscovery {
92 #[must_use]
94 pub fn new(config: StaticDiscoveryConfig) -> Self {
95 debug_assert!(
96 config.suspect_threshold < config.dead_threshold,
97 "suspect_threshold ({}) must be less than dead_threshold ({})",
98 config.suspect_threshold,
99 config.dead_threshold,
100 );
101 let (tx, rx) = watch::channel(Vec::new());
102 Self {
103 config,
104 peers: Arc::new(RwLock::new(HashMap::new())),
105 membership_tx: tx,
106 membership_rx: rx,
107 cancel: CancellationToken::new(),
108 listener_handle: None,
109 heartbeater_handle: None,
110 started: false,
111 }
112 }
113
114 fn serialize_node_info(info: &NodeInfo) -> Result<Vec<u8>, DiscoveryError> {
116 rkyv::to_bytes::<rkyv::rancor::Error>(info)
117 .map(|v| v.to_vec())
118 .map_err(|e| DiscoveryError::Serialization(e.to_string()))
119 }
120
121 fn deserialize_node_info(data: &[u8]) -> Result<NodeInfo, DiscoveryError> {
123 rkyv::from_bytes::<NodeInfo, rkyv::rancor::Error>(data)
124 .map_err(|e| DiscoveryError::Serialization(e.to_string()))
125 }
126
127 fn broadcast_membership(&self) {
129 let peers = self.peers.read();
130 let peer_list: Vec<NodeInfo> = peers.values().map(|p| p.info.clone()).collect();
131 publish_if_changed(&self.membership_tx, peer_list);
132 }
133
134 #[allow(clippy::cast_possible_truncation)]
136 async fn send_heartbeat(address: &str, data: &[u8]) -> Result<Option<Vec<u8>>, DiscoveryError> {
137 use tokio::io::{AsyncReadExt, AsyncWriteExt};
138
139 let mut stream = tokio::time::timeout(CONNECT_TIMEOUT, TcpStream::connect(address))
141 .await
142 .map_err(|_| DiscoveryError::Connection {
143 address: address.into(),
144 reason: "connect timeout".into(),
145 })?
146 .map_err(|e| DiscoveryError::Connection {
147 address: address.into(),
148 reason: e.to_string(),
149 })?;
150
151 if data.len() > MAX_MESSAGE_SIZE {
153 return Err(DiscoveryError::Serialization(
154 "message too large to send".into(),
155 ));
156 }
157
158 let len = data.len() as u32;
160 tokio::time::timeout(IO_TIMEOUT, async {
161 stream.write_all(&len.to_be_bytes()).await?;
162 stream.write_all(data).await
163 })
164 .await
165 .map_err(|_| DiscoveryError::Connection {
166 address: address.into(),
167 reason: "write timeout".into(),
168 })?
169 .map_err(|e| DiscoveryError::Connection {
170 address: address.into(),
171 reason: e.to_string(),
172 })?;
173
174 let resp = tokio::time::timeout(IO_TIMEOUT, async {
176 let mut len_buf = [0u8; 4];
177 if stream.read_exact(&mut len_buf).await.is_err() {
178 return Ok(None);
179 }
180
181 let resp_len = u32::from_be_bytes(len_buf) as usize;
182 if resp_len > MAX_MESSAGE_SIZE {
183 return Err(DiscoveryError::Serialization("response too large".into()));
184 }
185 let mut resp = vec![0u8; resp_len];
186 stream.read_exact(&mut resp).await?;
187 Ok(Some(resp))
188 })
189 .await
190 .map_err(|_| DiscoveryError::Connection {
191 address: address.into(),
192 reason: "read timeout".into(),
193 })?;
194
195 resp.map_err(|e: DiscoveryError| e)
196 }
197
198 #[allow(clippy::cast_possible_truncation)]
204 async fn run_listener(
205 listen_address: String,
206 local_info: NodeInfo,
207 peers: Arc<RwLock<HashMap<u64, PeerState>>>,
208 membership_tx: watch::Sender<Vec<NodeInfo>>,
209 cancel: CancellationToken,
210 ) -> Result<(), DiscoveryError> {
211 use tokio::io::{AsyncReadExt, AsyncWriteExt};
212
213 let listener = TcpListener::bind(&listen_address)
214 .await
215 .map_err(|e| DiscoveryError::Bind(e.to_string()))?;
216
217 let semaphore = Arc::new(tokio::sync::Semaphore::new(MAX_HANDLER_TASKS));
219
220 loop {
221 tokio::select! {
222 () = cancel.cancelled() => break,
223 accept = listener.accept() => {
224 let (mut stream, _) = accept?;
225 let local_info = local_info.clone();
226 let peers = Arc::clone(&peers);
227 let membership_tx = membership_tx.clone();
228 let permit = Arc::clone(&semaphore);
229
230 tokio::spawn(async move {
231 let Ok(_permit) = permit.try_acquire() else {
233 return; };
235
236 let result = tokio::time::timeout(IO_TIMEOUT, async {
238 let mut len_buf = [0u8; 4];
239 if stream.read_exact(&mut len_buf).await.is_err() {
240 return;
241 }
242 let msg_len = u32::from_be_bytes(len_buf) as usize;
243 if msg_len > MAX_MESSAGE_SIZE {
244 return;
245 }
246 let mut data = vec![0u8; msg_len];
247 if stream.read_exact(&mut data).await.is_err() {
248 return;
249 }
250
251 if let Ok(remote_info) = Self::deserialize_node_info(&data) {
252 if remote_info.id == local_info.id {
254 if let Ok(resp) = Self::serialize_node_info(&local_info) {
256 let len = resp.len() as u32;
257 let _ = stream.write_all(&len.to_be_bytes()).await;
258 let _ = stream.write_all(&resp).await;
259 }
260 return;
261 }
262
263 let peer_list = {
264 let mut guard = peers.write();
265 let now = chrono::Utc::now().timestamp_millis();
266 let peer =
267 guard.entry(remote_info.id.0).or_insert_with(|| {
268 PeerState {
272 info: NodeInfo {
273 last_heartbeat_ms: now,
274 state: NodeState::Joining,
275 ..remote_info.clone()
276 },
277 missed_heartbeats: 0,
278 left_ticks: 0,
279 }
280 });
281 peer.info.rpc_address.clone_from(&remote_info.rpc_address);
285 peer.info.raft_address.clone_from(&remote_info.raft_address);
286 peer.info.name.clone_from(&remote_info.name);
287 peer.info.metadata = remote_info.metadata.clone();
288 peer.info.last_heartbeat_ms = now;
289 guard.values().map(|p| p.info.clone()).collect::<Vec<_>>()
290 };
291 publish_if_changed(&membership_tx, peer_list);
292 }
293
294 if let Ok(resp) = Self::serialize_node_info(&local_info) {
295 let len = resp.len() as u32;
296 let _ = stream.write_all(&len.to_be_bytes()).await;
297 let _ = stream.write_all(&resp).await;
298 }
299 })
300 .await;
301
302 if result.is_err() {
303 }
305 });
306 }
307 }
308 }
309
310 Ok(())
311 }
312
313 async fn run_heartbeater(config: StaticDiscoveryConfig, ctx: HeartbeatContext) {
319 let local_id = config.local_node.id;
320 let mut interval = tokio::time::interval(config.heartbeat_interval);
321 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
323
324 let seed_to_peer = Arc::new(parking_lot::Mutex::new(HashMap::<String, u64>::new()));
327
328 loop {
329 tokio::select! {
330 () = ctx.cancel.cancelled() => break,
331 _ = interval.tick() => {
332 let Ok(data) = Self::serialize_node_info(&config.local_node) else {
333 continue;
334 };
335 let data = Arc::new(data);
336
337 let mut tasks = Vec::with_capacity(config.seeds.len());
339 for seed in &config.seeds {
340 let seed = seed.clone();
341 let data = Arc::clone(&data);
342 tasks.push(tokio::spawn(async move {
343 let result = Self::send_heartbeat(&seed, &data).await;
344 (seed, result)
345 }));
346 }
347
348 for task in tasks {
350 let Ok((seed, result)) = task.await else {
351 continue; };
353
354 if let Ok(Some(resp_data)) = result {
355 if let Ok(remote_info) = Self::deserialize_node_info(&resp_data) {
356 if remote_info.id == local_id {
358 continue;
359 }
360
361 seed_to_peer.lock().insert(seed, remote_info.id.0);
363
364 let mut peers = ctx.peers.write();
365 let now = chrono::Utc::now().timestamp_millis();
366 let peer =
367 peers.entry(remote_info.id.0).or_insert_with(|| {
368 PeerState {
369 info: remote_info.clone(),
370 missed_heartbeats: 0,
371 left_ticks: 0,
372 }
373 });
374 peer.info = NodeInfo {
375 last_heartbeat_ms: now,
376 state: NodeState::Active,
377 ..remote_info
378 };
379 peer.missed_heartbeats = 0;
380 peer.left_ticks = 0;
381 }
382 } else {
383 let map = seed_to_peer.lock();
385 if let Some(&peer_id) = map.get(seed.as_str()) {
386 drop(map);
387 let mut peers = ctx.peers.write();
388 if let Some(peer) = peers.get_mut(&peer_id) {
389 peer.missed_heartbeats += 1;
390 if peer.missed_heartbeats >= config.dead_threshold {
391 peer.info.state = NodeState::Left;
392 } else if peer.missed_heartbeats
393 >= config.suspect_threshold
394 {
395 peer.info.state = NodeState::Suspected;
396 }
397 }
398 }
399 }
400 }
401
402 {
404 let mut peers = ctx.peers.write();
405 peers.retain(|_id, peer| {
406 if peer.info.state == NodeState::Left {
407 peer.left_ticks += 1;
408 peer.left_ticks < LEFT_REAP_THRESHOLD
409 } else {
410 true
411 }
412 });
413 }
414
415 {
417 let peers = ctx.peers.read();
418 let mut map = seed_to_peer.lock();
419 map.retain(|_, peer_id| peers.contains_key(peer_id));
420 }
421
422 let peer_list: Vec<NodeInfo> = {
423 let peers = ctx.peers.read();
424 peers.values().map(|p| p.info.clone()).collect()
425 };
426 publish_if_changed(&ctx.membership_tx, peer_list);
427 }
428 }
429 }
430 }
431}
432
433fn publish_if_changed(tx: &watch::Sender<Vec<NodeInfo>>, mut peer_list: Vec<NodeInfo>) {
440 peer_list.sort_by_key(|n| n.id.0);
443 tx.send_if_modified(|cur| {
444 fn same_member(a: &NodeInfo, b: &NodeInfo) -> bool {
450 let mut a = a.clone();
451 a.last_heartbeat_ms = b.last_heartbeat_ms;
452 a == *b
453 }
454 let same = cur.len() == peer_list.len()
455 && cur.iter().zip(&peer_list).all(|(a, b)| same_member(a, b));
456 if same {
457 false
458 } else {
459 *cur = peer_list;
460 true
461 }
462 });
463}
464
465struct HeartbeatContext {
467 peers: Arc<RwLock<HashMap<u64, PeerState>>>,
468 membership_tx: watch::Sender<Vec<NodeInfo>>,
469 cancel: CancellationToken,
470}
471
472impl Discovery for StaticDiscovery {
473 async fn start(&mut self) -> Result<(), DiscoveryError> {
474 if self.started {
475 return Ok(());
476 }
477
478 self.cancel = CancellationToken::new();
480
481 let peers = Arc::clone(&self.peers);
482 let membership_tx = self.membership_tx.clone();
483 let cancel = self.cancel.clone();
484 let listen_address = self.config.listen_address.clone();
485 let local_info = self.config.local_node.clone();
486
487 self.listener_handle = Some(tokio::spawn(Self::run_listener(
489 listen_address,
490 local_info,
491 Arc::clone(&peers),
492 membership_tx.clone(),
493 cancel.clone(),
494 )));
495
496 self.heartbeater_handle = Some(tokio::spawn(Self::run_heartbeater(
498 self.config.clone(),
499 HeartbeatContext {
500 peers,
501 membership_tx,
502 cancel,
503 },
504 )));
505
506 self.started = true;
507 Ok(())
508 }
509
510 async fn peers(&self) -> Result<Vec<NodeInfo>, DiscoveryError> {
511 if !self.started {
512 return Err(DiscoveryError::NotStarted);
513 }
514 let peers = self.peers.read();
515 Ok(peers.values().map(|p| p.info.clone()).collect())
516 }
517
518 async fn announce(&self, info: NodeInfo) -> Result<(), DiscoveryError> {
519 if !self.started {
520 return Err(DiscoveryError::NotStarted);
521 }
522 {
523 let mut peers = self.peers.write();
524 peers.insert(
525 info.id.0,
526 PeerState {
527 info,
528 missed_heartbeats: 0,
529 left_ticks: 0,
530 },
531 );
532 }
533 self.broadcast_membership();
534 Ok(())
535 }
536
537 fn membership_watch(&self) -> watch::Receiver<Vec<NodeInfo>> {
538 self.membership_rx.clone()
539 }
540
541 async fn stop(&mut self) -> Result<(), DiscoveryError> {
542 self.cancel.cancel();
543 self.started = false;
544
545 if let Some(h) = self.listener_handle.take() {
547 let _ = h.await;
548 }
549 if let Some(h) = self.heartbeater_handle.take() {
550 let _ = h.await;
551 }
552
553 Ok(())
554 }
555}
556
557#[cfg(test)]
558mod tests {
559 use super::*;
560
561 #[test]
562 fn test_config_default() {
563 let config = StaticDiscoveryConfig::default();
564 assert_eq!(config.heartbeat_interval, Duration::from_secs(1));
565 assert_eq!(config.suspect_threshold, 3);
566 assert_eq!(config.dead_threshold, 10);
567 }
568
569 #[test]
570 fn test_serialize_round_trip() {
571 let info = NodeInfo {
572 id: NodeId(42),
573 name: "test".into(),
574 rpc_address: "127.0.0.1:9000".into(),
575 raft_address: "127.0.0.1:9001".into(),
576 state: NodeState::Active,
577 metadata: NodeMetadata::default(),
578 last_heartbeat_ms: 1000,
579 };
580
581 let data = StaticDiscovery::serialize_node_info(&info).unwrap();
582 let back = StaticDiscovery::deserialize_node_info(&data).unwrap();
583 assert_eq!(back.id, NodeId(42));
584 assert_eq!(back.name, "test");
585 }
586
587 #[test]
588 fn test_deserialize_invalid() {
589 let result = StaticDiscovery::deserialize_node_info(&[0xff, 0xff]);
590 assert!(result.is_err());
591 }
592
593 #[tokio::test]
594 async fn test_not_started_errors() {
595 let config = StaticDiscoveryConfig::default();
596 let disc = StaticDiscovery::new(config);
597 assert!(disc.peers().await.is_err());
598 }
599
600 #[tokio::test]
601 async fn test_start_stop() {
602 let config = StaticDiscoveryConfig {
603 listen_address: "127.0.0.1:0".into(),
604 ..StaticDiscoveryConfig::default()
605 };
606 let mut disc = StaticDiscovery::new(config);
607 disc.start().await.unwrap();
608 assert!(disc.started);
609 disc.stop().await.unwrap();
610 assert!(!disc.started);
611 }
612
613 #[tokio::test]
614 async fn test_double_start_ok() {
615 let config = StaticDiscoveryConfig {
616 listen_address: "127.0.0.1:0".into(),
617 ..StaticDiscoveryConfig::default()
618 };
619 let mut disc = StaticDiscovery::new(config);
620 disc.start().await.unwrap();
621 disc.start().await.unwrap(); disc.stop().await.unwrap();
623 }
624
625 #[tokio::test]
626 async fn test_membership_watch() {
627 let config = StaticDiscoveryConfig {
628 listen_address: "127.0.0.1:0".into(),
629 ..StaticDiscoveryConfig::default()
630 };
631 let disc = StaticDiscovery::new(config);
632 let rx = disc.membership_watch();
633 assert!(rx.borrow().is_empty());
634 }
635
636 #[tokio::test]
637 async fn test_announce_adds_peer() {
638 let config = StaticDiscoveryConfig {
639 listen_address: "127.0.0.1:0".into(),
640 ..StaticDiscoveryConfig::default()
641 };
642 let mut disc = StaticDiscovery::new(config);
643 disc.start().await.unwrap();
644
645 let peer = NodeInfo {
646 id: NodeId(99),
647 name: "peer".into(),
648 rpc_address: "127.0.0.1:8000".into(),
649 raft_address: "127.0.0.1:8001".into(),
650 state: NodeState::Active,
651 metadata: NodeMetadata::default(),
652 last_heartbeat_ms: 0,
653 };
654 disc.announce(peer).await.unwrap();
655
656 let peers = disc.peers().await.unwrap();
657 assert_eq!(peers.len(), 1);
658 assert_eq!(peers[0].id, NodeId(99));
659
660 disc.stop().await.unwrap();
661 }
662
663 #[tokio::test]
664 async fn test_two_node_heartbeat() {
665 let listener1 = TcpListener::bind("127.0.0.1:0").await.unwrap();
666 let addr1 = listener1.local_addr().unwrap().to_string();
667 drop(listener1);
668
669 let listener2 = TcpListener::bind("127.0.0.1:0").await.unwrap();
670 let addr2 = listener2.local_addr().unwrap().to_string();
671 drop(listener2);
672
673 let config1 = StaticDiscoveryConfig {
674 local_node: NodeInfo {
675 id: NodeId(1),
676 name: "node-1".into(),
677 rpc_address: addr1.clone(),
678 raft_address: addr1.clone(),
679 state: NodeState::Active,
680 metadata: NodeMetadata::default(),
681 last_heartbeat_ms: 0,
682 },
683 seeds: vec![addr2.clone()],
684 heartbeat_interval: Duration::from_millis(100),
685 listen_address: addr1.clone(),
686 ..StaticDiscoveryConfig::default()
687 };
688
689 let config2 = StaticDiscoveryConfig {
690 local_node: NodeInfo {
691 id: NodeId(2),
692 name: "node-2".into(),
693 rpc_address: addr2.clone(),
694 raft_address: addr2.clone(),
695 state: NodeState::Active,
696 metadata: NodeMetadata::default(),
697 last_heartbeat_ms: 0,
698 },
699 seeds: vec![addr1],
700 heartbeat_interval: Duration::from_millis(100),
701 listen_address: addr2,
702 ..StaticDiscoveryConfig::default()
703 };
704
705 let mut disc1 = StaticDiscovery::new(config1);
706 let mut disc2 = StaticDiscovery::new(config2);
707
708 disc1.start().await.unwrap();
709 disc2.start().await.unwrap();
710
711 tokio::time::sleep(Duration::from_millis(500)).await;
712
713 let peers1 = disc1.peers().await.unwrap();
714 let peers2 = disc2.peers().await.unwrap();
715
716 assert!(
717 !peers1.is_empty() || !peers2.is_empty(),
718 "at least one node should have discovered peers"
719 );
720
721 disc1.stop().await.unwrap();
722 disc2.stop().await.unwrap();
723 }
724
725 #[tokio::test]
726 async fn test_restart_after_stop() {
727 let config = StaticDiscoveryConfig {
728 listen_address: "127.0.0.1:0".into(),
729 ..StaticDiscoveryConfig::default()
730 };
731 let mut disc = StaticDiscovery::new(config);
732
733 disc.start().await.unwrap();
735 disc.stop().await.unwrap();
736
737 disc.start().await.unwrap();
739 assert!(disc.started);
740 disc.stop().await.unwrap();
741 }
742}