1#![allow(clippy::disallowed_types)] use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::Duration;
7
8use parking_lot::RwLock;
9use tokio::net::{TcpListener, TcpStream};
10use tokio::sync::watch;
11use tokio_util::sync::CancellationToken;
12
13use super::{Discovery, DiscoveryError, NodeId, NodeInfo, NodeMetadata, NodeState};
14
15const CONNECT_TIMEOUT: Duration = Duration::from_secs(2);
17
18const IO_TIMEOUT: Duration = Duration::from_secs(5);
20
21const MAX_HANDLER_TASKS: usize = 64;
23
24const MAX_MESSAGE_SIZE: usize = 1_048_576;
26
27const LEFT_REAP_THRESHOLD: u32 = 30;
29
30#[derive(Debug, Clone)]
32pub struct StaticDiscoveryConfig {
33 pub local_node: NodeInfo,
35 pub seeds: Vec<String>,
37 pub heartbeat_interval: Duration,
39 pub suspect_threshold: u32,
41 pub dead_threshold: u32,
43 pub listen_address: String,
45}
46
47impl Default for StaticDiscoveryConfig {
48 fn default() -> Self {
49 Self {
50 local_node: NodeInfo {
51 id: NodeId(1),
52 name: "node-1".into(),
53 rpc_address: "127.0.0.1:9000".into(),
54 raft_address: "127.0.0.1:9001".into(),
55 state: NodeState::Active,
56 metadata: NodeMetadata::default(),
57 last_heartbeat_ms: 0,
58 },
59 seeds: Vec::new(),
60 heartbeat_interval: Duration::from_secs(1),
61 suspect_threshold: 3,
62 dead_threshold: 10,
63 listen_address: "127.0.0.1:9002".into(),
64 }
65 }
66}
67
68#[derive(Debug)]
70struct PeerState {
71 info: NodeInfo,
72 missed_heartbeats: u32,
74 left_ticks: u32,
76}
77
78#[derive(Debug)]
80pub struct StaticDiscovery {
81 config: StaticDiscoveryConfig,
82 peers: Arc<RwLock<HashMap<u64, PeerState>>>,
83 #[allow(dead_code)]
84 membership_tx: watch::Sender<Vec<NodeInfo>>,
85 membership_rx: watch::Receiver<Vec<NodeInfo>>,
86 cancel: CancellationToken,
87 listener_handle: Option<tokio::task::JoinHandle<Result<(), DiscoveryError>>>,
88 heartbeater_handle: Option<tokio::task::JoinHandle<()>>,
89 started: bool,
90}
91
92impl StaticDiscovery {
93 #[must_use]
95 pub fn new(config: StaticDiscoveryConfig) -> Self {
96 debug_assert!(
97 config.suspect_threshold < config.dead_threshold,
98 "suspect_threshold ({}) must be less than dead_threshold ({})",
99 config.suspect_threshold,
100 config.dead_threshold,
101 );
102 let (tx, rx) = watch::channel(Vec::new());
103 Self {
104 config,
105 peers: Arc::new(RwLock::new(HashMap::new())),
106 membership_tx: tx,
107 membership_rx: rx,
108 cancel: CancellationToken::new(),
109 listener_handle: None,
110 heartbeater_handle: None,
111 started: false,
112 }
113 }
114
115 fn serialize_node_info(info: &NodeInfo) -> Result<Vec<u8>, DiscoveryError> {
117 rkyv::to_bytes::<rkyv::rancor::Error>(info)
118 .map(|v| v.to_vec())
119 .map_err(|e| DiscoveryError::Serialization(e.to_string()))
120 }
121
122 fn deserialize_node_info(data: &[u8]) -> Result<NodeInfo, DiscoveryError> {
124 rkyv::from_bytes::<NodeInfo, rkyv::rancor::Error>(data)
125 .map_err(|e| DiscoveryError::Serialization(e.to_string()))
126 }
127
128 fn broadcast_membership(&self) {
130 let peers = self.peers.read();
131 let peer_list: Vec<NodeInfo> = peers.values().map(|p| p.info.clone()).collect();
132 let _ = self.membership_tx.send(peer_list);
133 }
134
135 #[allow(clippy::cast_possible_truncation)]
137 async fn send_heartbeat(address: &str, data: &[u8]) -> Result<Option<Vec<u8>>, DiscoveryError> {
138 use tokio::io::{AsyncReadExt, AsyncWriteExt};
139
140 let mut stream = tokio::time::timeout(CONNECT_TIMEOUT, TcpStream::connect(address))
142 .await
143 .map_err(|_| DiscoveryError::Connection {
144 address: address.into(),
145 reason: "connect timeout".into(),
146 })?
147 .map_err(|e| DiscoveryError::Connection {
148 address: address.into(),
149 reason: e.to_string(),
150 })?;
151
152 if data.len() > MAX_MESSAGE_SIZE {
154 return Err(DiscoveryError::Serialization(
155 "message too large to send".into(),
156 ));
157 }
158
159 let len = data.len() as u32;
161 tokio::time::timeout(IO_TIMEOUT, async {
162 stream.write_all(&len.to_be_bytes()).await?;
163 stream.write_all(data).await
164 })
165 .await
166 .map_err(|_| DiscoveryError::Connection {
167 address: address.into(),
168 reason: "write timeout".into(),
169 })?
170 .map_err(|e| DiscoveryError::Connection {
171 address: address.into(),
172 reason: e.to_string(),
173 })?;
174
175 let resp = tokio::time::timeout(IO_TIMEOUT, async {
177 let mut len_buf = [0u8; 4];
178 if stream.read_exact(&mut len_buf).await.is_err() {
179 return Ok(None);
180 }
181
182 let resp_len = u32::from_be_bytes(len_buf) as usize;
183 if resp_len > MAX_MESSAGE_SIZE {
184 return Err(DiscoveryError::Serialization("response too large".into()));
185 }
186 let mut resp = vec![0u8; resp_len];
187 stream.read_exact(&mut resp).await?;
188 Ok(Some(resp))
189 })
190 .await
191 .map_err(|_| DiscoveryError::Connection {
192 address: address.into(),
193 reason: "read timeout".into(),
194 })?;
195
196 resp.map_err(|e: DiscoveryError| e)
197 }
198
199 #[allow(clippy::cast_possible_truncation)]
205 async fn run_listener(
206 listen_address: String,
207 local_info: NodeInfo,
208 peers: Arc<RwLock<HashMap<u64, PeerState>>>,
209 membership_tx: watch::Sender<Vec<NodeInfo>>,
210 cancel: CancellationToken,
211 ) -> Result<(), DiscoveryError> {
212 use tokio::io::{AsyncReadExt, AsyncWriteExt};
213
214 let listener = TcpListener::bind(&listen_address)
215 .await
216 .map_err(|e| DiscoveryError::Bind(e.to_string()))?;
217
218 let semaphore = Arc::new(tokio::sync::Semaphore::new(MAX_HANDLER_TASKS));
220
221 loop {
222 tokio::select! {
223 () = cancel.cancelled() => break,
224 accept = listener.accept() => {
225 let (mut stream, _) = accept?;
226 let local_info = local_info.clone();
227 let peers = Arc::clone(&peers);
228 let membership_tx = membership_tx.clone();
229 let permit = Arc::clone(&semaphore);
230
231 tokio::spawn(async move {
232 let Ok(_permit) = permit.try_acquire() else {
234 return; };
236
237 let result = tokio::time::timeout(IO_TIMEOUT, async {
239 let mut len_buf = [0u8; 4];
240 if stream.read_exact(&mut len_buf).await.is_err() {
241 return;
242 }
243 let msg_len = u32::from_be_bytes(len_buf) as usize;
244 if msg_len > MAX_MESSAGE_SIZE {
245 return;
246 }
247 let mut data = vec![0u8; msg_len];
248 if stream.read_exact(&mut data).await.is_err() {
249 return;
250 }
251
252 if let Ok(remote_info) = Self::deserialize_node_info(&data) {
253 if remote_info.id == local_info.id {
255 if let Ok(resp) = Self::serialize_node_info(&local_info) {
257 let len = resp.len() as u32;
258 let _ = stream.write_all(&len.to_be_bytes()).await;
259 let _ = stream.write_all(&resp).await;
260 }
261 return;
262 }
263
264 let peer_list = {
265 let mut guard = peers.write();
266 let now = chrono::Utc::now().timestamp_millis();
267 let peer =
268 guard.entry(remote_info.id.0).or_insert_with(|| {
269 PeerState {
273 info: NodeInfo {
274 last_heartbeat_ms: now,
275 state: NodeState::Joining,
276 ..remote_info.clone()
277 },
278 missed_heartbeats: 0,
279 left_ticks: 0,
280 }
281 });
282 peer.info.rpc_address.clone_from(&remote_info.rpc_address);
286 peer.info.raft_address.clone_from(&remote_info.raft_address);
287 peer.info.name.clone_from(&remote_info.name);
288 peer.info.metadata = remote_info.metadata.clone();
289 peer.info.last_heartbeat_ms = now;
290 guard.values().map(|p| p.info.clone()).collect::<Vec<_>>()
291 };
292 let _ = membership_tx.send(peer_list);
293 }
294
295 if let Ok(resp) = Self::serialize_node_info(&local_info) {
296 let len = resp.len() as u32;
297 let _ = stream.write_all(&len.to_be_bytes()).await;
298 let _ = stream.write_all(&resp).await;
299 }
300 })
301 .await;
302
303 if result.is_err() {
304 }
306 });
307 }
308 }
309 }
310
311 Ok(())
312 }
313
314 async fn run_heartbeater(config: StaticDiscoveryConfig, ctx: HeartbeatContext) {
320 let local_id = config.local_node.id;
321 let mut interval = tokio::time::interval(config.heartbeat_interval);
322 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
324
325 let seed_to_peer = Arc::new(parking_lot::Mutex::new(HashMap::<String, u64>::new()));
328
329 loop {
330 tokio::select! {
331 () = ctx.cancel.cancelled() => break,
332 _ = interval.tick() => {
333 let Ok(data) = Self::serialize_node_info(&config.local_node) else {
334 continue;
335 };
336 let data = Arc::new(data);
337
338 let mut tasks = Vec::with_capacity(config.seeds.len());
340 for seed in &config.seeds {
341 let seed = seed.clone();
342 let data = Arc::clone(&data);
343 tasks.push(tokio::spawn(async move {
344 let result = Self::send_heartbeat(&seed, &data).await;
345 (seed, result)
346 }));
347 }
348
349 for task in tasks {
351 let Ok((seed, result)) = task.await else {
352 continue; };
354
355 if let Ok(Some(resp_data)) = result {
356 if let Ok(remote_info) = Self::deserialize_node_info(&resp_data) {
357 if remote_info.id == local_id {
359 continue;
360 }
361
362 seed_to_peer.lock().insert(seed, remote_info.id.0);
364
365 let mut peers = ctx.peers.write();
366 let now = chrono::Utc::now().timestamp_millis();
367 let peer =
368 peers.entry(remote_info.id.0).or_insert_with(|| {
369 PeerState {
370 info: remote_info.clone(),
371 missed_heartbeats: 0,
372 left_ticks: 0,
373 }
374 });
375 peer.info = NodeInfo {
376 last_heartbeat_ms: now,
377 state: NodeState::Active,
378 ..remote_info
379 };
380 peer.missed_heartbeats = 0;
381 peer.left_ticks = 0;
382 }
383 } else {
384 let map = seed_to_peer.lock();
386 if let Some(&peer_id) = map.get(seed.as_str()) {
387 drop(map);
388 let mut peers = ctx.peers.write();
389 if let Some(peer) = peers.get_mut(&peer_id) {
390 peer.missed_heartbeats += 1;
391 if peer.missed_heartbeats >= config.dead_threshold {
392 peer.info.state = NodeState::Left;
393 } else if peer.missed_heartbeats
394 >= config.suspect_threshold
395 {
396 peer.info.state = NodeState::Suspected;
397 }
398 }
399 }
400 }
401 }
402
403 {
405 let mut peers = ctx.peers.write();
406 peers.retain(|_id, peer| {
407 if peer.info.state == NodeState::Left {
408 peer.left_ticks += 1;
409 peer.left_ticks < LEFT_REAP_THRESHOLD
410 } else {
411 true
412 }
413 });
414 }
415
416 {
418 let peers = ctx.peers.read();
419 let mut map = seed_to_peer.lock();
420 map.retain(|_, peer_id| peers.contains_key(peer_id));
421 }
422
423 let peer_list: Vec<NodeInfo> = {
424 let peers = ctx.peers.read();
425 peers.values().map(|p| p.info.clone()).collect()
426 };
427 let _ = ctx.membership_tx.send(peer_list);
428 }
429 }
430 }
431 }
432}
433
434struct HeartbeatContext {
436 peers: Arc<RwLock<HashMap<u64, PeerState>>>,
437 membership_tx: watch::Sender<Vec<NodeInfo>>,
438 cancel: CancellationToken,
439}
440
441impl Discovery for StaticDiscovery {
442 async fn start(&mut self) -> Result<(), DiscoveryError> {
443 if self.started {
444 return Ok(());
445 }
446
447 self.cancel = CancellationToken::new();
449
450 let peers = Arc::clone(&self.peers);
451 let membership_tx = self.membership_tx.clone();
452 let cancel = self.cancel.clone();
453 let listen_address = self.config.listen_address.clone();
454 let local_info = self.config.local_node.clone();
455
456 self.listener_handle = Some(tokio::spawn(Self::run_listener(
458 listen_address,
459 local_info,
460 Arc::clone(&peers),
461 membership_tx.clone(),
462 cancel.clone(),
463 )));
464
465 self.heartbeater_handle = Some(tokio::spawn(Self::run_heartbeater(
467 self.config.clone(),
468 HeartbeatContext {
469 peers,
470 membership_tx,
471 cancel,
472 },
473 )));
474
475 self.started = true;
476 Ok(())
477 }
478
479 async fn peers(&self) -> Result<Vec<NodeInfo>, DiscoveryError> {
480 if !self.started {
481 return Err(DiscoveryError::NotStarted);
482 }
483 let peers = self.peers.read();
484 Ok(peers.values().map(|p| p.info.clone()).collect())
485 }
486
487 async fn announce(&self, info: NodeInfo) -> Result<(), DiscoveryError> {
488 if !self.started {
489 return Err(DiscoveryError::NotStarted);
490 }
491 {
492 let mut peers = self.peers.write();
493 peers.insert(
494 info.id.0,
495 PeerState {
496 info,
497 missed_heartbeats: 0,
498 left_ticks: 0,
499 },
500 );
501 }
502 self.broadcast_membership();
503 Ok(())
504 }
505
506 fn membership_watch(&self) -> watch::Receiver<Vec<NodeInfo>> {
507 self.membership_rx.clone()
508 }
509
510 async fn stop(&mut self) -> Result<(), DiscoveryError> {
511 self.cancel.cancel();
512 self.started = false;
513
514 if let Some(h) = self.listener_handle.take() {
516 let _ = h.await;
517 }
518 if let Some(h) = self.heartbeater_handle.take() {
519 let _ = h.await;
520 }
521
522 Ok(())
523 }
524}
525
526#[cfg(test)]
527mod tests {
528 use super::*;
529
530 #[test]
531 fn test_config_default() {
532 let config = StaticDiscoveryConfig::default();
533 assert_eq!(config.heartbeat_interval, Duration::from_secs(1));
534 assert_eq!(config.suspect_threshold, 3);
535 assert_eq!(config.dead_threshold, 10);
536 }
537
538 #[test]
539 fn test_serialize_round_trip() {
540 let info = NodeInfo {
541 id: NodeId(42),
542 name: "test".into(),
543 rpc_address: "127.0.0.1:9000".into(),
544 raft_address: "127.0.0.1:9001".into(),
545 state: NodeState::Active,
546 metadata: NodeMetadata::default(),
547 last_heartbeat_ms: 1000,
548 };
549
550 let data = StaticDiscovery::serialize_node_info(&info).unwrap();
551 let back = StaticDiscovery::deserialize_node_info(&data).unwrap();
552 assert_eq!(back.id, NodeId(42));
553 assert_eq!(back.name, "test");
554 }
555
556 #[test]
557 fn test_deserialize_invalid() {
558 let result = StaticDiscovery::deserialize_node_info(&[0xff, 0xff]);
559 assert!(result.is_err());
560 }
561
562 #[tokio::test]
563 async fn test_not_started_errors() {
564 let config = StaticDiscoveryConfig::default();
565 let disc = StaticDiscovery::new(config);
566 assert!(disc.peers().await.is_err());
567 }
568
569 #[tokio::test]
570 async fn test_start_stop() {
571 let config = StaticDiscoveryConfig {
572 listen_address: "127.0.0.1:0".into(),
573 ..StaticDiscoveryConfig::default()
574 };
575 let mut disc = StaticDiscovery::new(config);
576 disc.start().await.unwrap();
577 assert!(disc.started);
578 disc.stop().await.unwrap();
579 assert!(!disc.started);
580 }
581
582 #[tokio::test]
583 async fn test_double_start_ok() {
584 let config = StaticDiscoveryConfig {
585 listen_address: "127.0.0.1:0".into(),
586 ..StaticDiscoveryConfig::default()
587 };
588 let mut disc = StaticDiscovery::new(config);
589 disc.start().await.unwrap();
590 disc.start().await.unwrap(); disc.stop().await.unwrap();
592 }
593
594 #[tokio::test]
595 async fn test_membership_watch() {
596 let config = StaticDiscoveryConfig {
597 listen_address: "127.0.0.1:0".into(),
598 ..StaticDiscoveryConfig::default()
599 };
600 let disc = StaticDiscovery::new(config);
601 let rx = disc.membership_watch();
602 assert!(rx.borrow().is_empty());
603 }
604
605 #[tokio::test]
606 async fn test_announce_adds_peer() {
607 let config = StaticDiscoveryConfig {
608 listen_address: "127.0.0.1:0".into(),
609 ..StaticDiscoveryConfig::default()
610 };
611 let mut disc = StaticDiscovery::new(config);
612 disc.start().await.unwrap();
613
614 let peer = NodeInfo {
615 id: NodeId(99),
616 name: "peer".into(),
617 rpc_address: "127.0.0.1:8000".into(),
618 raft_address: "127.0.0.1:8001".into(),
619 state: NodeState::Active,
620 metadata: NodeMetadata::default(),
621 last_heartbeat_ms: 0,
622 };
623 disc.announce(peer).await.unwrap();
624
625 let peers = disc.peers().await.unwrap();
626 assert_eq!(peers.len(), 1);
627 assert_eq!(peers[0].id, NodeId(99));
628
629 disc.stop().await.unwrap();
630 }
631
632 #[tokio::test]
633 async fn test_two_node_heartbeat() {
634 let listener1 = TcpListener::bind("127.0.0.1:0").await.unwrap();
635 let addr1 = listener1.local_addr().unwrap().to_string();
636 drop(listener1);
637
638 let listener2 = TcpListener::bind("127.0.0.1:0").await.unwrap();
639 let addr2 = listener2.local_addr().unwrap().to_string();
640 drop(listener2);
641
642 let config1 = StaticDiscoveryConfig {
643 local_node: NodeInfo {
644 id: NodeId(1),
645 name: "node-1".into(),
646 rpc_address: addr1.clone(),
647 raft_address: addr1.clone(),
648 state: NodeState::Active,
649 metadata: NodeMetadata::default(),
650 last_heartbeat_ms: 0,
651 },
652 seeds: vec![addr2.clone()],
653 heartbeat_interval: Duration::from_millis(100),
654 listen_address: addr1.clone(),
655 ..StaticDiscoveryConfig::default()
656 };
657
658 let config2 = StaticDiscoveryConfig {
659 local_node: NodeInfo {
660 id: NodeId(2),
661 name: "node-2".into(),
662 rpc_address: addr2.clone(),
663 raft_address: addr2.clone(),
664 state: NodeState::Active,
665 metadata: NodeMetadata::default(),
666 last_heartbeat_ms: 0,
667 },
668 seeds: vec![addr1],
669 heartbeat_interval: Duration::from_millis(100),
670 listen_address: addr2,
671 ..StaticDiscoveryConfig::default()
672 };
673
674 let mut disc1 = StaticDiscovery::new(config1);
675 let mut disc2 = StaticDiscovery::new(config2);
676
677 disc1.start().await.unwrap();
678 disc2.start().await.unwrap();
679
680 tokio::time::sleep(Duration::from_millis(500)).await;
681
682 let peers1 = disc1.peers().await.unwrap();
683 let peers2 = disc2.peers().await.unwrap();
684
685 assert!(
686 !peers1.is_empty() || !peers2.is_empty(),
687 "at least one node should have discovered peers"
688 );
689
690 disc1.stop().await.unwrap();
691 disc2.stop().await.unwrap();
692 }
693
694 #[tokio::test]
695 async fn test_restart_after_stop() {
696 let config = StaticDiscoveryConfig {
697 listen_address: "127.0.0.1:0".into(),
698 ..StaticDiscoveryConfig::default()
699 };
700 let mut disc = StaticDiscovery::new(config);
701
702 disc.start().await.unwrap();
704 disc.stop().await.unwrap();
705
706 disc.start().await.unwrap();
708 assert!(disc.started);
709 disc.stop().await.unwrap();
710 }
711}