1use std::net::{SocketAddr, UdpSocket};
27use std::sync::Arc;
28use std::time::{Duration, Instant};
29
30use async_trait::async_trait;
31use chitchat::transport::{Socket, Transport, UdpTransport};
32use object_store::{
33 path::Path as OsPath, CopyOptions, GetOptions, GetResult, ListResult, MultipartUpload,
34 ObjectMeta, ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult,
35};
36use parking_lot::Mutex;
37use rustc_hash::FxHashSet;
38use tokio::sync::watch;
39
40use super::control::{AssignmentSnapshotStore, ChitchatKv, ClusterController, ClusterKv};
41use super::discovery::{
42 Discovery, GossipDiscovery, GossipDiscoveryConfig, NodeId, NodeInfo, NodeMetadata, NodeState,
43};
44
45pub struct NetworkRules {
51 dropped: Mutex<FxHashSet<(SocketAddr, SocketAddr)>>,
52}
53
54impl std::fmt::Debug for NetworkRules {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 f.debug_struct("NetworkRules")
57 .field("drop_count", &self.dropped.lock().len())
58 .finish()
59 }
60}
61
62impl NetworkRules {
63 #[must_use]
65 pub fn new() -> Self {
66 Self {
67 dropped: Mutex::new(FxHashSet::default()),
68 }
69 }
70
71 pub fn partition(&self, side_a: &[SocketAddr], side_b: &[SocketAddr]) {
75 let mut set = self.dropped.lock();
76 for a in side_a {
77 for b in side_b {
78 set.insert((*a, *b));
79 set.insert((*b, *a));
80 }
81 }
82 }
83
84 pub fn drop_pair(&self, src: SocketAddr, dst: SocketAddr) {
86 self.dropped.lock().insert((src, dst));
87 }
88
89 pub fn heal(&self) {
91 self.dropped.lock().clear();
92 }
93
94 #[must_use]
96 pub fn is_dropped(&self, src: SocketAddr, dst: SocketAddr) -> bool {
97 self.dropped.lock().contains(&(src, dst))
98 }
99}
100
101impl Default for NetworkRules {
102 fn default() -> Self {
103 Self::new()
104 }
105}
106
107#[derive(Debug, Clone, Copy, PartialEq, Eq)]
109pub enum ObjectStoreFault {
110 None,
112 FailWrites,
114 FailReads,
116 FailAll,
118}
119
120impl ObjectStoreFault {
121 fn fails_writes(self) -> bool {
122 matches!(self, Self::FailWrites | Self::FailAll)
123 }
124 fn fails_reads(self) -> bool {
125 matches!(self, Self::FailReads | Self::FailAll)
126 }
127}
128
129pub struct FaultyObjectStore {
131 inner: Arc<dyn ObjectStore>,
132 fault: Mutex<ObjectStoreFault>,
133}
134
135impl std::fmt::Debug for FaultyObjectStore {
136 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137 f.debug_struct("FaultyObjectStore")
138 .field("fault", &self.fault())
139 .finish_non_exhaustive()
140 }
141}
142
143impl std::fmt::Display for FaultyObjectStore {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 write!(f, "FaultyObjectStore({:?})", self.fault())
146 }
147}
148
149impl FaultyObjectStore {
150 #[must_use]
152 pub fn new(inner: Arc<dyn ObjectStore>) -> Self {
153 Self {
154 inner,
155 fault: Mutex::new(ObjectStoreFault::None),
156 }
157 }
158
159 #[must_use]
161 pub fn fault(&self) -> ObjectStoreFault {
162 *self.fault.lock()
163 }
164
165 pub fn set_fault(&self, mode: ObjectStoreFault) {
167 *self.fault.lock() = mode;
168 }
169
170 fn check_write(&self) -> object_store::Result<()> {
171 if self.fault().fails_writes() {
172 return Err(object_store::Error::Generic {
173 store: "FaultyObjectStore",
174 source: "injected write failure".into(),
175 });
176 }
177 Ok(())
178 }
179
180 fn check_read(&self, path: &OsPath) -> object_store::Result<()> {
181 if self.fault().fails_reads() {
182 return Err(object_store::Error::NotFound {
183 path: path.to_string(),
184 source: "injected read failure".into(),
185 });
186 }
187 Ok(())
188 }
189}
190
191#[async_trait]
192impl ObjectStore for FaultyObjectStore {
193 async fn put_opts(
194 &self,
195 location: &OsPath,
196 payload: PutPayload,
197 opts: PutOptions,
198 ) -> object_store::Result<PutResult> {
199 self.check_write()?;
200 self.inner.put_opts(location, payload, opts).await
201 }
202
203 async fn put_multipart_opts(
204 &self,
205 location: &OsPath,
206 opts: PutMultipartOptions,
207 ) -> object_store::Result<Box<dyn MultipartUpload>> {
208 self.check_write()?;
209 self.inner.put_multipart_opts(location, opts).await
210 }
211
212 async fn get_opts(
213 &self,
214 location: &OsPath,
215 options: GetOptions,
216 ) -> object_store::Result<GetResult> {
217 self.check_read(location)?;
218 self.inner.get_opts(location, options).await
219 }
220
221 fn delete_stream(
222 &self,
223 locations: futures::stream::BoxStream<'static, object_store::Result<OsPath>>,
224 ) -> futures::stream::BoxStream<'static, object_store::Result<OsPath>> {
225 if self.fault().fails_writes() {
228 use futures::StreamExt;
229 locations
230 .map(|_| {
231 Err(object_store::Error::Generic {
232 store: "FaultyObjectStore",
233 source: "injected write failure (delete_stream)".into(),
234 })
235 })
236 .boxed()
237 } else {
238 self.inner.delete_stream(locations)
239 }
240 }
241
242 fn list(
243 &self,
244 prefix: Option<&OsPath>,
245 ) -> futures::stream::BoxStream<'static, object_store::Result<ObjectMeta>> {
246 self.inner.list(prefix)
247 }
248
249 async fn list_with_delimiter(
250 &self,
251 prefix: Option<&OsPath>,
252 ) -> object_store::Result<ListResult> {
253 self.inner.list_with_delimiter(prefix).await
254 }
255
256 async fn copy_opts(
257 &self,
258 from: &OsPath,
259 to: &OsPath,
260 options: CopyOptions,
261 ) -> object_store::Result<()> {
262 self.check_write()?;
263 self.inner.copy_opts(from, to, options).await
264 }
265}
266
267pub struct PartitionableTransport {
271 rules: Arc<NetworkRules>,
272 inner: UdpTransport,
273}
274
275impl PartitionableTransport {
276 #[must_use]
278 pub fn new(rules: Arc<NetworkRules>) -> Self {
279 Self {
280 rules,
281 inner: UdpTransport,
282 }
283 }
284}
285
286#[async_trait]
287impl Transport for PartitionableTransport {
288 async fn open(&self, listen_addr: SocketAddr) -> anyhow::Result<Box<dyn Socket>> {
289 let socket = self.inner.open(listen_addr).await?;
290 Ok(Box::new(PartitionableSocket {
291 my_addr: listen_addr,
292 rules: Arc::clone(&self.rules),
293 inner: socket,
294 }))
295 }
296}
297
298struct PartitionableSocket {
299 my_addr: SocketAddr,
300 rules: Arc<NetworkRules>,
301 inner: Box<dyn Socket>,
302}
303
304#[async_trait]
305impl Socket for PartitionableSocket {
306 async fn send(&mut self, to: SocketAddr, msg: chitchat::ChitchatMessage) -> anyhow::Result<()> {
307 if self.rules.is_dropped(self.my_addr, to) {
308 return Ok(());
312 }
313 self.inner.send(to, msg).await
314 }
315
316 async fn recv(&mut self) -> anyhow::Result<(SocketAddr, chitchat::ChitchatMessage)> {
317 self.inner.recv().await
318 }
319}
320
321fn grab_port() -> u16 {
325 let sock = UdpSocket::bind("127.0.0.1:0").expect("bind 127.0.0.1:0");
326 let port = sock.local_addr().expect("local_addr").port();
327 drop(sock);
328 port
329}
330
331pub struct NodeHandle {
333 pub instance_id: NodeId,
335 pub gossip_addr: String,
337 pub controller: Arc<ClusterController>,
340 discovery: GossipDiscovery,
342}
343
344impl NodeHandle {
345 pub async fn kill(mut self) {
355 let left = NodeInfo {
356 state: NodeState::Left,
357 ..current_info(&self)
358 };
359 let _ = self.discovery.announce(left).await;
360 tokio::time::sleep(Duration::from_millis(150)).await;
363 let _ = self.discovery.stop().await;
364 }
365
366 pub async fn crash(mut self) {
379 let _ = self.discovery.stop().await;
380 }
381}
382
383fn current_info(node: &NodeHandle) -> NodeInfo {
384 NodeInfo {
385 id: node.instance_id,
386 name: format!("minicluster-n{}", node.instance_id.0),
387 rpc_address: String::new(),
388 raft_address: String::new(),
389 state: NodeState::Active,
390 metadata: NodeMetadata {
391 cores: 1,
392 ..NodeMetadata::default()
393 },
394 last_heartbeat_ms: 0,
395 }
396}
397
398pub struct MiniCluster {
400 pub nodes: Vec<NodeHandle>,
403 pub rules: Option<Arc<NetworkRules>>,
407 pub snapshot: Option<Arc<AssignmentSnapshotStore>>,
411}
412
413impl MiniCluster {
414 pub async fn spawn(n: usize) -> Self {
422 Self::spawn_inner(n, None, None).await
423 }
424
425 pub async fn spawn_partitionable(n: usize) -> Self {
434 let rules = Arc::new(NetworkRules::new());
435 Self::spawn_inner(n, Some(rules), None).await
436 }
437
438 pub async fn spawn_with_snapshot(n: usize, snapshot: Arc<AssignmentSnapshotStore>) -> Self {
444 Self::spawn_inner(n, None, Some(snapshot)).await
445 }
446
447 pub async fn join_node(&mut self, instance_id: NodeId) {
456 assert!(!self.nodes.is_empty(), "cannot join empty cluster");
457 let seeds: Vec<String> = self.nodes.iter().map(|n| n.gossip_addr.clone()).collect();
461 let port = grab_port();
462 let gossip_addr = format!("127.0.0.1:{port}");
463
464 let local_node = NodeInfo {
465 id: instance_id,
466 name: format!("minicluster-rejoin-{}", instance_id.0),
467 rpc_address: String::new(),
468 raft_address: String::new(),
469 state: NodeState::Active,
470 metadata: NodeMetadata {
471 cores: 1,
472 ..NodeMetadata::default()
473 },
474 last_heartbeat_ms: 0,
475 };
476
477 let cfg = GossipDiscoveryConfig {
478 gossip_address: gossip_addr.clone(),
479 seed_nodes: seeds,
480 gossip_interval: Duration::from_millis(50),
481 phi_threshold: 3.0,
482 dead_node_grace_period: Duration::from_secs(1),
483 cluster_id: "minicluster".to_string(),
484 node_id: instance_id,
485 local_node,
486 };
487 let mut discovery = GossipDiscovery::new(cfg);
488 match &self.rules {
489 Some(rules) => {
490 let transport = PartitionableTransport::new(Arc::clone(rules));
491 discovery
492 .start_with_transport(&transport)
493 .await
494 .expect("partitionable chitchat start on rejoin");
495 }
496 None => discovery.start().await.expect("chitchat start on rejoin"),
497 }
498
499 let handle = discovery
500 .chitchat_handle()
501 .expect("chitchat handle available after start");
502 let kv: Arc<dyn ClusterKv> = Arc::new(ChitchatKv::from_handle(handle));
503 let members_rx = discovery.membership_watch();
504 let controller = Arc::new(ClusterController::new(
505 instance_id,
506 kv,
507 self.snapshot.clone(),
508 members_rx,
509 ));
510
511 self.nodes.push(NodeHandle {
512 instance_id,
513 gossip_addr,
514 controller,
515 discovery,
516 });
517 }
518
519 async fn spawn_inner(
520 n: usize,
521 rules: Option<Arc<NetworkRules>>,
522 snapshot: Option<Arc<AssignmentSnapshotStore>>,
523 ) -> Self {
524 assert!(n >= 1, "MiniCluster needs at least one node");
525
526 let ports: Vec<u16> = (0..n).map(|_| grab_port()).collect();
527 let seed = format!("127.0.0.1:{}", ports[0]);
528 let transport = rules
529 .as_ref()
530 .map(|r| PartitionableTransport::new(Arc::clone(r)));
531
532 let mut nodes = Vec::with_capacity(n);
533 for (idx, port) in ports.iter().enumerate() {
534 let instance_id = NodeId((idx as u64) + 1); let gossip_addr = format!("127.0.0.1:{port}");
536
537 let local_node = NodeInfo {
538 id: instance_id,
539 name: format!("minicluster-n{idx}"),
540 rpc_address: String::new(),
541 raft_address: String::new(),
542 state: NodeState::Active,
543 metadata: NodeMetadata {
544 cores: 1,
545 ..NodeMetadata::default()
546 },
547 last_heartbeat_ms: 0,
548 };
549
550 let seeds = if idx == 0 {
551 Vec::new()
552 } else {
553 vec![seed.clone()]
554 };
555 let cfg = GossipDiscoveryConfig {
559 gossip_address: gossip_addr.clone(),
560 seed_nodes: seeds,
561 gossip_interval: Duration::from_millis(50),
562 phi_threshold: 3.0,
563 dead_node_grace_period: Duration::from_secs(1),
564 cluster_id: "minicluster".to_string(),
565 node_id: instance_id,
566 local_node,
567 };
568 let mut discovery = GossipDiscovery::new(cfg);
569 match &transport {
570 Some(t) => discovery
571 .start_with_transport(t)
572 .await
573 .expect("partitionable chitchat start"),
574 None => discovery.start().await.expect("chitchat start on loopback"),
575 }
576
577 let handle = discovery
578 .chitchat_handle()
579 .expect("chitchat handle available after start");
580 let kv: Arc<dyn ClusterKv> = Arc::new(ChitchatKv::from_handle(handle));
581 let members_rx: watch::Receiver<Vec<NodeInfo>> = discovery.membership_watch();
582 let controller = Arc::new(ClusterController::new(
583 instance_id,
584 kv,
585 snapshot.clone(),
586 members_rx,
587 ));
588
589 nodes.push(NodeHandle {
590 instance_id,
591 gossip_addr,
592 controller,
593 discovery,
594 });
595 }
596 Self {
597 nodes,
598 rules,
599 snapshot,
600 }
601 }
602
603 #[must_use]
611 pub fn addrs(&self) -> Vec<SocketAddr> {
612 self.nodes
613 .iter()
614 .map(|n| n.gossip_addr.parse().expect("valid gossip_addr"))
615 .collect()
616 }
617
618 pub async fn wait_for_convergence(&self, deadline: Duration) -> Result<(), String> {
625 let start = Instant::now();
626 loop {
627 let mut all_converged = true;
628 let mut missing_summary = Vec::new();
629 for node in &self.nodes {
630 let peers = node
631 .discovery
632 .peers()
633 .await
634 .map_err(|e| format!("peers() failed on {}: {e}", node.instance_id.0))?;
635 let expected = self.nodes.len() - 1;
636 if peers.len() < expected {
637 all_converged = false;
638 missing_summary.push(format!(
639 "node {} sees {} peers (expected {})",
640 node.instance_id.0,
641 peers.len(),
642 expected
643 ));
644 }
645 }
646 if all_converged {
647 return Ok(());
648 }
649 if start.elapsed() >= deadline {
650 return Err(format!(
651 "convergence timeout after {:?}: {}",
652 deadline,
653 missing_summary.join("; "),
654 ));
655 }
656 tokio::time::sleep(Duration::from_millis(100)).await;
657 }
658 }
659
660 pub async fn shutdown(mut self) {
662 for node in self.nodes.drain(..) {
663 node.kill().await;
664 }
665 }
666}