Skip to main content

laminar_core/shuffle/
transport.rs

1//! TCP shuffle: a per-peer connection pool for senders, an accept loop
2//! for receivers. Each frame carries a node id in its handshake so the
3//! receiver can attribute incoming traffic. See
4//! [`super::message`] for the wire format.
5
6use std::io;
7use std::net::SocketAddr;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10
11use rustc_hash::FxHashMap;
12use tokio::net::{TcpListener, TcpStream};
13use tokio::sync::{mpsc, Mutex, RwLock};
14use tokio::task::JoinHandle;
15
16use super::message::{read_message, write_message, ShuffleMessage};
17use crate::checkpoint::barrier::CheckpointBarrier;
18
19#[cfg(feature = "cluster-unstable")]
20use crate::cluster::control::ClusterKv;
21
22/// Gossip KV key used by [`ShuffleReceiver::bind_with_kv`] to publish
23/// the listener's socket address, and by [`ShuffleSender`] to discover
24/// peer addresses on first contact. Value: the bound socket address
25/// formatted via `SocketAddr::to_string()`.
26#[cfg(feature = "cluster-unstable")]
27pub const SHUFFLE_ADDR_KEY: &str = "shuffle:addr";
28
29/// Peer-local identifier on the wire. Matches
30/// `cluster::discovery::NodeId`'s inner type for seamless conversion.
31pub type ShufflePeerId = u64;
32
33/// One active TCP connection in the shuffle fabric. Internal to the
34/// transport — callers hold a [`ShuffleSender`] or [`ShuffleReceiver`].
35struct ShuffleConnection {
36    /// Write half. Parked behind a mutex so multiple operators can
37    /// share one connection without interleaving frames.
38    writer: Mutex<tokio::io::WriteHalf<TcpStream>>,
39    /// The reader task. Kept so dropping the connection cancels it.
40    reader: JoinHandle<()>,
41    /// Liveness flag shared with the reader task. Flipped to `false`
42    /// when the reader exits (peer closed the socket cleanly OR an
43    /// IO error hit). [`ShuffleSender::connection_for`] consults it
44    /// before handing out the cached `Arc`; dead entries are purged
45    /// and the next send reconnects.
46    alive: Arc<AtomicBool>,
47}
48
49impl ShuffleConnection {
50    async fn send(&self, msg: &ShuffleMessage) -> io::Result<()> {
51        let mut w = self.writer.lock().await;
52        write_message(&mut *w, msg).await
53    }
54
55    fn is_alive(&self) -> bool {
56        self.alive.load(Ordering::Acquire)
57    }
58}
59
60impl Drop for ShuffleConnection {
61    fn drop(&mut self) {
62        self.reader.abort();
63    }
64}
65
66/// Lazy pool of outbound connections, keyed by peer id. Addresses go
67/// in via `register_peer` (manual) or via the KV on first send.
68pub struct ShuffleSender {
69    local_id: ShufflePeerId,
70    peers: RwLock<FxHashMap<ShufflePeerId, SocketAddr>>,
71    pool: RwLock<FxHashMap<ShufflePeerId, Arc<ShuffleConnection>>>,
72    #[cfg(feature = "cluster-unstable")]
73    kv: Option<Arc<dyn ClusterKv>>,
74}
75
76impl std::fmt::Debug for ShuffleSender {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        f.debug_struct("ShuffleSender")
79            .field("local_id", &self.local_id)
80            .finish_non_exhaustive()
81    }
82}
83
84impl ShuffleSender {
85    /// Empty sender. Peers must be added via `register_peer` or
86    /// discovered via the KV in `with_kv` before any `send_to`.
87    #[must_use]
88    pub fn new(local_id: ShufflePeerId) -> Self {
89        Self {
90            local_id,
91            peers: RwLock::new(FxHashMap::default()),
92            pool: RwLock::new(FxHashMap::default()),
93            #[cfg(feature = "cluster-unstable")]
94            kv: None,
95        }
96    }
97
98    /// Construct a sender that falls back to `kv` when `send_to` is
99    /// called for a peer not previously registered. The KV is read
100    /// from the peer's own state at [`SHUFFLE_ADDR_KEY`].
101    #[cfg(feature = "cluster-unstable")]
102    #[must_use]
103    pub fn with_kv(local_id: ShufflePeerId, kv: Arc<dyn ClusterKv>) -> Self {
104        let mut s = Self::new(local_id);
105        s.kv = Some(kv);
106        s
107    }
108
109    /// Register (or update) a peer's shuffle address. Must be called
110    /// before `send_to(peer, ..)`.
111    pub async fn register_peer(&self, peer: ShufflePeerId, addr: SocketAddr) {
112        self.peers.write().await.insert(peer, addr);
113    }
114
115    /// Send `msg` to `peer`, opening a connection if necessary.
116    ///
117    /// # Errors
118    /// Returns `io::Error` on connect failure, peer-unregistered, or
119    /// frame write failure.
120    pub async fn send_to(&self, peer: ShufflePeerId, msg: &ShuffleMessage) -> io::Result<()> {
121        let conn = self.connection_for(peer).await?;
122        conn.send(msg).await
123    }
124
125    /// Ship `barrier` to every peer in order. Short-circuits on the
126    /// first send failure; the gossip side-channel is the
127    /// authoritative announcement so a partial fan-out is tolerable.
128    ///
129    /// # Errors
130    /// Returns the first `io::Error` from any peer's `send_to`.
131    pub async fn fan_out_barrier(
132        &self,
133        peers: &[ShufflePeerId],
134        barrier: CheckpointBarrier,
135    ) -> io::Result<()> {
136        let msg = ShuffleMessage::Barrier(barrier);
137        for &peer in peers {
138            self.send_to(peer, &msg).await?;
139        }
140        Ok(())
141    }
142
143    /// Look up `peer`'s shuffle address from the cluster KV and
144    /// register it on success. Returns `None` when no KV is attached,
145    /// the peer has no entry yet, or the entry can't be parsed.
146    #[cfg(feature = "cluster-unstable")]
147    async fn discover_peer(&self, peer: ShufflePeerId) -> Option<SocketAddr> {
148        let kv = self.kv.as_ref()?;
149        let raw = kv
150            .read_from(crate::cluster::discovery::NodeId(peer), SHUFFLE_ADDR_KEY)
151            .await?;
152        let addr: SocketAddr = raw.parse().ok()?;
153        self.peers.write().await.insert(peer, addr);
154        Some(addr)
155    }
156
157    async fn connection_for(&self, peer: ShufflePeerId) -> io::Result<Arc<ShuffleConnection>> {
158        // Fast path: hand out the cached connection if its reader task
159        // is still alive. A dead connection is one whose reader has
160        // exited (peer closed or IO error) — the socket's write half is
161        // useless even though the `Arc` still exists.
162        if let Some(existing) = self.pool.read().await.get(&peer).cloned() {
163            if existing.is_alive() {
164                return Ok(existing);
165            }
166        }
167
168        // Purge the dead pool entry so we reconnect below. We do NOT
169        // touch `peers` here: the caller (gossip watcher / explicit
170        // `register_peer`) owns the address mapping. If the cached
171        // address is stale, `TcpStream::connect` will surface that
172        // as a connect error and the caller retries after re-register;
173        // if callers use KV discovery, a missing `peers` entry triggers
174        // `discover_peer` in the next block.
175        {
176            let mut pool = self.pool.write().await;
177            if let Some(c) = pool.get(&peer) {
178                if !c.is_alive() {
179                    pool.remove(&peer);
180                }
181            }
182        }
183
184        let addr = if let Some(a) = self.peers.read().await.get(&peer).copied() {
185            a
186        } else {
187            #[cfg(feature = "cluster-unstable")]
188            let discovered = self.discover_peer(peer).await;
189            #[cfg(not(feature = "cluster-unstable"))]
190            let discovered: Option<SocketAddr> = None;
191            discovered.ok_or_else(|| {
192                io::Error::new(
193                    io::ErrorKind::NotFound,
194                    format!("peer {peer} has no registered shuffle address"),
195                )
196            })?
197        };
198
199        // Open + handshake without holding the pool write lock.
200        let stream = TcpStream::connect(addr).await?;
201        stream.set_nodelay(true)?;
202        let (mut reader_half, mut writer_half) = tokio::io::split(stream);
203        write_message(&mut writer_half, &ShuffleMessage::Hello(self.local_id)).await?;
204
205        // Outbound connection's read half: drain frames until the peer
206        // closes or the socket errors. Nothing currently acts on reply
207        // traffic on the outbound side (credit frames are handled by
208        // the inbound `ShuffleReceiver` on the other instance), so we
209        // just discard. Flip `alive` to `false` on exit so the next
210        // `connection_for` call evicts this entry and reconnects.
211        let alive = Arc::new(AtomicBool::new(true));
212        let alive_for_reader = Arc::clone(&alive);
213        let reader = tokio::spawn(async move {
214            let _ = peer;
215            loop {
216                match read_message(&mut reader_half).await {
217                    Ok(ShuffleMessage::Close(_)) | Err(_) => break,
218                    Ok(_) => {}
219                }
220            }
221            alive_for_reader.store(false, Ordering::Release);
222        });
223
224        let conn = Arc::new(ShuffleConnection {
225            writer: Mutex::new(writer_half),
226            reader,
227            alive,
228        });
229
230        // Race: another task may have created a connection in the
231        // meantime. Cheap to discard ours — but only if the winner is
232        // still alive. A dead winner (raced with an earlier reader
233        // exit) must be superseded.
234        let mut pool = self.pool.write().await;
235        if let Some(winner) = pool.get(&peer).cloned() {
236            if winner.is_alive() {
237                return Ok(winner);
238            }
239        }
240        pool.insert(peer, Arc::clone(&conn));
241        Ok(conn)
242    }
243}
244
245/// Inbound side of the shuffle fabric.
246///
247/// Binds a `TcpListener` and surfaces every frame received from any
248/// peer — prefixed with that peer's id — on the `subscribe` channel.
249pub struct ShuffleReceiver {
250    local_id: ShufflePeerId,
251    local_addr: SocketAddr,
252    accept: JoinHandle<()>,
253    /// Per-peer reader tasks spawned by the accept loop. Tracked so
254    /// [`Drop`] can abort them — otherwise detached tasks keep the
255    /// socket open, peers never see EOF, and senders can't detect
256    /// that we went away.
257    peer_tasks: Arc<parking_lot::Mutex<Vec<JoinHandle<()>>>>,
258    rx: Mutex<mpsc::UnboundedReceiver<(ShufflePeerId, ShuffleMessage)>>,
259}
260
261impl Drop for ShuffleReceiver {
262    fn drop(&mut self) {
263        // Abort accept first so no new peer tasks are spawned, then
264        // abort any in-flight peer tasks. Aborting drops each socket,
265        // which surfaces as EOF on the sender's reader half — exactly
266        // what the stale-connection purge in `connection_for` relies on.
267        self.accept.abort();
268        for h in self.peer_tasks.lock().drain(..) {
269            h.abort();
270        }
271    }
272}
273
274impl std::fmt::Debug for ShuffleReceiver {
275    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276        f.debug_struct("ShuffleReceiver")
277            .field("local_id", &self.local_id)
278            .field("local_addr", &self.local_addr)
279            .finish_non_exhaustive()
280    }
281}
282
283impl ShuffleReceiver {
284    /// Bind on `addr` and start accepting peer connections. The bound
285    /// socket address is surfaced via [`Self::local_addr`] so callers
286    /// using an ephemeral port can register themselves with peers.
287    ///
288    /// # Errors
289    /// Returns `io::Error` on bind failure.
290    pub async fn bind(local_id: ShufflePeerId, addr: SocketAddr) -> io::Result<Self> {
291        Self::bind_impl(local_id, addr).await
292    }
293
294    /// Bind + publish the listener's address into `kv` under
295    /// [`SHUFFLE_ADDR_KEY`] so that peer [`ShuffleSender`]s can
296    /// discover us via [`ShuffleSender::with_kv`].
297    ///
298    /// # Errors
299    /// Returns `io::Error` on bind failure.
300    #[cfg(feature = "cluster-unstable")]
301    pub async fn bind_with_kv(
302        local_id: ShufflePeerId,
303        addr: SocketAddr,
304        kv: Arc<dyn ClusterKv>,
305    ) -> io::Result<Self> {
306        let recv = Self::bind_impl(local_id, addr).await?;
307        kv.write(SHUFFLE_ADDR_KEY, recv.local_addr.to_string())
308            .await;
309        Ok(recv)
310    }
311
312    async fn bind_impl(local_id: ShufflePeerId, addr: SocketAddr) -> io::Result<Self> {
313        let listener = TcpListener::bind(addr).await?;
314        let local_addr = listener.local_addr()?;
315        let (tx, rx) = mpsc::unbounded_channel();
316
317        let peer_tasks: Arc<parking_lot::Mutex<Vec<JoinHandle<()>>>> =
318            Arc::new(parking_lot::Mutex::new(Vec::new()));
319        let accept = tokio::spawn(Self::accept_loop(listener, tx, Arc::clone(&peer_tasks)));
320
321        Ok(Self {
322            local_id,
323            local_addr,
324            accept,
325            peer_tasks,
326            rx: Mutex::new(rx),
327        })
328    }
329
330    /// Local socket address the listener is bound to.
331    #[must_use]
332    pub fn local_addr(&self) -> SocketAddr {
333        self.local_addr
334    }
335
336    /// Await the next `(peer_id, msg)` from any connected peer.
337    pub async fn recv(&self) -> Option<(ShufflePeerId, ShuffleMessage)> {
338        self.rx.lock().await.recv().await
339    }
340
341    /// Drain every currently-available `(peer_id, msg)` without blocking.
342    /// Returns immediately when the internal queue is empty.
343    ///
344    /// Used by the row-shuffle aggregator path to pull remote rows into
345    /// the current streaming cycle without waiting for more. Uses
346    /// `tokio::sync::Mutex::try_lock` so a concurrent `recv()` doesn't
347    /// block us — we just skip this tick when contended (next tick picks
348    /// up the messages).
349    pub fn drain_available(&self) -> Vec<(ShufflePeerId, ShuffleMessage)> {
350        let Ok(mut guard) = self.rx.try_lock() else {
351            return Vec::new();
352        };
353        let mut out = Vec::new();
354        while let Ok(msg) = guard.try_recv() {
355            out.push(msg);
356        }
357        out
358    }
359
360    async fn accept_loop(
361        listener: TcpListener,
362        tx: mpsc::UnboundedSender<(ShufflePeerId, ShuffleMessage)>,
363        peer_tasks: Arc<parking_lot::Mutex<Vec<JoinHandle<()>>>>,
364    ) {
365        loop {
366            let Ok((stream, _peer_addr)) = listener.accept().await else {
367                break;
368            };
369            if stream.set_nodelay(true).is_err() {
370                continue;
371            }
372            let tx = tx.clone();
373            let handle = tokio::spawn(Self::per_peer_loop(stream, tx));
374            // Sweep finished tasks so the vec doesn't grow unbounded
375            // under a long-lived receiver with churning peers, then
376            // track the fresh one so Drop can abort it.
377            let mut tasks = peer_tasks.lock();
378            tasks.retain(|h| !h.is_finished());
379            tasks.push(handle);
380        }
381    }
382
383    async fn per_peer_loop(
384        stream: TcpStream,
385        tx: mpsc::UnboundedSender<(ShufflePeerId, ShuffleMessage)>,
386    ) {
387        let (mut reader_half, _writer_half) = tokio::io::split(stream);
388        // Expect Hello first. Anything else means the peer is broken.
389        let Ok(ShuffleMessage::Hello(peer)) = read_message(&mut reader_half).await else {
390            return;
391        };
392        loop {
393            match read_message(&mut reader_half).await {
394                Ok(ShuffleMessage::Close(_)) | Err(_) => break,
395                Ok(msg) => {
396                    if tx.send((peer, msg)).is_err() {
397                        break;
398                    }
399                }
400            }
401        }
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408
409    async fn bind_on_loopback(local_id: ShufflePeerId) -> ShuffleReceiver {
410        ShuffleReceiver::bind(local_id, "127.0.0.1:0".parse().unwrap())
411            .await
412            .expect("bind")
413    }
414
415    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
416    async fn sender_to_receiver_delivers_with_peer_attribution() {
417        let recv = bind_on_loopback(2).await;
418        let recv_addr = recv.local_addr();
419
420        let sender = ShuffleSender::new(1);
421        sender.register_peer(2, recv_addr).await;
422        sender
423            .send_to(2, &ShuffleMessage::Hello(1234))
424            .await
425            .unwrap();
426
427        let (from, msg) = recv.recv().await.unwrap();
428        assert_eq!(from, 1, "receiver attributes frame to sender id");
429        assert_eq!(msg, ShuffleMessage::Hello(1234));
430    }
431
432    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
433    async fn sender_reuses_connection_across_sends() {
434        let recv = bind_on_loopback(2).await;
435        let sender = ShuffleSender::new(1);
436        sender.register_peer(2, recv.local_addr()).await;
437
438        for delta in [10u64, 20, 30, 40] {
439            sender
440                .send_to(2, &ShuffleMessage::Hello(delta))
441                .await
442                .unwrap();
443        }
444
445        let mut got = Vec::new();
446        for _ in 0..4 {
447            got.push(recv.recv().await.unwrap().1);
448        }
449        assert_eq!(
450            got,
451            vec![
452                ShuffleMessage::Hello(10),
453                ShuffleMessage::Hello(20),
454                ShuffleMessage::Hello(30),
455                ShuffleMessage::Hello(40),
456            ]
457        );
458        // Pool holds exactly one connection to peer 2.
459        assert_eq!(sender.pool.read().await.len(), 1);
460    }
461
462    #[cfg(feature = "cluster-unstable")]
463    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
464    async fn discover_peer_reads_registered_address_from_kv() {
465        use crate::cluster::control::{ClusterKv, InMemoryKv};
466        use crate::cluster::discovery::NodeId;
467
468        // Seed node 2's address into node 1's local KV, then verify
469        // `discover_peer` pulls it out and caches it. Covers the
470        // discovery glue without involving real TCP.
471        let kv = Arc::new(InMemoryKv::new(NodeId(1)));
472        kv.seed(NodeId(2), SHUFFLE_ADDR_KEY, "127.0.0.1:54321".into());
473        let sender = ShuffleSender::with_kv(1, kv as Arc<dyn ClusterKv>);
474
475        let expected: SocketAddr = "127.0.0.1:54321".parse().unwrap();
476        let addr = sender.discover_peer(2).await.expect("peer found");
477        assert_eq!(addr, expected);
478        assert_eq!(sender.peers.read().await.get(&2).copied(), Some(expected));
479    }
480
481    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
482    async fn send_to_unregistered_peer_errors() {
483        let sender = ShuffleSender::new(1);
484        let err = sender
485            .send_to(99, &ShuffleMessage::Hello(1))
486            .await
487            .unwrap_err();
488        assert_eq!(err.kind(), io::ErrorKind::NotFound);
489    }
490
491    /// When a peer restarts at a different address, the sender's cached
492    /// connection flips dead (reader exits on EOF), the next `send_to`
493    /// purges the stale entry and reconnects against the
494    /// freshly-registered address.
495    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
496    async fn send_reconnects_after_peer_restart_at_new_address() {
497        // 1. Peer binds on an ephemeral port.
498        let recv_v1 = bind_on_loopback(2).await;
499        let addr_v1 = recv_v1.local_addr();
500
501        let sender = ShuffleSender::new(1);
502        sender.register_peer(2, addr_v1).await;
503
504        // 2. First send establishes the pooled connection.
505        sender
506            .send_to(2, &ShuffleMessage::Hello(111))
507            .await
508            .unwrap();
509        let (from, msg) = recv_v1.recv().await.unwrap();
510        assert_eq!(from, 1);
511        assert_eq!(msg, ShuffleMessage::Hello(111));
512
513        // Pool has exactly the one connection, and it's alive.
514        {
515            let pool = sender.pool.read().await;
516            assert_eq!(pool.len(), 1, "one pooled connection");
517            assert!(pool.get(&2).unwrap().is_alive(), "alive after first send");
518        }
519
520        // 3. "Crash" the peer: drop the receiver so the listener and
521        //    the per-peer task go away. The sender's reader half will
522        //    observe EOF and flip `alive=false`.
523        drop(recv_v1);
524
525        // Spin until the reader task notices the shutdown. Bounded
526        // wait so a hung test fails loudly instead of running forever.
527        let deadline = std::time::Instant::now() + std::time::Duration::from_secs(5);
528        loop {
529            let alive = {
530                let pool = sender.pool.read().await;
531                pool.get(&2).is_none_or(|c| c.is_alive())
532            };
533            if !alive {
534                break;
535            }
536            assert!(
537                std::time::Instant::now() < deadline,
538                "reader task did not flip alive=false within 5s",
539            );
540            tokio::time::sleep(std::time::Duration::from_millis(25)).await;
541        }
542
543        // 4. Peer "restarts" on a fresh ephemeral port — almost
544        //    certainly different from addr_v1. Re-register its new
545        //    address on the sender (mimicking gossip discovery).
546        let recv_v2 = bind_on_loopback(2).await;
547        let addr_v2 = recv_v2.local_addr();
548        assert_ne!(
549            addr_v1, addr_v2,
550            "ephemeral rebind must pick a different port",
551        );
552        sender.register_peer(2, addr_v2).await;
553
554        // 5. `send_to` must purge the dead entry and reconnect against
555        //    addr_v2. If the purge is broken, we'd either reuse the
556        //    dead writer (io error) or dial the stale addr_v1.
557        sender
558            .send_to(2, &ShuffleMessage::Hello(222))
559            .await
560            .expect("reconnect after restart");
561
562        let (from, msg) = recv_v2.recv().await.unwrap();
563        assert_eq!(from, 1);
564        assert_eq!(
565            msg,
566            ShuffleMessage::Hello(222),
567            "delivered to the restarted peer, not the dead one",
568        );
569
570        // Pool still holds exactly one connection — the fresh one.
571        let pool = sender.pool.read().await;
572        assert_eq!(pool.len(), 1);
573        assert!(pool.get(&2).unwrap().is_alive());
574    }
575}