1use std::io;
7use std::net::SocketAddr;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10
11use rustc_hash::FxHashMap;
12use tokio::net::{TcpListener, TcpStream};
13use tokio::sync::{mpsc, Mutex, RwLock};
14use tokio::task::JoinHandle;
15
16use super::message::{read_message, write_message, ShuffleMessage};
17use crate::checkpoint::barrier::CheckpointBarrier;
18
19#[cfg(feature = "cluster-unstable")]
20use crate::cluster::control::ClusterKv;
21
22#[cfg(feature = "cluster-unstable")]
27pub const SHUFFLE_ADDR_KEY: &str = "shuffle:addr";
28
29pub type ShufflePeerId = u64;
32
33struct ShuffleConnection {
36 writer: Mutex<tokio::io::WriteHalf<TcpStream>>,
39 reader: JoinHandle<()>,
41 alive: Arc<AtomicBool>,
47}
48
49impl ShuffleConnection {
50 async fn send(&self, msg: &ShuffleMessage) -> io::Result<()> {
51 let mut w = self.writer.lock().await;
52 write_message(&mut *w, msg).await
53 }
54
55 fn is_alive(&self) -> bool {
56 self.alive.load(Ordering::Acquire)
57 }
58}
59
60impl Drop for ShuffleConnection {
61 fn drop(&mut self) {
62 self.reader.abort();
63 }
64}
65
66pub struct ShuffleSender {
69 local_id: ShufflePeerId,
70 peers: RwLock<FxHashMap<ShufflePeerId, SocketAddr>>,
71 pool: RwLock<FxHashMap<ShufflePeerId, Arc<ShuffleConnection>>>,
72 #[cfg(feature = "cluster-unstable")]
73 kv: Option<Arc<dyn ClusterKv>>,
74}
75
76impl std::fmt::Debug for ShuffleSender {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 f.debug_struct("ShuffleSender")
79 .field("local_id", &self.local_id)
80 .finish_non_exhaustive()
81 }
82}
83
84impl ShuffleSender {
85 #[must_use]
88 pub fn new(local_id: ShufflePeerId) -> Self {
89 Self {
90 local_id,
91 peers: RwLock::new(FxHashMap::default()),
92 pool: RwLock::new(FxHashMap::default()),
93 #[cfg(feature = "cluster-unstable")]
94 kv: None,
95 }
96 }
97
98 #[cfg(feature = "cluster-unstable")]
102 #[must_use]
103 pub fn with_kv(local_id: ShufflePeerId, kv: Arc<dyn ClusterKv>) -> Self {
104 let mut s = Self::new(local_id);
105 s.kv = Some(kv);
106 s
107 }
108
109 pub async fn register_peer(&self, peer: ShufflePeerId, addr: SocketAddr) {
112 self.peers.write().await.insert(peer, addr);
113 }
114
115 pub async fn send_to(&self, peer: ShufflePeerId, msg: &ShuffleMessage) -> io::Result<()> {
121 let conn = self.connection_for(peer).await?;
122 conn.send(msg).await
123 }
124
125 pub async fn fan_out_barrier(
132 &self,
133 peers: &[ShufflePeerId],
134 barrier: CheckpointBarrier,
135 ) -> io::Result<()> {
136 let msg = ShuffleMessage::Barrier(barrier);
137 for &peer in peers {
138 self.send_to(peer, &msg).await?;
139 }
140 Ok(())
141 }
142
143 #[cfg(feature = "cluster-unstable")]
147 async fn discover_peer(&self, peer: ShufflePeerId) -> Option<SocketAddr> {
148 let kv = self.kv.as_ref()?;
149 let raw = kv
150 .read_from(crate::cluster::discovery::NodeId(peer), SHUFFLE_ADDR_KEY)
151 .await?;
152 let addr: SocketAddr = raw.parse().ok()?;
153 self.peers.write().await.insert(peer, addr);
154 Some(addr)
155 }
156
157 async fn connection_for(&self, peer: ShufflePeerId) -> io::Result<Arc<ShuffleConnection>> {
158 if let Some(existing) = self.pool.read().await.get(&peer).cloned() {
163 if existing.is_alive() {
164 return Ok(existing);
165 }
166 }
167
168 {
176 let mut pool = self.pool.write().await;
177 if let Some(c) = pool.get(&peer) {
178 if !c.is_alive() {
179 pool.remove(&peer);
180 }
181 }
182 }
183
184 let addr = if let Some(a) = self.peers.read().await.get(&peer).copied() {
185 a
186 } else {
187 #[cfg(feature = "cluster-unstable")]
188 let discovered = self.discover_peer(peer).await;
189 #[cfg(not(feature = "cluster-unstable"))]
190 let discovered: Option<SocketAddr> = None;
191 discovered.ok_or_else(|| {
192 io::Error::new(
193 io::ErrorKind::NotFound,
194 format!("peer {peer} has no registered shuffle address"),
195 )
196 })?
197 };
198
199 let stream = TcpStream::connect(addr).await?;
201 stream.set_nodelay(true)?;
202 let (mut reader_half, mut writer_half) = tokio::io::split(stream);
203 write_message(&mut writer_half, &ShuffleMessage::Hello(self.local_id)).await?;
204
205 let alive = Arc::new(AtomicBool::new(true));
212 let alive_for_reader = Arc::clone(&alive);
213 let reader = tokio::spawn(async move {
214 let _ = peer;
215 loop {
216 match read_message(&mut reader_half).await {
217 Ok(ShuffleMessage::Close(_)) | Err(_) => break,
218 Ok(_) => {}
219 }
220 }
221 alive_for_reader.store(false, Ordering::Release);
222 });
223
224 let conn = Arc::new(ShuffleConnection {
225 writer: Mutex::new(writer_half),
226 reader,
227 alive,
228 });
229
230 let mut pool = self.pool.write().await;
235 if let Some(winner) = pool.get(&peer).cloned() {
236 if winner.is_alive() {
237 return Ok(winner);
238 }
239 }
240 pool.insert(peer, Arc::clone(&conn));
241 Ok(conn)
242 }
243}
244
245pub struct ShuffleReceiver {
250 local_id: ShufflePeerId,
251 local_addr: SocketAddr,
252 accept: JoinHandle<()>,
253 peer_tasks: Arc<parking_lot::Mutex<Vec<JoinHandle<()>>>>,
258 rx: Mutex<mpsc::UnboundedReceiver<(ShufflePeerId, ShuffleMessage)>>,
259}
260
261impl Drop for ShuffleReceiver {
262 fn drop(&mut self) {
263 self.accept.abort();
268 for h in self.peer_tasks.lock().drain(..) {
269 h.abort();
270 }
271 }
272}
273
274impl std::fmt::Debug for ShuffleReceiver {
275 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276 f.debug_struct("ShuffleReceiver")
277 .field("local_id", &self.local_id)
278 .field("local_addr", &self.local_addr)
279 .finish_non_exhaustive()
280 }
281}
282
283impl ShuffleReceiver {
284 pub async fn bind(local_id: ShufflePeerId, addr: SocketAddr) -> io::Result<Self> {
291 Self::bind_impl(local_id, addr).await
292 }
293
294 #[cfg(feature = "cluster-unstable")]
301 pub async fn bind_with_kv(
302 local_id: ShufflePeerId,
303 addr: SocketAddr,
304 kv: Arc<dyn ClusterKv>,
305 ) -> io::Result<Self> {
306 let recv = Self::bind_impl(local_id, addr).await?;
307 kv.write(SHUFFLE_ADDR_KEY, recv.local_addr.to_string())
308 .await;
309 Ok(recv)
310 }
311
312 async fn bind_impl(local_id: ShufflePeerId, addr: SocketAddr) -> io::Result<Self> {
313 let listener = TcpListener::bind(addr).await?;
314 let local_addr = listener.local_addr()?;
315 let (tx, rx) = mpsc::unbounded_channel();
316
317 let peer_tasks: Arc<parking_lot::Mutex<Vec<JoinHandle<()>>>> =
318 Arc::new(parking_lot::Mutex::new(Vec::new()));
319 let accept = tokio::spawn(Self::accept_loop(listener, tx, Arc::clone(&peer_tasks)));
320
321 Ok(Self {
322 local_id,
323 local_addr,
324 accept,
325 peer_tasks,
326 rx: Mutex::new(rx),
327 })
328 }
329
330 #[must_use]
332 pub fn local_addr(&self) -> SocketAddr {
333 self.local_addr
334 }
335
336 pub async fn recv(&self) -> Option<(ShufflePeerId, ShuffleMessage)> {
338 self.rx.lock().await.recv().await
339 }
340
341 pub fn drain_available(&self) -> Vec<(ShufflePeerId, ShuffleMessage)> {
350 let Ok(mut guard) = self.rx.try_lock() else {
351 return Vec::new();
352 };
353 let mut out = Vec::new();
354 while let Ok(msg) = guard.try_recv() {
355 out.push(msg);
356 }
357 out
358 }
359
360 async fn accept_loop(
361 listener: TcpListener,
362 tx: mpsc::UnboundedSender<(ShufflePeerId, ShuffleMessage)>,
363 peer_tasks: Arc<parking_lot::Mutex<Vec<JoinHandle<()>>>>,
364 ) {
365 loop {
366 let Ok((stream, _peer_addr)) = listener.accept().await else {
367 break;
368 };
369 if stream.set_nodelay(true).is_err() {
370 continue;
371 }
372 let tx = tx.clone();
373 let handle = tokio::spawn(Self::per_peer_loop(stream, tx));
374 let mut tasks = peer_tasks.lock();
378 tasks.retain(|h| !h.is_finished());
379 tasks.push(handle);
380 }
381 }
382
383 async fn per_peer_loop(
384 stream: TcpStream,
385 tx: mpsc::UnboundedSender<(ShufflePeerId, ShuffleMessage)>,
386 ) {
387 let (mut reader_half, _writer_half) = tokio::io::split(stream);
388 let Ok(ShuffleMessage::Hello(peer)) = read_message(&mut reader_half).await else {
390 return;
391 };
392 loop {
393 match read_message(&mut reader_half).await {
394 Ok(ShuffleMessage::Close(_)) | Err(_) => break,
395 Ok(msg) => {
396 if tx.send((peer, msg)).is_err() {
397 break;
398 }
399 }
400 }
401 }
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 async fn bind_on_loopback(local_id: ShufflePeerId) -> ShuffleReceiver {
410 ShuffleReceiver::bind(local_id, "127.0.0.1:0".parse().unwrap())
411 .await
412 .expect("bind")
413 }
414
415 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
416 async fn sender_to_receiver_delivers_with_peer_attribution() {
417 let recv = bind_on_loopback(2).await;
418 let recv_addr = recv.local_addr();
419
420 let sender = ShuffleSender::new(1);
421 sender.register_peer(2, recv_addr).await;
422 sender
423 .send_to(2, &ShuffleMessage::Hello(1234))
424 .await
425 .unwrap();
426
427 let (from, msg) = recv.recv().await.unwrap();
428 assert_eq!(from, 1, "receiver attributes frame to sender id");
429 assert_eq!(msg, ShuffleMessage::Hello(1234));
430 }
431
432 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
433 async fn sender_reuses_connection_across_sends() {
434 let recv = bind_on_loopback(2).await;
435 let sender = ShuffleSender::new(1);
436 sender.register_peer(2, recv.local_addr()).await;
437
438 for delta in [10u64, 20, 30, 40] {
439 sender
440 .send_to(2, &ShuffleMessage::Hello(delta))
441 .await
442 .unwrap();
443 }
444
445 let mut got = Vec::new();
446 for _ in 0..4 {
447 got.push(recv.recv().await.unwrap().1);
448 }
449 assert_eq!(
450 got,
451 vec![
452 ShuffleMessage::Hello(10),
453 ShuffleMessage::Hello(20),
454 ShuffleMessage::Hello(30),
455 ShuffleMessage::Hello(40),
456 ]
457 );
458 assert_eq!(sender.pool.read().await.len(), 1);
460 }
461
462 #[cfg(feature = "cluster-unstable")]
463 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
464 async fn discover_peer_reads_registered_address_from_kv() {
465 use crate::cluster::control::{ClusterKv, InMemoryKv};
466 use crate::cluster::discovery::NodeId;
467
468 let kv = Arc::new(InMemoryKv::new(NodeId(1)));
472 kv.seed(NodeId(2), SHUFFLE_ADDR_KEY, "127.0.0.1:54321".into());
473 let sender = ShuffleSender::with_kv(1, kv as Arc<dyn ClusterKv>);
474
475 let expected: SocketAddr = "127.0.0.1:54321".parse().unwrap();
476 let addr = sender.discover_peer(2).await.expect("peer found");
477 assert_eq!(addr, expected);
478 assert_eq!(sender.peers.read().await.get(&2).copied(), Some(expected));
479 }
480
481 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
482 async fn send_to_unregistered_peer_errors() {
483 let sender = ShuffleSender::new(1);
484 let err = sender
485 .send_to(99, &ShuffleMessage::Hello(1))
486 .await
487 .unwrap_err();
488 assert_eq!(err.kind(), io::ErrorKind::NotFound);
489 }
490
491 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
496 async fn send_reconnects_after_peer_restart_at_new_address() {
497 let recv_v1 = bind_on_loopback(2).await;
499 let addr_v1 = recv_v1.local_addr();
500
501 let sender = ShuffleSender::new(1);
502 sender.register_peer(2, addr_v1).await;
503
504 sender
506 .send_to(2, &ShuffleMessage::Hello(111))
507 .await
508 .unwrap();
509 let (from, msg) = recv_v1.recv().await.unwrap();
510 assert_eq!(from, 1);
511 assert_eq!(msg, ShuffleMessage::Hello(111));
512
513 {
515 let pool = sender.pool.read().await;
516 assert_eq!(pool.len(), 1, "one pooled connection");
517 assert!(pool.get(&2).unwrap().is_alive(), "alive after first send");
518 }
519
520 drop(recv_v1);
524
525 let deadline = std::time::Instant::now() + std::time::Duration::from_secs(5);
528 loop {
529 let alive = {
530 let pool = sender.pool.read().await;
531 pool.get(&2).is_none_or(|c| c.is_alive())
532 };
533 if !alive {
534 break;
535 }
536 assert!(
537 std::time::Instant::now() < deadline,
538 "reader task did not flip alive=false within 5s",
539 );
540 tokio::time::sleep(std::time::Duration::from_millis(25)).await;
541 }
542
543 let recv_v2 = bind_on_loopback(2).await;
547 let addr_v2 = recv_v2.local_addr();
548 assert_ne!(
549 addr_v1, addr_v2,
550 "ephemeral rebind must pick a different port",
551 );
552 sender.register_peer(2, addr_v2).await;
553
554 sender
558 .send_to(2, &ShuffleMessage::Hello(222))
559 .await
560 .expect("reconnect after restart");
561
562 let (from, msg) = recv_v2.recv().await.unwrap();
563 assert_eq!(from, 1);
564 assert_eq!(
565 msg,
566 ShuffleMessage::Hello(222),
567 "delivered to the restarted peer, not the dead one",
568 );
569
570 let pool = sender.pool.read().await;
572 assert_eq!(pool.len(), 1);
573 assert!(pool.get(&2).unwrap().is_alive());
574 }
575}