Skip to main content

laminar_core/shuffle/
transport.rs

1//! Cross-node shuffle transport over Tonic gRPC client-streaming.
2//!
3//! Each sender opens a client-streaming `Shuffle` RPC per peer and pushes a
4//! forward-only stream of `ShuffleFrame`s; the receiver runs the
5//! `ShuffleTransport` service, attributes every stream to the peer announced in
6//! its leading `Hello`, and surfaces decoded [`ShuffleMessage`]s on a bounded
7//! crossfire MPSC queue. Backpressure is the HTTP/2 flow-control window plus
8//! that bounded queue. See [`super::message`] for the per-frame payloads and
9//! [`crate::serialization`] for the Arrow IPC (de)serialization of `VnodeData`.
10//!
11//! The real gRPC path is compiled under the `cluster` feature (which
12//! pulls in `tonic`/`prost`). A default build keeps the same public API via a
13//! networking-free shim so the types referenced by `laminar-db`/`laminar-server`
14//! signatures still compile without the cluster dependencies.
15
16use super::message::ShuffleMessage;
17use crate::checkpoint::barrier::CheckpointBarrier;
18
19/// Bounded capacity for the inbound shuffle queue. One consumer per
20/// [`ShuffleReceiver`] (the cluster repartition dispatcher) drains it; a slow
21/// consumer parks the per-stream service handler on the bounded `send`, so
22/// backpressure flows back over HTTP/2 to the sender.
23const SHUFFLE_RECV_QUEUE: usize = 1024;
24
25/// Peer-local identifier on the wire. Matches `cluster::discovery::NodeId`'s
26/// inner type for seamless conversion.
27pub type ShufflePeerId = u64;
28
29/// Gossip KV key used by [`ShuffleReceiver::bind_with_kv`] to publish the
30/// listener's socket address, and by [`ShuffleSender`] to discover peer
31/// addresses on first contact. Value: the bound socket address formatted via
32/// `SocketAddr::to_string()`.
33#[cfg(feature = "cluster")]
34pub const SHUFFLE_ADDR_KEY: &str = "shuffle:addr";
35
36#[cfg(feature = "cluster")]
37#[allow(
38    clippy::doc_markdown,
39    clippy::default_trait_access,
40    clippy::missing_const_for_fn,
41    clippy::must_use_candidate,
42    clippy::too_many_lines,
43    missing_docs
44)]
45pub(crate) mod shuffle_v1 {
46    tonic::include_proto!("laminar.shuffle.v1");
47}
48
49// ---------------------------------------------------------------------------
50// Per-stage / per-barrier holdover shared by both builds.
51// ---------------------------------------------------------------------------
52
53/// Inbound-side holdover state lifted out of [`ShuffleReceiver`] so both the
54/// gRPC and default builds share the staging semantics that barrier alignment
55/// depends on: frames pulled for another stage are bucketed for that stage's own
56/// drainer, and barriers pulled mid-cycle are stashed (never dropped) for the
57/// aligning checkpoint.
58#[derive(Default)]
59struct Holdover {
60    staged: parking_lot::Mutex<rustc_hash::FxHashMap<String, Vec<arrow_array::RecordBatch>>>,
61    staged_barriers: parking_lot::Mutex<Vec<(ShufflePeerId, CheckpointBarrier)>>,
62}
63
64// ===========================================================================
65// gRPC implementation (cluster).
66// ===========================================================================
67
68#[cfg(feature = "cluster")]
69mod grpc {
70    use std::collections::hash_map::Entry;
71    use std::io;
72    use std::net::SocketAddr;
73    use std::sync::atomic::{AtomicBool, Ordering};
74    use std::sync::Arc;
75
76    use arrow_array::RecordBatch;
77    use crossfire::{mpsc, AsyncRx, MAsyncTx};
78    use futures::StreamExt as _;
79    use parking_lot::Mutex;
80    use rustc_hash::FxHashMap;
81    use tokio::task::JoinHandle;
82    use tonic::transport::{Channel, Server};
83    use tonic::Request;
84
85    use super::shuffle_v1::shuffle_frame;
86    use super::shuffle_v1::shuffle_transport_client::ShuffleTransportClient;
87    use super::shuffle_v1::shuffle_transport_server::{ShuffleTransport, ShuffleTransportServer};
88    use super::shuffle_v1::{Barrier, Close, Hello, ShuffleFrame, ShuffleSummary, VnodeData};
89    use super::{Holdover, ShuffleMessage, ShufflePeerId, SHUFFLE_ADDR_KEY, SHUFFLE_RECV_QUEUE};
90    use crate::checkpoint::barrier::CheckpointBarrier;
91    use crate::cluster::control::ClusterKv;
92    use crate::serialization::{BatchStreamDecoder, BatchStreamEncoder};
93
94    /// Outbound queue capacity per peer. Bounds per-peer buffering before the
95    /// HTTP/2 window applies its own backpressure.
96    const SHUFFLE_SEND_QUEUE: usize = 1024;
97
98    /// Inbound queue item flavor (kept as a `type` so the parked-behind-mutex
99    /// receiver field doesn't trip clippy's `type_complexity`).
100    type InboundRx = AsyncRx<mpsc::Array<(ShufflePeerId, ShuffleMessage)>>;
101    type InboundTx = MAsyncTx<mpsc::Array<(ShufflePeerId, ShuffleMessage)>>;
102
103    /// Map a `tonic::Status` / `tonic::transport::Error` (or any `Display`) into
104    /// `io::Error` so the public API keeps its `io::Result` shape.
105    fn io_err<E: std::fmt::Display>(e: E) -> io::Error {
106        io::Error::other(e.to_string())
107    }
108
109    /// Encode a [`ShuffleMessage`] into the wire [`ShuffleFrame`]. The per-stage
110    /// [`BatchStreamEncoder`] writes the Arrow schema only on a stage's first
111    /// `VnodeData`; later batches are schema-less. Runs in the connection driver
112    /// task, off the Ring 0 compute thread.
113    fn encode_message(
114        msg: &ShuffleMessage,
115        encoders: &mut FxHashMap<String, BatchStreamEncoder>,
116    ) -> Result<ShuffleFrame, tonic::Status> {
117        let kind = match msg {
118            ShuffleMessage::Hello(node_id) => {
119                shuffle_frame::Kind::Hello(Hello { node_id: *node_id })
120            }
121            ShuffleMessage::Barrier(b) => shuffle_frame::Kind::Barrier(Barrier {
122                checkpoint_id: b.checkpoint_id,
123                epoch: b.epoch,
124                flags: b.flags,
125            }),
126            ShuffleMessage::VnodeData(stage, vnode, batch) => {
127                let encoder = match encoders.entry(stage.clone()) {
128                    Entry::Occupied(e) => {
129                        let enc = e.into_mut();
130                        // Fail loudly rather than desync the peer's IPC decoder.
131                        let schema = batch.schema();
132                        if !Arc::ptr_eq(enc.schema(), &schema) && *enc.schema() != schema {
133                            return Err(tonic::Status::internal(format!(
134                                "shuffle stage '{stage}' changed schema mid-connection",
135                            )));
136                        }
137                        enc
138                    }
139                    Entry::Vacant(v) => {
140                        v.insert(BatchStreamEncoder::new(&batch.schema()).map_err(|e| {
141                            tonic::Status::internal(format!("shuffle ipc encoder init: {e}"))
142                        })?)
143                    }
144                };
145                let arrow_ipc = encoder
146                    .encode(batch)
147                    .map_err(|e| tonic::Status::internal(format!("shuffle ipc encode: {e}")))?;
148                shuffle_frame::Kind::VnodeData(VnodeData {
149                    stage: stage.clone(),
150                    vnode: *vnode,
151                    arrow_ipc,
152                })
153            }
154            ShuffleMessage::Close(reason) => shuffle_frame::Kind::Close(Close {
155                reason: reason.clone(),
156            }),
157        };
158        Ok(ShuffleFrame { kind: Some(kind) })
159    }
160
161    /// One lazily-opened client-streaming call to a peer. The driver task pulls
162    /// [`ShuffleMessage`]s from `tx`'s queue, serializes them to wire frames, and
163    /// feeds the gRPC request stream; it flips `alive=false` on the first transport
164    /// error (or connect failure), so the next `send_to` purges this entry and
165    /// reconnects. Buffering messages (not frames) keeps the CPU-heavy Arrow IPC
166    /// serialization off the caller's (Ring 0 compute) thread.
167    struct PeerConn {
168        tx: MAsyncTx<mpsc::Array<ShuffleMessage>>,
169        alive: Arc<AtomicBool>,
170        driver: JoinHandle<()>,
171    }
172
173    impl PeerConn {
174        fn is_alive(&self) -> bool {
175            self.alive.load(Ordering::Acquire)
176        }
177    }
178
179    impl Drop for PeerConn {
180        fn drop(&mut self) {
181            self.driver.abort();
182        }
183    }
184
185    /// Lazy pool of outbound client-streaming calls, keyed by peer id.
186    pub struct ShuffleSender {
187        local_id: ShufflePeerId,
188        peers: Mutex<FxHashMap<ShufflePeerId, SocketAddr>>,
189        pool: Mutex<FxHashMap<ShufflePeerId, Arc<PeerConn>>>,
190        kv: Option<Arc<dyn ClusterKv>>,
191    }
192
193    impl std::fmt::Debug for ShuffleSender {
194        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195            f.debug_struct("ShuffleSender")
196                .field("local_id", &self.local_id)
197                .finish_non_exhaustive()
198        }
199    }
200
201    impl ShuffleSender {
202        /// Empty sender. Peers are added via [`Self::register_peer`] or discovered
203        /// via the KV (in [`Self::with_kv`]) before any `send_to`.
204        #[must_use]
205        pub fn new(local_id: ShufflePeerId) -> Self {
206            Self {
207                local_id,
208                peers: Mutex::new(FxHashMap::default()),
209                pool: Mutex::new(FxHashMap::default()),
210                kv: None,
211            }
212        }
213
214        /// Sender that falls back to `kv` (key [`SHUFFLE_ADDR_KEY`] on the peer's
215        /// own state) when `send_to` targets a peer not previously registered.
216        #[must_use]
217        pub fn with_kv(local_id: ShufflePeerId, kv: Arc<dyn ClusterKv>) -> Self {
218            let mut s = Self::new(local_id);
219            s.kv = Some(kv);
220            s
221        }
222
223        /// Register (or update) a peer's shuffle address.
224        // Body is sync, but the signature stays async to match the contract
225        // callers `.await`.
226        #[allow(clippy::unused_async)]
227        pub async fn register_peer(&self, peer: ShufflePeerId, addr: SocketAddr) {
228            self.peers.lock().insert(peer, addr);
229        }
230
231        /// Send `msg` to `peer`, opening a client-streaming call if necessary.
232        ///
233        /// # Errors
234        /// Returns `io::Error` when the peer is unregistered/undiscoverable, the
235        /// endpoint cannot be built, or the per-peer stream has shut down.
236        pub async fn send_to(&self, peer: ShufflePeerId, msg: &ShuffleMessage) -> io::Result<()> {
237            let conn = self.connection_for(peer).await?;
238            // The clone is cheap (`RecordBatch` is an Arc bump); the driver
239            // task serializes to Arrow IPC off this thread.
240            conn.tx.send(msg.clone()).await.map_err(|_| {
241                io::Error::new(
242                    io::ErrorKind::BrokenPipe,
243                    format!("shuffle stream to peer {peer} closed"),
244                )
245            })
246        }
247
248        /// Ship `barrier` to every peer in order, short-circuiting on the first
249        /// failure (the gossip side-channel is authoritative, so a partial
250        /// fan-out is tolerable).
251        ///
252        /// # Errors
253        /// Returns the first `io::Error` from any peer's `send_to`.
254        pub async fn fan_out_barrier(
255            &self,
256            peers: &[ShufflePeerId],
257            barrier: CheckpointBarrier,
258        ) -> io::Result<()> {
259            let msg = ShuffleMessage::Barrier(barrier);
260            for &peer in peers {
261                self.send_to(peer, &msg).await?;
262            }
263            Ok(())
264        }
265
266        /// Resolve `peer`'s address from the KV (`SHUFFLE_ADDR_KEY` on the peer's
267        /// own state) and cache it. `None` when no KV, no entry, or unparseable.
268        async fn discover_peer(&self, peer: ShufflePeerId) -> Option<SocketAddr> {
269            let kv = self.kv.as_ref()?;
270            let raw = kv
271                .read_from(crate::cluster::discovery::NodeId(peer), SHUFFLE_ADDR_KEY)
272                .await?;
273            let addr: SocketAddr = raw.parse().ok()?;
274            self.peers.lock().insert(peer, addr);
275            Some(addr)
276        }
277
278        async fn connection_for(&self, peer: ShufflePeerId) -> io::Result<Arc<PeerConn>> {
279            if let Some(existing) = self.pool.lock().get(&peer).cloned() {
280                if existing.is_alive() {
281                    return Ok(existing);
282                }
283            }
284            // Purge a dead entry so we reopen the call below.
285            self.pool.lock().retain(|p, c| *p != peer || c.is_alive());
286
287            // Re-resolve on reconnect (peers may restart on a new port); fall
288            // back to a statically registered address when there's no KV.
289            let addr = match self.discover_peer(peer).await {
290                Some(addr) => addr,
291                None => self.peers.lock().get(&peer).copied().ok_or_else(|| {
292                    io::Error::new(
293                        io::ErrorKind::NotFound,
294                        format!("peer {peer} has no registered shuffle address"),
295                    )
296                })?,
297            };
298
299            let conn = Arc::new(open_call(self.local_id, addr)?);
300
301            // Race: another task may have opened a live call meanwhile.
302            let mut pool = self.pool.lock();
303            if let Some(winner) = pool.get(&peer).cloned() {
304                if winner.is_alive() {
305                    return Ok(winner);
306                }
307            }
308            pool.insert(peer, Arc::clone(&conn));
309            Ok(conn)
310        }
311    }
312
313    /// Open a client-streaming `Shuffle` call to `addr`, sending `Hello(local_id)`
314    /// as the first frame. Connecting happens inside the driver task so this stays
315    /// non-blocking; a connect failure flips `alive` so the next `send_to` retries.
316    fn open_call(local_id: ShufflePeerId, addr: SocketAddr) -> io::Result<PeerConn> {
317        let endpoint = crate::cluster::control::tls::client_endpoint(&addr.to_string())
318            .map_err(io_err)?
319            .tcp_nodelay(true);
320        let (tx, rx) = mpsc::bounded_async::<ShuffleMessage>(SHUFFLE_SEND_QUEUE);
321        let alive = Arc::new(AtomicBool::new(true));
322        let alive_for_driver = Arc::clone(&alive);
323
324        // Request stream: a single `Hello` chained onto an unfold over the
325        // per-peer crossfire receiver (no `async-stream` dependency needed).
326        // The unfold serializes dequeued messages here, in the driver task.
327        let hello = ShuffleFrame {
328            kind: Some(shuffle_frame::Kind::Hello(Hello { node_id: local_id })),
329        };
330        let encoders: FxHashMap<String, BatchStreamEncoder> = FxHashMap::default();
331        let outbound = futures::stream::once(async move { hello }).chain(futures::stream::unfold(
332            (rx, encoders),
333            |(rx, mut encoders)| async move {
334                let msg = rx.recv().await.ok()?;
335                match encode_message(&msg, &mut encoders) {
336                    Ok(frame) => Some((frame, (rx, encoders))),
337                    Err(e) => {
338                        // An unencodable batch would desync the stage's IPC
339                        // stream; half-close so the peer reconnects fresh.
340                        tracing::warn!(error = %e, "shuffle frame encode failed; closing stream");
341                        None
342                    }
343                }
344            },
345        ));
346
347        let driver = tokio::spawn(async move {
348            let Ok(channel) = endpoint.connect().await else {
349                alive_for_driver.store(false, Ordering::Release);
350                return;
351            };
352            let mut client = ShuffleTransportClient::<Channel>::new(channel);
353            // The call returns when the server responds to half-close or the
354            // transport breaks; either way the peer connection is finished.
355            let _ = client.shuffle(Request::new(outbound)).await;
356            alive_for_driver.store(false, Ordering::Release);
357        });
358
359        Ok(PeerConn { tx, alive, driver })
360    }
361
362    /// Inbound side of the shuffle fabric: a Tonic `ShuffleTransport` server
363    /// surfacing every received frame, attributed to its sending peer, on the
364    /// bounded crossfire queue.
365    pub struct ShuffleReceiver {
366        local_id: ShufflePeerId,
367        local_addr: SocketAddr,
368        // crossfire's `AsyncRx` is `Send` but `!Sync` (it holds a
369        // `PhantomData<Cell<()>>`), yet `Arc<ShuffleReceiver>` must be `Sync` — it
370        // is embedded in DataFusion's `ClusterRepartitionExec`, whose
371        // `ExecutionPlan` impl requires `Send + Sync`. Park the receiver behind a
372        // `Mutex<Option<_>>` (which is `Sync` for any `Send` inner) and hand it out
373        // via a take/return guard so the single async consumer never holds the
374        // guard across `.await`. `rx_returned` wakes the next waiter; the guard
375        // restores the receiver on drop so a cancelled `recv` can't strand it.
376        rx: Mutex<Option<InboundRx>>,
377        rx_returned: Arc<tokio::sync::Notify>,
378        server: JoinHandle<()>,
379        holdover: Arc<Holdover>,
380    }
381
382    impl Drop for ShuffleReceiver {
383        fn drop(&mut self) {
384            // Abort the server task so the listener closes and in-flight peer
385            // streams break — senders then observe the error and reconnect.
386            self.server.abort();
387        }
388    }
389
390    impl std::fmt::Debug for ShuffleReceiver {
391        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392            f.debug_struct("ShuffleReceiver")
393                .field("local_id", &self.local_id)
394                .field("local_addr", &self.local_addr)
395                .finish_non_exhaustive()
396        }
397    }
398
399    impl ShuffleReceiver {
400        /// Bind on `addr` and start serving. The bound address (with any ephemeral
401        /// port resolved) is exposed via [`Self::local_addr`].
402        ///
403        /// # Errors
404        /// Returns `io::Error` on bind failure.
405        pub async fn bind(local_id: ShufflePeerId, addr: SocketAddr) -> io::Result<Self> {
406            let listener = tokio::net::TcpListener::bind(addr).await?;
407            let local_addr = listener.local_addr()?;
408            let (tx, rx) =
409                mpsc::bounded_async::<(ShufflePeerId, ShuffleMessage)>(SHUFFLE_RECV_QUEUE);
410
411            let service = ShuffleService { tx };
412            // Accept loop as a stream of `Result<TcpStream, io::Error>` for
413            // `serve_with_incoming` — avoids the tokio-stream `net` feature.
414            // nodelay is set per accepted connection.
415            let incoming = futures::stream::unfold(listener, |listener| async move {
416                let item = match listener.accept().await {
417                    Ok((stream, _)) => {
418                        let _ = stream.set_nodelay(true);
419                        Ok(stream)
420                    }
421                    Err(e) => Err(e),
422                };
423                Some((item, listener))
424            });
425            // Apply TLS synchronously so a bad cert fails bind() rather than
426            // silently never serving.
427            let mut builder = Server::builder();
428            if let Some(tls) = crate::cluster::control::tls::server_tls() {
429                builder = builder
430                    .tls_config(tls.clone())
431                    .map_err(|e| io::Error::other(format!("cluster shuffle TLS config: {e}")))?;
432            }
433            let router = builder.add_service(ShuffleTransportServer::new(service));
434            let server = tokio::spawn(async move {
435                let _ = router.serve_with_incoming(incoming).await;
436            });
437
438            Ok(Self {
439                local_id,
440                local_addr,
441                rx: Mutex::new(Some(rx)),
442                rx_returned: Arc::new(tokio::sync::Notify::new()),
443                server,
444                holdover: Arc::new(Holdover::default()),
445            })
446        }
447
448        /// Bind and publish the listener's address into `kv` under
449        /// [`SHUFFLE_ADDR_KEY`] for peer discovery.
450        ///
451        /// # Errors
452        /// Returns `io::Error` on bind failure.
453        pub async fn bind_with_kv(
454            local_id: ShufflePeerId,
455            addr: SocketAddr,
456            kv: Arc<dyn ClusterKv>,
457        ) -> io::Result<Self> {
458            let recv = Self::bind(local_id, addr).await?;
459            kv.write(SHUFFLE_ADDR_KEY, recv.local_addr.to_string())
460                .await;
461            Ok(recv)
462        }
463
464        /// Local socket address the server is bound to.
465        #[must_use]
466        pub fn local_addr(&self) -> SocketAddr {
467            self.local_addr
468        }
469
470        /// Await the next `(peer_id, msg)`. `None` once the server task has stopped
471        /// and every queued item is drained. Single-owner; concurrent callers
472        /// serialise via `rx_returned`. Cancellation-safe — a dropped `recv()`
473        /// future returns the receiver to its slot via the RAII guard.
474        pub async fn recv(&self) -> Option<(ShufflePeerId, ShuffleMessage)> {
475            loop {
476                // Take the receiver out under a short lock dropped before `.await`.
477                let taken = { self.rx.lock().take() };
478                let Some(rx) = taken else {
479                    self.rx_returned.notified().await;
480                    continue;
481                };
482                let mut guard = RxReturnGuard {
483                    slot: &self.rx,
484                    notify: &self.rx_returned,
485                    rx: Some(rx),
486                };
487                let rx = guard.rx.as_mut()?;
488                return rx.recv().await.ok();
489            }
490        }
491
492        /// Drain every currently-available `(peer_id, msg)` without blocking. Empty
493        /// when the queue is empty or a `recv()` currently holds the receiver.
494        #[must_use]
495        pub fn drain_available(&self) -> Vec<(ShufflePeerId, ShuffleMessage)> {
496            let mut out = Vec::new();
497            let slot = self.rx.lock();
498            if let Some(rx) = slot.as_ref() {
499                while let Ok(item) = rx.try_recv() {
500                    out.push(item);
501                }
502            }
503            out
504        }
505
506        /// Drain the inbound queue into `staged`: bucket `VnodeData` by stage,
507        /// stash `Barrier`s for the aligning checkpoint (never dropped — see
508        /// `Holdover`), discard `Hello`/`Close`.
509        fn drain_inbound_into(&self, staged: &mut FxHashMap<String, Vec<RecordBatch>>) {
510            let slot = self.rx.lock();
511            if let Some(rx) = slot.as_ref() {
512                while let Ok((from, msg)) = rx.try_recv() {
513                    match msg {
514                        ShuffleMessage::VnodeData(s, _vnode, batch) => {
515                            staged.entry(s).or_default().push(batch);
516                        }
517                        ShuffleMessage::Barrier(b) => {
518                            self.holdover.staged_barriers.lock().push((from, b));
519                        }
520                        _ => {} // Hello / Close
521                    }
522                }
523            }
524        }
525
526        /// Non-blocking drain of the [`ShuffleMessage::VnodeData`] batches for
527        /// `stage`; other stages stay bucketed for their own drainer. Empty if the
528        /// queue is empty or a `recv()` holds it.
529        #[must_use]
530        pub fn drain_vnode_data_for(&self, stage: &str) -> Vec<RecordBatch> {
531            let mut staged = self.holdover.staged.lock();
532            self.drain_inbound_into(&mut staged);
533            staged.remove(stage).unwrap_or_default()
534        }
535
536        /// Single lock-cycle drain of every staged stage whose key starts with
537        /// `prefix`, lifting those out and leaving operator stages untouched. Lets
538        /// the subscription router pull all `__sub::` batches in one pass.
539        #[must_use]
540        pub fn drain_staged_with_prefix(
541            &self,
542            prefix: &str,
543        ) -> FxHashMap<String, Vec<RecordBatch>> {
544            let mut staged = self.holdover.staged.lock();
545            self.drain_inbound_into(&mut staged);
546            let mut out: FxHashMap<String, Vec<RecordBatch>> = FxHashMap::default();
547            staged.retain(|stage, batches| {
548                if stage.starts_with(prefix) {
549                    out.insert(stage.clone(), std::mem::take(batches));
550                    false
551                } else {
552                    true
553                }
554            });
555            out
556        }
557
558        /// Stage `batch` under `stage` for a later [`Self::drain_vnode_data_for`] /
559        /// [`Self::drain_all_staged`] — used when no operator for `stage` exists yet
560        /// at drain time.
561        pub fn stage_batch(&self, stage: String, batch: RecordBatch) {
562            self.holdover
563                .staged
564                .lock()
565                .entry(stage)
566                .or_default()
567                .push(batch);
568        }
569
570        /// Take the barriers stashed by [`Self::drain_vnode_data_for`] (peers that
571        /// fanned out before this node began aligning).
572        #[must_use]
573        pub fn drain_staged_barriers(&self) -> Vec<(ShufflePeerId, CheckpointBarrier)> {
574            std::mem::take(&mut self.holdover.staged_barriers.lock())
575        }
576
577        /// Empty the per-stage holdover, returning every buffered `(stage, batch)`.
578        #[must_use]
579        pub fn drain_all_staged(&self) -> Vec<(String, RecordBatch)> {
580            let mut staged = self.holdover.staged.lock();
581            staged
582                .drain()
583                .flat_map(|(stage, batches)| batches.into_iter().map(move |b| (stage.clone(), b)))
584                .collect()
585        }
586    }
587
588    /// Returns the receiver to the slot on drop so a cancelled `recv()` future
589    /// doesn't strand it; wakes the next parked waiter.
590    struct RxReturnGuard<'a> {
591        slot: &'a Mutex<Option<InboundRx>>,
592        notify: &'a tokio::sync::Notify,
593        rx: Option<InboundRx>,
594    }
595
596    impl Drop for RxReturnGuard<'_> {
597        fn drop(&mut self) {
598            if let Some(rx) = self.rx.take() {
599                *self.slot.lock() = Some(rx);
600                // notify_one stores a permit; notify_waiters can lose wakeups.
601                self.notify.notify_one();
602            }
603        }
604    }
605
606    /// The `ShuffleTransport` service object: holds the producer end of the inbound
607    /// queue shared by every peer stream.
608    struct ShuffleService {
609        tx: InboundTx,
610    }
611
612    #[tonic::async_trait]
613    impl ShuffleTransport for ShuffleService {
614        async fn shuffle(
615            &self,
616            request: Request<tonic::Streaming<ShuffleFrame>>,
617        ) -> Result<tonic::Response<ShuffleSummary>, tonic::Status> {
618            let summary = run_stream(self.tx.clone(), request.into_inner()).await?;
619            Ok(tonic::Response::new(summary))
620        }
621    }
622
623    /// Read the leading `Hello`, then forward each decoded frame onto the bounded
624    /// inbound queue, returning a summary when the client half-closes. `VnodeData`
625    /// is decoded with per-stage [`BatchStreamDecoder`]s mirroring the sender's
626    /// per-stage encoders (schema on a stage's first chunk only).
627    async fn run_stream(
628        tx: InboundTx,
629        mut stream: tonic::Streaming<ShuffleFrame>,
630    ) -> Result<ShuffleSummary, tonic::Status> {
631        let first = stream
632            .message()
633            .await?
634            .ok_or_else(|| tonic::Status::invalid_argument("shuffle stream closed before Hello"))?;
635        let peer = match first.kind {
636            Some(shuffle_frame::Kind::Hello(h)) => h.node_id,
637            _ => {
638                return Err(tonic::Status::invalid_argument(
639                    "first shuffle frame must be Hello",
640                ))
641            }
642        };
643
644        let mut decoders: FxHashMap<String, BatchStreamDecoder> = FxHashMap::default();
645        let mut frames_received = 0u64;
646        while let Some(frame) = stream.message().await? {
647            let kind = frame
648                .kind
649                .ok_or_else(|| tonic::Status::invalid_argument("empty shuffle frame"))?;
650            match kind {
651                shuffle_frame::Kind::Close(_) => break,
652                shuffle_frame::Kind::Hello(h) => {
653                    frames_received += 1;
654                    if tx
655                        .send((peer, ShuffleMessage::Hello(h.node_id)))
656                        .await
657                        .is_err()
658                    {
659                        break;
660                    }
661                }
662                shuffle_frame::Kind::Barrier(b) => {
663                    frames_received += 1;
664                    let msg = ShuffleMessage::Barrier(CheckpointBarrier {
665                        checkpoint_id: b.checkpoint_id,
666                        epoch: b.epoch,
667                        flags: b.flags,
668                    });
669                    if tx.send((peer, msg)).await.is_err() {
670                        break;
671                    }
672                }
673                shuffle_frame::Kind::VnodeData(v) => {
674                    frames_received += 1;
675                    let batches = decoders
676                        .entry(v.stage.clone())
677                        .or_default()
678                        .decode_chunk(v.arrow_ipc)
679                        .map_err(|e| {
680                            tonic::Status::invalid_argument(format!("shuffle ipc: {e}"))
681                        })?;
682                    let mut stream_broken = false;
683                    for batch in batches {
684                        if !forward_vnode_batch(&tx, peer, &v.stage, v.vnode, batch).await? {
685                            stream_broken = true;
686                            break;
687                        }
688                    }
689                    if stream_broken {
690                        break;
691                    }
692                }
693            }
694        }
695        Ok(ShuffleSummary { frames_received })
696    }
697
698    /// Forward one decoded `stage` batch onto the inbound queue. If the batch
699    /// carries the `__laminar_vnode` metadata column, split it per vnode and emit
700    /// a slice each; otherwise emit it whole under `default_vnode`. Returns
701    /// `Ok(false)` when the inbound queue has closed (the caller stops reading).
702    async fn forward_vnode_batch(
703        tx: &InboundTx,
704        peer: ShufflePeerId,
705        stage: &str,
706        default_vnode: u32,
707        batch: RecordBatch,
708    ) -> Result<bool, tonic::Status> {
709        let schema = batch.schema();
710        let Some((col_idx, _field)) = schema.column_with_name("__laminar_vnode") else {
711            let msg = ShuffleMessage::VnodeData(stage.to_string(), default_vnode, batch);
712            return Ok(tx.send((peer, msg)).await.is_ok());
713        };
714
715        let vnode_array = batch
716            .column(col_idx)
717            .as_any()
718            .downcast_ref::<arrow_array::UInt32Array>()
719            .ok_or_else(|| {
720                tonic::Status::invalid_argument("vnode metadata column is not UInt32Array")
721            })?;
722        let row_vnodes: Vec<u32> = vnode_array.values().to_vec();
723
724        let mut projection: Vec<usize> = (0..schema.fields().len()).collect();
725        projection.remove(col_idx);
726        let batch_without_vnode = batch.project(&projection).map_err(|e| {
727            tonic::Status::internal(format!("Failed to project out vnode metadata: {e}"))
728        })?;
729
730        let slices =
731            crate::shuffle::routing::slice_batch_by_vnodes(&batch_without_vnode, &row_vnodes);
732        for (v, slice) in slices {
733            let sub_msg = ShuffleMessage::VnodeData(stage.to_string(), v, slice);
734            if tx.send((peer, sub_msg)).await.is_err() {
735                return Ok(false);
736            }
737        }
738        Ok(true)
739    }
740
741    #[cfg(test)]
742    mod encode_tests {
743        use super::*;
744        use arrow_array::Int64Array;
745        use arrow_schema::{DataType, Field, Schema};
746
747        #[test]
748        fn schema_change_on_a_stage_is_rejected() {
749            let batch = |name: &str| {
750                let schema = Arc::new(Schema::new(vec![Field::new(name, DataType::Int64, false)]));
751                arrow_array::RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(vec![1]))])
752                    .unwrap()
753            };
754            let mut encoders = FxHashMap::default();
755            let msg = ShuffleMessage::VnodeData("s".into(), 0, batch("a"));
756            encode_message(&msg, &mut encoders).unwrap();
757
758            let changed = ShuffleMessage::VnodeData("s".into(), 0, batch("b"));
759            let err = encode_message(&changed, &mut encoders).unwrap_err();
760            assert!(err.message().contains("changed schema"), "{err}");
761
762            // A different stage with its own schema is fine.
763            let other = ShuffleMessage::VnodeData("t".into(), 0, batch("b"));
764            encode_message(&other, &mut encoders).unwrap();
765        }
766    }
767}
768
769#[cfg(feature = "cluster")]
770pub use grpc::{ShuffleReceiver, ShuffleSender};
771
772// ===========================================================================
773// Default build: networking-free shim preserving the public API.
774//
775// The cluster shuffle is only exercised under `cluster`; a default
776// build references these types only in signatures. The shim keeps the inbound
777// crossfire queue + holdover staging so the surface compiles and behaves sanely
778// (local-only) without pulling in tonic.
779// ===========================================================================
780
781#[cfg(not(feature = "cluster"))]
782mod shim {
783    use std::io;
784    use std::net::SocketAddr;
785    use std::sync::Arc;
786
787    use arrow_array::RecordBatch;
788    use crossfire::{mpsc, AsyncRx, MAsyncTx};
789    use parking_lot::Mutex;
790    use rustc_hash::FxHashMap;
791
792    use super::{Holdover, ShuffleMessage, ShufflePeerId, SHUFFLE_RECV_QUEUE};
793    use crate::checkpoint::barrier::CheckpointBarrier;
794
795    type InboundRx = AsyncRx<mpsc::Array<(ShufflePeerId, ShuffleMessage)>>;
796    type InboundTx = MAsyncTx<mpsc::Array<(ShufflePeerId, ShuffleMessage)>>;
797
798    /// Outbound shuffle handle. Without the cluster feature there is no peer
799    /// fabric, so sends to a non-local peer report the peer as unregistered.
800    pub struct ShuffleSender {
801        local_id: ShufflePeerId,
802        peers: Mutex<FxHashMap<ShufflePeerId, SocketAddr>>,
803    }
804
805    impl std::fmt::Debug for ShuffleSender {
806        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
807            f.debug_struct("ShuffleSender")
808                .field("local_id", &self.local_id)
809                .finish_non_exhaustive()
810        }
811    }
812
813    impl ShuffleSender {
814        /// Empty sender (no peer fabric without the cluster feature).
815        #[must_use]
816        pub fn new(local_id: ShufflePeerId) -> Self {
817            Self {
818                local_id,
819                peers: Mutex::new(FxHashMap::default()),
820            }
821        }
822
823        /// Register (or update) a peer's shuffle address.
824        #[allow(clippy::unused_async)] // async to match the cluster build's API.
825        pub async fn register_peer(&self, peer: ShufflePeerId, addr: SocketAddr) {
826            self.peers.lock().insert(peer, addr);
827        }
828
829        /// # Errors
830        /// Errors for an unregistered peer; the no-cluster build has no transport,
831        /// so registered peers are accepted as a no-op delivery.
832        #[allow(clippy::unused_async)] // async to match the cluster build's API.
833        pub async fn send_to(&self, peer: ShufflePeerId, _msg: &ShuffleMessage) -> io::Result<()> {
834            if self.peers.lock().contains_key(&peer) {
835                Ok(())
836            } else {
837                Err(io::Error::new(
838                    io::ErrorKind::NotFound,
839                    format!("peer {peer} has no registered shuffle address"),
840                ))
841            }
842        }
843
844        /// # Errors
845        /// Returns the first `io::Error` from any peer's `send_to`.
846        pub async fn fan_out_barrier(
847            &self,
848            peers: &[ShufflePeerId],
849            barrier: CheckpointBarrier,
850        ) -> io::Result<()> {
851            let msg = ShuffleMessage::Barrier(barrier);
852            for &peer in peers {
853                self.send_to(peer, &msg).await?;
854            }
855            Ok(())
856        }
857    }
858
859    /// Inbound shuffle handle. Holds the bounded crossfire queue + holdover so the
860    /// drain/stage API compiles and behaves locally without a network. The
861    /// receiver is parked behind a `Mutex<Option<_>>` for the same `Sync` reason as
862    /// the gRPC build.
863    pub struct ShuffleReceiver {
864        local_id: ShufflePeerId,
865        local_addr: SocketAddr,
866        #[allow(dead_code)]
867        tx: InboundTx,
868        rx: Mutex<Option<InboundRx>>,
869        rx_returned: Arc<tokio::sync::Notify>,
870        holdover: Arc<Holdover>,
871    }
872
873    impl std::fmt::Debug for ShuffleReceiver {
874        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
875            f.debug_struct("ShuffleReceiver")
876                .field("local_id", &self.local_id)
877                .field("local_addr", &self.local_addr)
878                .finish_non_exhaustive()
879        }
880    }
881
882    impl ShuffleReceiver {
883        /// # Errors
884        /// Returns `io::Error` on bind failure.
885        pub async fn bind(local_id: ShufflePeerId, addr: SocketAddr) -> io::Result<Self> {
886            // Resolve the address (incl. ephemeral port) by binding momentarily.
887            let listener = tokio::net::TcpListener::bind(addr).await?;
888            let local_addr = listener.local_addr()?;
889            drop(listener);
890            let (tx, rx) =
891                mpsc::bounded_async::<(ShufflePeerId, ShuffleMessage)>(SHUFFLE_RECV_QUEUE);
892            Ok(Self {
893                local_id,
894                local_addr,
895                tx,
896                rx: Mutex::new(Some(rx)),
897                rx_returned: Arc::new(tokio::sync::Notify::new()),
898                holdover: Arc::new(Holdover::default()),
899            })
900        }
901
902        /// Local socket address resolved at bind time.
903        #[must_use]
904        pub fn local_addr(&self) -> SocketAddr {
905            self.local_addr
906        }
907
908        /// Await the next `(peer_id, msg)`. `None` once all senders drop.
909        pub async fn recv(&self) -> Option<(ShufflePeerId, ShuffleMessage)> {
910            loop {
911                let taken = { self.rx.lock().take() };
912                let Some(rx) = taken else {
913                    self.rx_returned.notified().await;
914                    continue;
915                };
916                let mut guard = RxReturnGuard {
917                    slot: &self.rx,
918                    notify: &self.rx_returned,
919                    rx: Some(rx),
920                };
921                let rx = guard.rx.as_mut()?;
922                return rx.recv().await.ok();
923            }
924        }
925
926        /// Drain every currently-available `(peer_id, msg)` without blocking.
927        #[must_use]
928        pub fn drain_available(&self) -> Vec<(ShufflePeerId, ShuffleMessage)> {
929            let mut out = Vec::new();
930            let slot = self.rx.lock();
931            if let Some(rx) = slot.as_ref() {
932                while let Ok(item) = rx.try_recv() {
933                    out.push(item);
934                }
935            }
936            out
937        }
938
939        /// Non-blocking drain of the `VnodeData` batches for `stage`; other-stage
940        /// frames are bucketed and barriers stashed (never dropped).
941        #[must_use]
942        pub fn drain_vnode_data_for(&self, stage: &str) -> Vec<RecordBatch> {
943            let mut staged = self.holdover.staged.lock();
944            {
945                let slot = self.rx.lock();
946                if let Some(rx) = slot.as_ref() {
947                    while let Ok((from, msg)) = rx.try_recv() {
948                        match msg {
949                            ShuffleMessage::VnodeData(s, _vnode, batch) => {
950                                staged.entry(s).or_default().push(batch);
951                            }
952                            ShuffleMessage::Barrier(b) => {
953                                self.holdover.staged_barriers.lock().push((from, b));
954                            }
955                            _ => {}
956                        }
957                    }
958                }
959            }
960            staged.remove(stage).unwrap_or_default()
961        }
962
963        /// Single lock-cycle drain of every staged stage whose key starts with
964        /// `prefix`; other-stage frames are bucketed and barriers stashed (never
965        /// dropped), matching [`Self::drain_vnode_data_for`]. Operator stages are
966        /// left in `staged`; only the matching entries are returned.
967        #[must_use]
968        pub fn drain_staged_with_prefix(
969            &self,
970            prefix: &str,
971        ) -> FxHashMap<String, Vec<RecordBatch>> {
972            let mut staged = self.holdover.staged.lock();
973            {
974                let slot = self.rx.lock();
975                if let Some(rx) = slot.as_ref() {
976                    while let Ok((from, msg)) = rx.try_recv() {
977                        match msg {
978                            ShuffleMessage::VnodeData(s, _vnode, batch) => {
979                                staged.entry(s).or_default().push(batch);
980                            }
981                            ShuffleMessage::Barrier(b) => {
982                                self.holdover.staged_barriers.lock().push((from, b));
983                            }
984                            _ => {}
985                        }
986                    }
987                }
988            }
989            let mut out: FxHashMap<String, Vec<RecordBatch>> = FxHashMap::default();
990            staged.retain(|stage, batches| {
991                if stage.starts_with(prefix) {
992                    out.insert(stage.clone(), std::mem::take(batches));
993                    false
994                } else {
995                    true
996                }
997            });
998            out
999        }
1000
1001        /// Stage `batch` under `stage` for a later drain.
1002        pub fn stage_batch(&self, stage: String, batch: RecordBatch) {
1003            self.holdover
1004                .staged
1005                .lock()
1006                .entry(stage)
1007                .or_default()
1008                .push(batch);
1009        }
1010
1011        /// Take the barriers stashed by [`Self::drain_vnode_data_for`].
1012        #[must_use]
1013        pub fn drain_staged_barriers(&self) -> Vec<(ShufflePeerId, CheckpointBarrier)> {
1014            std::mem::take(&mut self.holdover.staged_barriers.lock())
1015        }
1016
1017        /// Empty the per-stage holdover, returning every buffered `(stage, batch)`.
1018        #[must_use]
1019        pub fn drain_all_staged(&self) -> Vec<(String, RecordBatch)> {
1020            let mut staged = self.holdover.staged.lock();
1021            staged
1022                .drain()
1023                .flat_map(|(stage, batches)| batches.into_iter().map(move |b| (stage.clone(), b)))
1024                .collect()
1025        }
1026    }
1027
1028    /// Returns the receiver to the slot on drop so a cancelled `recv()` future
1029    /// doesn't strand it; wakes the next parked waiter.
1030    struct RxReturnGuard<'a> {
1031        slot: &'a Mutex<Option<InboundRx>>,
1032        notify: &'a tokio::sync::Notify,
1033        rx: Option<InboundRx>,
1034    }
1035
1036    impl Drop for RxReturnGuard<'_> {
1037        fn drop(&mut self) {
1038            if let Some(rx) = self.rx.take() {
1039                *self.slot.lock() = Some(rx);
1040                self.notify.notify_one();
1041            }
1042        }
1043    }
1044}
1045
1046#[cfg(not(feature = "cluster"))]
1047pub use shim::{ShuffleReceiver, ShuffleSender};
1048
1049#[cfg(all(test, feature = "cluster"))]
1050mod tests {
1051    use std::io;
1052    use std::sync::Arc;
1053
1054    use super::*;
1055
1056    async fn bind_on_loopback(local_id: ShufflePeerId) -> ShuffleReceiver {
1057        ShuffleReceiver::bind(local_id, "127.0.0.1:0".parse().unwrap())
1058            .await
1059            .expect("bind")
1060    }
1061
1062    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1063    async fn sender_to_receiver_delivers_with_peer_attribution() {
1064        let recv = bind_on_loopback(2).await;
1065        let recv_addr = recv.local_addr();
1066
1067        let sender = ShuffleSender::new(1);
1068        sender.register_peer(2, recv_addr).await;
1069        sender
1070            .send_to(2, &ShuffleMessage::Hello(1234))
1071            .await
1072            .unwrap();
1073
1074        let (from, msg) = recv.recv().await.unwrap();
1075        assert_eq!(from, 1, "receiver attributes frame to sender id");
1076        assert_eq!(msg, ShuffleMessage::Hello(1234));
1077    }
1078
1079    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1080    async fn sender_reuses_stream_across_sends() {
1081        let recv = bind_on_loopback(2).await;
1082        let sender = ShuffleSender::new(1);
1083        sender.register_peer(2, recv.local_addr()).await;
1084
1085        for delta in [10u64, 20, 30, 40] {
1086            sender
1087                .send_to(2, &ShuffleMessage::Hello(delta))
1088                .await
1089                .unwrap();
1090        }
1091
1092        let mut got = Vec::new();
1093        for _ in 0..4 {
1094            got.push(recv.recv().await.unwrap().1);
1095        }
1096        assert_eq!(
1097            got,
1098            vec![
1099                ShuffleMessage::Hello(10),
1100                ShuffleMessage::Hello(20),
1101                ShuffleMessage::Hello(30),
1102                ShuffleMessage::Hello(40),
1103            ]
1104        );
1105    }
1106
1107    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1108    async fn send_to_unregistered_peer_errors() {
1109        let sender = ShuffleSender::new(1);
1110        let err = sender
1111            .send_to(99, &ShuffleMessage::Hello(1))
1112            .await
1113            .unwrap_err();
1114        assert_eq!(err.kind(), io::ErrorKind::NotFound);
1115    }
1116
1117    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1118    async fn send_discovers_peer_address_from_kv() {
1119        use crate::cluster::control::{ClusterKv, InMemoryKv};
1120        use crate::cluster::discovery::NodeId;
1121
1122        // Peer 2 binds for real; its address is seeded into peer 1's KV so the
1123        // KV-backed sender resolves it on first send without an explicit
1124        // `register_peer`. End-to-end delivery proves the discovery glue.
1125        let recv = bind_on_loopback(2).await;
1126        let kv = Arc::new(InMemoryKv::new(NodeId(1)));
1127        kv.seed(NodeId(2), SHUFFLE_ADDR_KEY, recv.local_addr().to_string());
1128        let sender = ShuffleSender::with_kv(1, kv as Arc<dyn ClusterKv>);
1129
1130        sender.send_to(2, &ShuffleMessage::Hello(7)).await.unwrap();
1131        let (from, msg) = recv.recv().await.unwrap();
1132        assert_eq!(from, 1);
1133        assert_eq!(msg, ShuffleMessage::Hello(7));
1134    }
1135
1136    /// A peer restarting at a new address: the cached stream breaks, the next
1137    /// `send_to` reconnects against the freshly-registered address. Windows-only
1138    /// skip — the FIN-after-abort wakeup chain is not time-bounded under nextest
1139    /// parallelism there.
1140    #[cfg(not(windows))]
1141    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1142    async fn send_reconnects_after_peer_restart_at_new_address() {
1143        let recv_v1 = bind_on_loopback(2).await;
1144        let addr_v1 = recv_v1.local_addr();
1145
1146        let sender = ShuffleSender::new(1);
1147        sender.register_peer(2, addr_v1).await;
1148        sender
1149            .send_to(2, &ShuffleMessage::Hello(111))
1150            .await
1151            .unwrap();
1152        let (from, msg) = recv_v1.recv().await.unwrap();
1153        assert_eq!(from, 1);
1154        assert_eq!(msg, ShuffleMessage::Hello(111));
1155
1156        // Crash the peer.
1157        drop(recv_v1);
1158
1159        // Peer restarts on a fresh ephemeral port.
1160        let recv_v2 = bind_on_loopback(2).await;
1161        let addr_v2 = recv_v2.local_addr();
1162        assert_ne!(addr_v1, addr_v2, "ephemeral rebind must pick a new port");
1163        sender.register_peer(2, addr_v2).await;
1164
1165        // Reconnect + deliver to the restarted peer. Retry to absorb the time it
1166        // takes the old stream to flip dead after the server aborted.
1167        let deadline = std::time::Instant::now() + std::time::Duration::from_secs(30);
1168        loop {
1169            let _ = sender.send_to(2, &ShuffleMessage::Hello(222)).await;
1170            if let Some((from, ShuffleMessage::Hello(222))) =
1171                tokio::time::timeout(std::time::Duration::from_millis(200), recv_v2.recv())
1172                    .await
1173                    .ok()
1174                    .flatten()
1175            {
1176                assert_eq!(from, 1);
1177                return;
1178            }
1179            assert!(
1180                std::time::Instant::now() < deadline,
1181                "did not deliver to restarted peer within 30s",
1182            );
1183        }
1184    }
1185
1186    /// `drain_staged_with_prefix` lifts `__sub::` stages in one pass while
1187    /// leaving operator stages staged for their own drainer.
1188    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1189    async fn drain_staged_with_prefix_lifts_subs_and_keeps_operator_stages() {
1190        use arrow_array::{Int64Array, RecordBatch};
1191        use arrow_schema::{DataType, Field, Schema};
1192        use rustc_hash::FxHashMap;
1193
1194        use crate::checkpoint::barrier::CheckpointBarrier;
1195
1196        fn batch(values: Vec<i64>) -> RecordBatch {
1197            let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1198            RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(values))]).unwrap()
1199        }
1200        fn col(b: &RecordBatch) -> Vec<i64> {
1201            b.column(0)
1202                .as_any()
1203                .downcast_ref::<Int64Array>()
1204                .unwrap()
1205                .values()
1206                .to_vec()
1207        }
1208
1209        let recv = bind_on_loopback(2).await;
1210        let sender = ShuffleSender::new(1);
1211        sender.register_peer(2, recv.local_addr()).await;
1212
1213        // FIFO over one stream: two subscription stages, one operator stage, then
1214        // a trailing barrier. Once the barrier is observed, every prior frame has
1215        // been received and bucketed.
1216        for (stage, vals) in [
1217            ("__sub::alpha", vec![1, 2, 3]),
1218            ("__sub::beta", vec![4, 5, 6]),
1219            ("op_stage", vec![7, 8, 9]),
1220        ] {
1221            sender
1222                .send_to(2, &ShuffleMessage::VnodeData(stage.into(), 0, batch(vals)))
1223                .await
1224                .unwrap();
1225        }
1226        sender
1227            .send_to(
1228                2,
1229                &ShuffleMessage::Barrier(CheckpointBarrier {
1230                    checkpoint_id: 7,
1231                    epoch: 3,
1232                    flags: 0,
1233                }),
1234            )
1235            .await
1236            .unwrap();
1237
1238        // Poll the single-lock-cycle drain until both sub stages and the trailing
1239        // barrier have arrived (loopback is near-instant; 2s is a wide margin).
1240        let mut subs: FxHashMap<String, Vec<RecordBatch>> = FxHashMap::default();
1241        let mut barriers = Vec::new();
1242        let deadline = std::time::Instant::now() + std::time::Duration::from_secs(2);
1243        while subs.len() < 2 || barriers.is_empty() {
1244            for (k, v) in recv.drain_staged_with_prefix("__sub::") {
1245                subs.entry(k).or_default().extend(v);
1246            }
1247            barriers.extend(recv.drain_staged_barriers());
1248            assert!(
1249                std::time::Instant::now() < deadline,
1250                "frames not delivered within 2s",
1251            );
1252            tokio::time::sleep(std::time::Duration::from_millis(5)).await;
1253        }
1254
1255        // Both subscription stages lifted, with their batches intact.
1256        assert_eq!(subs.len(), 2, "only the two __sub:: stages are returned");
1257        assert_eq!(col(&subs["__sub::alpha"][0]), vec![1, 2, 3]);
1258        assert_eq!(col(&subs["__sub::beta"][0]), vec![4, 5, 6]);
1259
1260        // The barrier was stashed, not dropped, and attributed to its sender.
1261        assert_eq!(barriers.len(), 1);
1262        assert_eq!(barriers[0].0, 1, "barrier attributed to sender peer 1");
1263        assert_eq!(barriers[0].1.checkpoint_id, 7);
1264
1265        // The operator stage was left intact for its own drainer.
1266        let op = recv.drain_vnode_data_for("op_stage");
1267        assert_eq!(op.len(), 1);
1268        assert_eq!(col(&op[0]), vec![7, 8, 9]);
1269
1270        // A second prefix drain finds nothing new.
1271        assert!(recv.drain_staged_with_prefix("__sub::").is_empty());
1272    }
1273}