1use super::message::ShuffleMessage;
17use crate::checkpoint::barrier::CheckpointBarrier;
18
19const SHUFFLE_RECV_QUEUE: usize = 1024;
24
25pub type ShufflePeerId = u64;
28
29#[cfg(feature = "cluster")]
34pub const SHUFFLE_ADDR_KEY: &str = "shuffle:addr";
35
36#[cfg(feature = "cluster")]
37#[allow(
38 clippy::doc_markdown,
39 clippy::default_trait_access,
40 clippy::missing_const_for_fn,
41 clippy::must_use_candidate,
42 clippy::too_many_lines,
43 missing_docs
44)]
45pub(crate) mod shuffle_v1 {
46 tonic::include_proto!("laminar.shuffle.v1");
47}
48
49#[derive(Default)]
59struct Holdover {
60 staged: parking_lot::Mutex<rustc_hash::FxHashMap<String, Vec<arrow_array::RecordBatch>>>,
61 staged_barriers: parking_lot::Mutex<Vec<(ShufflePeerId, CheckpointBarrier)>>,
62}
63
64#[cfg(feature = "cluster")]
69mod grpc {
70 use std::collections::hash_map::Entry;
71 use std::io;
72 use std::net::SocketAddr;
73 use std::sync::atomic::{AtomicBool, Ordering};
74 use std::sync::Arc;
75
76 use arrow_array::RecordBatch;
77 use crossfire::{mpsc, AsyncRx, MAsyncTx};
78 use futures::StreamExt as _;
79 use parking_lot::Mutex;
80 use rustc_hash::FxHashMap;
81 use tokio::task::JoinHandle;
82 use tonic::transport::{Channel, Server};
83 use tonic::Request;
84
85 use super::shuffle_v1::shuffle_frame;
86 use super::shuffle_v1::shuffle_transport_client::ShuffleTransportClient;
87 use super::shuffle_v1::shuffle_transport_server::{ShuffleTransport, ShuffleTransportServer};
88 use super::shuffle_v1::{Barrier, Close, Hello, ShuffleFrame, ShuffleSummary, VnodeData};
89 use super::{Holdover, ShuffleMessage, ShufflePeerId, SHUFFLE_ADDR_KEY, SHUFFLE_RECV_QUEUE};
90 use crate::checkpoint::barrier::CheckpointBarrier;
91 use crate::cluster::control::ClusterKv;
92 use crate::serialization::{BatchStreamDecoder, BatchStreamEncoder};
93
94 const SHUFFLE_SEND_QUEUE: usize = 1024;
97
98 type InboundRx = AsyncRx<mpsc::Array<(ShufflePeerId, ShuffleMessage)>>;
101 type InboundTx = MAsyncTx<mpsc::Array<(ShufflePeerId, ShuffleMessage)>>;
102
103 fn io_err<E: std::fmt::Display>(e: E) -> io::Error {
106 io::Error::other(e.to_string())
107 }
108
109 fn encode_message(
114 msg: &ShuffleMessage,
115 encoders: &mut FxHashMap<String, BatchStreamEncoder>,
116 ) -> Result<ShuffleFrame, tonic::Status> {
117 let kind = match msg {
118 ShuffleMessage::Hello(node_id) => {
119 shuffle_frame::Kind::Hello(Hello { node_id: *node_id })
120 }
121 ShuffleMessage::Barrier(b) => shuffle_frame::Kind::Barrier(Barrier {
122 checkpoint_id: b.checkpoint_id,
123 epoch: b.epoch,
124 flags: b.flags,
125 }),
126 ShuffleMessage::VnodeData(stage, vnode, batch) => {
127 let encoder = match encoders.entry(stage.clone()) {
128 Entry::Occupied(e) => {
129 let enc = e.into_mut();
130 let schema = batch.schema();
132 if !Arc::ptr_eq(enc.schema(), &schema) && *enc.schema() != schema {
133 return Err(tonic::Status::internal(format!(
134 "shuffle stage '{stage}' changed schema mid-connection",
135 )));
136 }
137 enc
138 }
139 Entry::Vacant(v) => {
140 v.insert(BatchStreamEncoder::new(&batch.schema()).map_err(|e| {
141 tonic::Status::internal(format!("shuffle ipc encoder init: {e}"))
142 })?)
143 }
144 };
145 let arrow_ipc = encoder
146 .encode(batch)
147 .map_err(|e| tonic::Status::internal(format!("shuffle ipc encode: {e}")))?;
148 shuffle_frame::Kind::VnodeData(VnodeData {
149 stage: stage.clone(),
150 vnode: *vnode,
151 arrow_ipc,
152 })
153 }
154 ShuffleMessage::Close(reason) => shuffle_frame::Kind::Close(Close {
155 reason: reason.clone(),
156 }),
157 };
158 Ok(ShuffleFrame { kind: Some(kind) })
159 }
160
161 struct PeerConn {
168 tx: MAsyncTx<mpsc::Array<ShuffleMessage>>,
169 alive: Arc<AtomicBool>,
170 driver: JoinHandle<()>,
171 }
172
173 impl PeerConn {
174 fn is_alive(&self) -> bool {
175 self.alive.load(Ordering::Acquire)
176 }
177 }
178
179 impl Drop for PeerConn {
180 fn drop(&mut self) {
181 self.driver.abort();
182 }
183 }
184
185 pub struct ShuffleSender {
187 local_id: ShufflePeerId,
188 peers: Mutex<FxHashMap<ShufflePeerId, SocketAddr>>,
189 pool: Mutex<FxHashMap<ShufflePeerId, Arc<PeerConn>>>,
190 kv: Option<Arc<dyn ClusterKv>>,
191 }
192
193 impl std::fmt::Debug for ShuffleSender {
194 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195 f.debug_struct("ShuffleSender")
196 .field("local_id", &self.local_id)
197 .finish_non_exhaustive()
198 }
199 }
200
201 impl ShuffleSender {
202 #[must_use]
205 pub fn new(local_id: ShufflePeerId) -> Self {
206 Self {
207 local_id,
208 peers: Mutex::new(FxHashMap::default()),
209 pool: Mutex::new(FxHashMap::default()),
210 kv: None,
211 }
212 }
213
214 #[must_use]
217 pub fn with_kv(local_id: ShufflePeerId, kv: Arc<dyn ClusterKv>) -> Self {
218 let mut s = Self::new(local_id);
219 s.kv = Some(kv);
220 s
221 }
222
223 #[allow(clippy::unused_async)]
227 pub async fn register_peer(&self, peer: ShufflePeerId, addr: SocketAddr) {
228 self.peers.lock().insert(peer, addr);
229 }
230
231 pub async fn send_to(&self, peer: ShufflePeerId, msg: &ShuffleMessage) -> io::Result<()> {
237 let conn = self.connection_for(peer).await?;
238 conn.tx.send(msg.clone()).await.map_err(|_| {
241 io::Error::new(
242 io::ErrorKind::BrokenPipe,
243 format!("shuffle stream to peer {peer} closed"),
244 )
245 })
246 }
247
248 pub async fn fan_out_barrier(
255 &self,
256 peers: &[ShufflePeerId],
257 barrier: CheckpointBarrier,
258 ) -> io::Result<()> {
259 let msg = ShuffleMessage::Barrier(barrier);
260 for &peer in peers {
261 self.send_to(peer, &msg).await?;
262 }
263 Ok(())
264 }
265
266 async fn discover_peer(&self, peer: ShufflePeerId) -> Option<SocketAddr> {
269 let kv = self.kv.as_ref()?;
270 let raw = kv
271 .read_from(crate::cluster::discovery::NodeId(peer), SHUFFLE_ADDR_KEY)
272 .await?;
273 let addr: SocketAddr = raw.parse().ok()?;
274 self.peers.lock().insert(peer, addr);
275 Some(addr)
276 }
277
278 async fn connection_for(&self, peer: ShufflePeerId) -> io::Result<Arc<PeerConn>> {
279 if let Some(existing) = self.pool.lock().get(&peer).cloned() {
280 if existing.is_alive() {
281 return Ok(existing);
282 }
283 }
284 self.pool.lock().retain(|p, c| *p != peer || c.is_alive());
286
287 let addr = match self.discover_peer(peer).await {
290 Some(addr) => addr,
291 None => self.peers.lock().get(&peer).copied().ok_or_else(|| {
292 io::Error::new(
293 io::ErrorKind::NotFound,
294 format!("peer {peer} has no registered shuffle address"),
295 )
296 })?,
297 };
298
299 let conn = Arc::new(open_call(self.local_id, addr)?);
300
301 let mut pool = self.pool.lock();
303 if let Some(winner) = pool.get(&peer).cloned() {
304 if winner.is_alive() {
305 return Ok(winner);
306 }
307 }
308 pool.insert(peer, Arc::clone(&conn));
309 Ok(conn)
310 }
311 }
312
313 fn open_call(local_id: ShufflePeerId, addr: SocketAddr) -> io::Result<PeerConn> {
317 let endpoint = crate::cluster::control::tls::client_endpoint(&addr.to_string())
318 .map_err(io_err)?
319 .tcp_nodelay(true);
320 let (tx, rx) = mpsc::bounded_async::<ShuffleMessage>(SHUFFLE_SEND_QUEUE);
321 let alive = Arc::new(AtomicBool::new(true));
322 let alive_for_driver = Arc::clone(&alive);
323
324 let hello = ShuffleFrame {
328 kind: Some(shuffle_frame::Kind::Hello(Hello { node_id: local_id })),
329 };
330 let encoders: FxHashMap<String, BatchStreamEncoder> = FxHashMap::default();
331 let outbound = futures::stream::once(async move { hello }).chain(futures::stream::unfold(
332 (rx, encoders),
333 |(rx, mut encoders)| async move {
334 let msg = rx.recv().await.ok()?;
335 match encode_message(&msg, &mut encoders) {
336 Ok(frame) => Some((frame, (rx, encoders))),
337 Err(e) => {
338 tracing::warn!(error = %e, "shuffle frame encode failed; closing stream");
341 None
342 }
343 }
344 },
345 ));
346
347 let driver = tokio::spawn(async move {
348 let Ok(channel) = endpoint.connect().await else {
349 alive_for_driver.store(false, Ordering::Release);
350 return;
351 };
352 let mut client = ShuffleTransportClient::<Channel>::new(channel);
353 let _ = client.shuffle(Request::new(outbound)).await;
356 alive_for_driver.store(false, Ordering::Release);
357 });
358
359 Ok(PeerConn { tx, alive, driver })
360 }
361
362 pub struct ShuffleReceiver {
366 local_id: ShufflePeerId,
367 local_addr: SocketAddr,
368 rx: Mutex<Option<InboundRx>>,
377 rx_returned: Arc<tokio::sync::Notify>,
378 server: JoinHandle<()>,
379 holdover: Arc<Holdover>,
380 }
381
382 impl Drop for ShuffleReceiver {
383 fn drop(&mut self) {
384 self.server.abort();
387 }
388 }
389
390 impl std::fmt::Debug for ShuffleReceiver {
391 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392 f.debug_struct("ShuffleReceiver")
393 .field("local_id", &self.local_id)
394 .field("local_addr", &self.local_addr)
395 .finish_non_exhaustive()
396 }
397 }
398
399 impl ShuffleReceiver {
400 pub async fn bind(local_id: ShufflePeerId, addr: SocketAddr) -> io::Result<Self> {
406 let listener = tokio::net::TcpListener::bind(addr).await?;
407 let local_addr = listener.local_addr()?;
408 let (tx, rx) =
409 mpsc::bounded_async::<(ShufflePeerId, ShuffleMessage)>(SHUFFLE_RECV_QUEUE);
410
411 let service = ShuffleService { tx };
412 let incoming = futures::stream::unfold(listener, |listener| async move {
416 let item = match listener.accept().await {
417 Ok((stream, _)) => {
418 let _ = stream.set_nodelay(true);
419 Ok(stream)
420 }
421 Err(e) => Err(e),
422 };
423 Some((item, listener))
424 });
425 let mut builder = Server::builder();
428 if let Some(tls) = crate::cluster::control::tls::server_tls() {
429 builder = builder
430 .tls_config(tls.clone())
431 .map_err(|e| io::Error::other(format!("cluster shuffle TLS config: {e}")))?;
432 }
433 let router = builder.add_service(ShuffleTransportServer::new(service));
434 let server = tokio::spawn(async move {
435 let _ = router.serve_with_incoming(incoming).await;
436 });
437
438 Ok(Self {
439 local_id,
440 local_addr,
441 rx: Mutex::new(Some(rx)),
442 rx_returned: Arc::new(tokio::sync::Notify::new()),
443 server,
444 holdover: Arc::new(Holdover::default()),
445 })
446 }
447
448 pub async fn bind_with_kv(
454 local_id: ShufflePeerId,
455 addr: SocketAddr,
456 kv: Arc<dyn ClusterKv>,
457 ) -> io::Result<Self> {
458 let recv = Self::bind(local_id, addr).await?;
459 kv.write(SHUFFLE_ADDR_KEY, recv.local_addr.to_string())
460 .await;
461 Ok(recv)
462 }
463
464 #[must_use]
466 pub fn local_addr(&self) -> SocketAddr {
467 self.local_addr
468 }
469
470 pub async fn recv(&self) -> Option<(ShufflePeerId, ShuffleMessage)> {
475 loop {
476 let taken = { self.rx.lock().take() };
478 let Some(rx) = taken else {
479 self.rx_returned.notified().await;
480 continue;
481 };
482 let mut guard = RxReturnGuard {
483 slot: &self.rx,
484 notify: &self.rx_returned,
485 rx: Some(rx),
486 };
487 let rx = guard.rx.as_mut()?;
488 return rx.recv().await.ok();
489 }
490 }
491
492 #[must_use]
495 pub fn drain_available(&self) -> Vec<(ShufflePeerId, ShuffleMessage)> {
496 let mut out = Vec::new();
497 let slot = self.rx.lock();
498 if let Some(rx) = slot.as_ref() {
499 while let Ok(item) = rx.try_recv() {
500 out.push(item);
501 }
502 }
503 out
504 }
505
506 fn drain_inbound_into(&self, staged: &mut FxHashMap<String, Vec<RecordBatch>>) {
510 let slot = self.rx.lock();
511 if let Some(rx) = slot.as_ref() {
512 while let Ok((from, msg)) = rx.try_recv() {
513 match msg {
514 ShuffleMessage::VnodeData(s, _vnode, batch) => {
515 staged.entry(s).or_default().push(batch);
516 }
517 ShuffleMessage::Barrier(b) => {
518 self.holdover.staged_barriers.lock().push((from, b));
519 }
520 _ => {} }
522 }
523 }
524 }
525
526 #[must_use]
530 pub fn drain_vnode_data_for(&self, stage: &str) -> Vec<RecordBatch> {
531 let mut staged = self.holdover.staged.lock();
532 self.drain_inbound_into(&mut staged);
533 staged.remove(stage).unwrap_or_default()
534 }
535
536 #[must_use]
540 pub fn drain_staged_with_prefix(
541 &self,
542 prefix: &str,
543 ) -> FxHashMap<String, Vec<RecordBatch>> {
544 let mut staged = self.holdover.staged.lock();
545 self.drain_inbound_into(&mut staged);
546 let mut out: FxHashMap<String, Vec<RecordBatch>> = FxHashMap::default();
547 staged.retain(|stage, batches| {
548 if stage.starts_with(prefix) {
549 out.insert(stage.clone(), std::mem::take(batches));
550 false
551 } else {
552 true
553 }
554 });
555 out
556 }
557
558 pub fn stage_batch(&self, stage: String, batch: RecordBatch) {
562 self.holdover
563 .staged
564 .lock()
565 .entry(stage)
566 .or_default()
567 .push(batch);
568 }
569
570 #[must_use]
573 pub fn drain_staged_barriers(&self) -> Vec<(ShufflePeerId, CheckpointBarrier)> {
574 std::mem::take(&mut self.holdover.staged_barriers.lock())
575 }
576
577 #[must_use]
579 pub fn drain_all_staged(&self) -> Vec<(String, RecordBatch)> {
580 let mut staged = self.holdover.staged.lock();
581 staged
582 .drain()
583 .flat_map(|(stage, batches)| batches.into_iter().map(move |b| (stage.clone(), b)))
584 .collect()
585 }
586 }
587
588 struct RxReturnGuard<'a> {
591 slot: &'a Mutex<Option<InboundRx>>,
592 notify: &'a tokio::sync::Notify,
593 rx: Option<InboundRx>,
594 }
595
596 impl Drop for RxReturnGuard<'_> {
597 fn drop(&mut self) {
598 if let Some(rx) = self.rx.take() {
599 *self.slot.lock() = Some(rx);
600 self.notify.notify_one();
602 }
603 }
604 }
605
606 struct ShuffleService {
609 tx: InboundTx,
610 }
611
612 #[tonic::async_trait]
613 impl ShuffleTransport for ShuffleService {
614 async fn shuffle(
615 &self,
616 request: Request<tonic::Streaming<ShuffleFrame>>,
617 ) -> Result<tonic::Response<ShuffleSummary>, tonic::Status> {
618 let summary = run_stream(self.tx.clone(), request.into_inner()).await?;
619 Ok(tonic::Response::new(summary))
620 }
621 }
622
623 async fn run_stream(
628 tx: InboundTx,
629 mut stream: tonic::Streaming<ShuffleFrame>,
630 ) -> Result<ShuffleSummary, tonic::Status> {
631 let first = stream
632 .message()
633 .await?
634 .ok_or_else(|| tonic::Status::invalid_argument("shuffle stream closed before Hello"))?;
635 let peer = match first.kind {
636 Some(shuffle_frame::Kind::Hello(h)) => h.node_id,
637 _ => {
638 return Err(tonic::Status::invalid_argument(
639 "first shuffle frame must be Hello",
640 ))
641 }
642 };
643
644 let mut decoders: FxHashMap<String, BatchStreamDecoder> = FxHashMap::default();
645 let mut frames_received = 0u64;
646 while let Some(frame) = stream.message().await? {
647 let kind = frame
648 .kind
649 .ok_or_else(|| tonic::Status::invalid_argument("empty shuffle frame"))?;
650 match kind {
651 shuffle_frame::Kind::Close(_) => break,
652 shuffle_frame::Kind::Hello(h) => {
653 frames_received += 1;
654 if tx
655 .send((peer, ShuffleMessage::Hello(h.node_id)))
656 .await
657 .is_err()
658 {
659 break;
660 }
661 }
662 shuffle_frame::Kind::Barrier(b) => {
663 frames_received += 1;
664 let msg = ShuffleMessage::Barrier(CheckpointBarrier {
665 checkpoint_id: b.checkpoint_id,
666 epoch: b.epoch,
667 flags: b.flags,
668 });
669 if tx.send((peer, msg)).await.is_err() {
670 break;
671 }
672 }
673 shuffle_frame::Kind::VnodeData(v) => {
674 frames_received += 1;
675 let batches = decoders
676 .entry(v.stage.clone())
677 .or_default()
678 .decode_chunk(v.arrow_ipc)
679 .map_err(|e| {
680 tonic::Status::invalid_argument(format!("shuffle ipc: {e}"))
681 })?;
682 let mut stream_broken = false;
683 for batch in batches {
684 if !forward_vnode_batch(&tx, peer, &v.stage, v.vnode, batch).await? {
685 stream_broken = true;
686 break;
687 }
688 }
689 if stream_broken {
690 break;
691 }
692 }
693 }
694 }
695 Ok(ShuffleSummary { frames_received })
696 }
697
698 async fn forward_vnode_batch(
703 tx: &InboundTx,
704 peer: ShufflePeerId,
705 stage: &str,
706 default_vnode: u32,
707 batch: RecordBatch,
708 ) -> Result<bool, tonic::Status> {
709 let schema = batch.schema();
710 let Some((col_idx, _field)) = schema.column_with_name("__laminar_vnode") else {
711 let msg = ShuffleMessage::VnodeData(stage.to_string(), default_vnode, batch);
712 return Ok(tx.send((peer, msg)).await.is_ok());
713 };
714
715 let vnode_array = batch
716 .column(col_idx)
717 .as_any()
718 .downcast_ref::<arrow_array::UInt32Array>()
719 .ok_or_else(|| {
720 tonic::Status::invalid_argument("vnode metadata column is not UInt32Array")
721 })?;
722 let row_vnodes: Vec<u32> = vnode_array.values().to_vec();
723
724 let mut projection: Vec<usize> = (0..schema.fields().len()).collect();
725 projection.remove(col_idx);
726 let batch_without_vnode = batch.project(&projection).map_err(|e| {
727 tonic::Status::internal(format!("Failed to project out vnode metadata: {e}"))
728 })?;
729
730 let slices =
731 crate::shuffle::routing::slice_batch_by_vnodes(&batch_without_vnode, &row_vnodes);
732 for (v, slice) in slices {
733 let sub_msg = ShuffleMessage::VnodeData(stage.to_string(), v, slice);
734 if tx.send((peer, sub_msg)).await.is_err() {
735 return Ok(false);
736 }
737 }
738 Ok(true)
739 }
740
741 #[cfg(test)]
742 mod encode_tests {
743 use super::*;
744 use arrow_array::Int64Array;
745 use arrow_schema::{DataType, Field, Schema};
746
747 #[test]
748 fn schema_change_on_a_stage_is_rejected() {
749 let batch = |name: &str| {
750 let schema = Arc::new(Schema::new(vec![Field::new(name, DataType::Int64, false)]));
751 arrow_array::RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(vec![1]))])
752 .unwrap()
753 };
754 let mut encoders = FxHashMap::default();
755 let msg = ShuffleMessage::VnodeData("s".into(), 0, batch("a"));
756 encode_message(&msg, &mut encoders).unwrap();
757
758 let changed = ShuffleMessage::VnodeData("s".into(), 0, batch("b"));
759 let err = encode_message(&changed, &mut encoders).unwrap_err();
760 assert!(err.message().contains("changed schema"), "{err}");
761
762 let other = ShuffleMessage::VnodeData("t".into(), 0, batch("b"));
764 encode_message(&other, &mut encoders).unwrap();
765 }
766 }
767}
768
769#[cfg(feature = "cluster")]
770pub use grpc::{ShuffleReceiver, ShuffleSender};
771
772#[cfg(not(feature = "cluster"))]
782mod shim {
783 use std::io;
784 use std::net::SocketAddr;
785 use std::sync::Arc;
786
787 use arrow_array::RecordBatch;
788 use crossfire::{mpsc, AsyncRx, MAsyncTx};
789 use parking_lot::Mutex;
790 use rustc_hash::FxHashMap;
791
792 use super::{Holdover, ShuffleMessage, ShufflePeerId, SHUFFLE_RECV_QUEUE};
793 use crate::checkpoint::barrier::CheckpointBarrier;
794
795 type InboundRx = AsyncRx<mpsc::Array<(ShufflePeerId, ShuffleMessage)>>;
796 type InboundTx = MAsyncTx<mpsc::Array<(ShufflePeerId, ShuffleMessage)>>;
797
798 pub struct ShuffleSender {
801 local_id: ShufflePeerId,
802 peers: Mutex<FxHashMap<ShufflePeerId, SocketAddr>>,
803 }
804
805 impl std::fmt::Debug for ShuffleSender {
806 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
807 f.debug_struct("ShuffleSender")
808 .field("local_id", &self.local_id)
809 .finish_non_exhaustive()
810 }
811 }
812
813 impl ShuffleSender {
814 #[must_use]
816 pub fn new(local_id: ShufflePeerId) -> Self {
817 Self {
818 local_id,
819 peers: Mutex::new(FxHashMap::default()),
820 }
821 }
822
823 #[allow(clippy::unused_async)] pub async fn register_peer(&self, peer: ShufflePeerId, addr: SocketAddr) {
826 self.peers.lock().insert(peer, addr);
827 }
828
829 #[allow(clippy::unused_async)] pub async fn send_to(&self, peer: ShufflePeerId, _msg: &ShuffleMessage) -> io::Result<()> {
834 if self.peers.lock().contains_key(&peer) {
835 Ok(())
836 } else {
837 Err(io::Error::new(
838 io::ErrorKind::NotFound,
839 format!("peer {peer} has no registered shuffle address"),
840 ))
841 }
842 }
843
844 pub async fn fan_out_barrier(
847 &self,
848 peers: &[ShufflePeerId],
849 barrier: CheckpointBarrier,
850 ) -> io::Result<()> {
851 let msg = ShuffleMessage::Barrier(barrier);
852 for &peer in peers {
853 self.send_to(peer, &msg).await?;
854 }
855 Ok(())
856 }
857 }
858
859 pub struct ShuffleReceiver {
864 local_id: ShufflePeerId,
865 local_addr: SocketAddr,
866 #[allow(dead_code)]
867 tx: InboundTx,
868 rx: Mutex<Option<InboundRx>>,
869 rx_returned: Arc<tokio::sync::Notify>,
870 holdover: Arc<Holdover>,
871 }
872
873 impl std::fmt::Debug for ShuffleReceiver {
874 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
875 f.debug_struct("ShuffleReceiver")
876 .field("local_id", &self.local_id)
877 .field("local_addr", &self.local_addr)
878 .finish_non_exhaustive()
879 }
880 }
881
882 impl ShuffleReceiver {
883 pub async fn bind(local_id: ShufflePeerId, addr: SocketAddr) -> io::Result<Self> {
886 let listener = tokio::net::TcpListener::bind(addr).await?;
888 let local_addr = listener.local_addr()?;
889 drop(listener);
890 let (tx, rx) =
891 mpsc::bounded_async::<(ShufflePeerId, ShuffleMessage)>(SHUFFLE_RECV_QUEUE);
892 Ok(Self {
893 local_id,
894 local_addr,
895 tx,
896 rx: Mutex::new(Some(rx)),
897 rx_returned: Arc::new(tokio::sync::Notify::new()),
898 holdover: Arc::new(Holdover::default()),
899 })
900 }
901
902 #[must_use]
904 pub fn local_addr(&self) -> SocketAddr {
905 self.local_addr
906 }
907
908 pub async fn recv(&self) -> Option<(ShufflePeerId, ShuffleMessage)> {
910 loop {
911 let taken = { self.rx.lock().take() };
912 let Some(rx) = taken else {
913 self.rx_returned.notified().await;
914 continue;
915 };
916 let mut guard = RxReturnGuard {
917 slot: &self.rx,
918 notify: &self.rx_returned,
919 rx: Some(rx),
920 };
921 let rx = guard.rx.as_mut()?;
922 return rx.recv().await.ok();
923 }
924 }
925
926 #[must_use]
928 pub fn drain_available(&self) -> Vec<(ShufflePeerId, ShuffleMessage)> {
929 let mut out = Vec::new();
930 let slot = self.rx.lock();
931 if let Some(rx) = slot.as_ref() {
932 while let Ok(item) = rx.try_recv() {
933 out.push(item);
934 }
935 }
936 out
937 }
938
939 #[must_use]
942 pub fn drain_vnode_data_for(&self, stage: &str) -> Vec<RecordBatch> {
943 let mut staged = self.holdover.staged.lock();
944 {
945 let slot = self.rx.lock();
946 if let Some(rx) = slot.as_ref() {
947 while let Ok((from, msg)) = rx.try_recv() {
948 match msg {
949 ShuffleMessage::VnodeData(s, _vnode, batch) => {
950 staged.entry(s).or_default().push(batch);
951 }
952 ShuffleMessage::Barrier(b) => {
953 self.holdover.staged_barriers.lock().push((from, b));
954 }
955 _ => {}
956 }
957 }
958 }
959 }
960 staged.remove(stage).unwrap_or_default()
961 }
962
963 #[must_use]
968 pub fn drain_staged_with_prefix(
969 &self,
970 prefix: &str,
971 ) -> FxHashMap<String, Vec<RecordBatch>> {
972 let mut staged = self.holdover.staged.lock();
973 {
974 let slot = self.rx.lock();
975 if let Some(rx) = slot.as_ref() {
976 while let Ok((from, msg)) = rx.try_recv() {
977 match msg {
978 ShuffleMessage::VnodeData(s, _vnode, batch) => {
979 staged.entry(s).or_default().push(batch);
980 }
981 ShuffleMessage::Barrier(b) => {
982 self.holdover.staged_barriers.lock().push((from, b));
983 }
984 _ => {}
985 }
986 }
987 }
988 }
989 let mut out: FxHashMap<String, Vec<RecordBatch>> = FxHashMap::default();
990 staged.retain(|stage, batches| {
991 if stage.starts_with(prefix) {
992 out.insert(stage.clone(), std::mem::take(batches));
993 false
994 } else {
995 true
996 }
997 });
998 out
999 }
1000
1001 pub fn stage_batch(&self, stage: String, batch: RecordBatch) {
1003 self.holdover
1004 .staged
1005 .lock()
1006 .entry(stage)
1007 .or_default()
1008 .push(batch);
1009 }
1010
1011 #[must_use]
1013 pub fn drain_staged_barriers(&self) -> Vec<(ShufflePeerId, CheckpointBarrier)> {
1014 std::mem::take(&mut self.holdover.staged_barriers.lock())
1015 }
1016
1017 #[must_use]
1019 pub fn drain_all_staged(&self) -> Vec<(String, RecordBatch)> {
1020 let mut staged = self.holdover.staged.lock();
1021 staged
1022 .drain()
1023 .flat_map(|(stage, batches)| batches.into_iter().map(move |b| (stage.clone(), b)))
1024 .collect()
1025 }
1026 }
1027
1028 struct RxReturnGuard<'a> {
1031 slot: &'a Mutex<Option<InboundRx>>,
1032 notify: &'a tokio::sync::Notify,
1033 rx: Option<InboundRx>,
1034 }
1035
1036 impl Drop for RxReturnGuard<'_> {
1037 fn drop(&mut self) {
1038 if let Some(rx) = self.rx.take() {
1039 *self.slot.lock() = Some(rx);
1040 self.notify.notify_one();
1041 }
1042 }
1043 }
1044}
1045
1046#[cfg(not(feature = "cluster"))]
1047pub use shim::{ShuffleReceiver, ShuffleSender};
1048
1049#[cfg(all(test, feature = "cluster"))]
1050mod tests {
1051 use std::io;
1052 use std::sync::Arc;
1053
1054 use super::*;
1055
1056 async fn bind_on_loopback(local_id: ShufflePeerId) -> ShuffleReceiver {
1057 ShuffleReceiver::bind(local_id, "127.0.0.1:0".parse().unwrap())
1058 .await
1059 .expect("bind")
1060 }
1061
1062 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1063 async fn sender_to_receiver_delivers_with_peer_attribution() {
1064 let recv = bind_on_loopback(2).await;
1065 let recv_addr = recv.local_addr();
1066
1067 let sender = ShuffleSender::new(1);
1068 sender.register_peer(2, recv_addr).await;
1069 sender
1070 .send_to(2, &ShuffleMessage::Hello(1234))
1071 .await
1072 .unwrap();
1073
1074 let (from, msg) = recv.recv().await.unwrap();
1075 assert_eq!(from, 1, "receiver attributes frame to sender id");
1076 assert_eq!(msg, ShuffleMessage::Hello(1234));
1077 }
1078
1079 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1080 async fn sender_reuses_stream_across_sends() {
1081 let recv = bind_on_loopback(2).await;
1082 let sender = ShuffleSender::new(1);
1083 sender.register_peer(2, recv.local_addr()).await;
1084
1085 for delta in [10u64, 20, 30, 40] {
1086 sender
1087 .send_to(2, &ShuffleMessage::Hello(delta))
1088 .await
1089 .unwrap();
1090 }
1091
1092 let mut got = Vec::new();
1093 for _ in 0..4 {
1094 got.push(recv.recv().await.unwrap().1);
1095 }
1096 assert_eq!(
1097 got,
1098 vec![
1099 ShuffleMessage::Hello(10),
1100 ShuffleMessage::Hello(20),
1101 ShuffleMessage::Hello(30),
1102 ShuffleMessage::Hello(40),
1103 ]
1104 );
1105 }
1106
1107 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1108 async fn send_to_unregistered_peer_errors() {
1109 let sender = ShuffleSender::new(1);
1110 let err = sender
1111 .send_to(99, &ShuffleMessage::Hello(1))
1112 .await
1113 .unwrap_err();
1114 assert_eq!(err.kind(), io::ErrorKind::NotFound);
1115 }
1116
1117 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1118 async fn send_discovers_peer_address_from_kv() {
1119 use crate::cluster::control::{ClusterKv, InMemoryKv};
1120 use crate::cluster::discovery::NodeId;
1121
1122 let recv = bind_on_loopback(2).await;
1126 let kv = Arc::new(InMemoryKv::new(NodeId(1)));
1127 kv.seed(NodeId(2), SHUFFLE_ADDR_KEY, recv.local_addr().to_string());
1128 let sender = ShuffleSender::with_kv(1, kv as Arc<dyn ClusterKv>);
1129
1130 sender.send_to(2, &ShuffleMessage::Hello(7)).await.unwrap();
1131 let (from, msg) = recv.recv().await.unwrap();
1132 assert_eq!(from, 1);
1133 assert_eq!(msg, ShuffleMessage::Hello(7));
1134 }
1135
1136 #[cfg(not(windows))]
1141 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1142 async fn send_reconnects_after_peer_restart_at_new_address() {
1143 let recv_v1 = bind_on_loopback(2).await;
1144 let addr_v1 = recv_v1.local_addr();
1145
1146 let sender = ShuffleSender::new(1);
1147 sender.register_peer(2, addr_v1).await;
1148 sender
1149 .send_to(2, &ShuffleMessage::Hello(111))
1150 .await
1151 .unwrap();
1152 let (from, msg) = recv_v1.recv().await.unwrap();
1153 assert_eq!(from, 1);
1154 assert_eq!(msg, ShuffleMessage::Hello(111));
1155
1156 drop(recv_v1);
1158
1159 let recv_v2 = bind_on_loopback(2).await;
1161 let addr_v2 = recv_v2.local_addr();
1162 assert_ne!(addr_v1, addr_v2, "ephemeral rebind must pick a new port");
1163 sender.register_peer(2, addr_v2).await;
1164
1165 let deadline = std::time::Instant::now() + std::time::Duration::from_secs(30);
1168 loop {
1169 let _ = sender.send_to(2, &ShuffleMessage::Hello(222)).await;
1170 if let Some((from, ShuffleMessage::Hello(222))) =
1171 tokio::time::timeout(std::time::Duration::from_millis(200), recv_v2.recv())
1172 .await
1173 .ok()
1174 .flatten()
1175 {
1176 assert_eq!(from, 1);
1177 return;
1178 }
1179 assert!(
1180 std::time::Instant::now() < deadline,
1181 "did not deliver to restarted peer within 30s",
1182 );
1183 }
1184 }
1185
1186 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1189 async fn drain_staged_with_prefix_lifts_subs_and_keeps_operator_stages() {
1190 use arrow_array::{Int64Array, RecordBatch};
1191 use arrow_schema::{DataType, Field, Schema};
1192 use rustc_hash::FxHashMap;
1193
1194 use crate::checkpoint::barrier::CheckpointBarrier;
1195
1196 fn batch(values: Vec<i64>) -> RecordBatch {
1197 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
1198 RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(values))]).unwrap()
1199 }
1200 fn col(b: &RecordBatch) -> Vec<i64> {
1201 b.column(0)
1202 .as_any()
1203 .downcast_ref::<Int64Array>()
1204 .unwrap()
1205 .values()
1206 .to_vec()
1207 }
1208
1209 let recv = bind_on_loopback(2).await;
1210 let sender = ShuffleSender::new(1);
1211 sender.register_peer(2, recv.local_addr()).await;
1212
1213 for (stage, vals) in [
1217 ("__sub::alpha", vec![1, 2, 3]),
1218 ("__sub::beta", vec![4, 5, 6]),
1219 ("op_stage", vec![7, 8, 9]),
1220 ] {
1221 sender
1222 .send_to(2, &ShuffleMessage::VnodeData(stage.into(), 0, batch(vals)))
1223 .await
1224 .unwrap();
1225 }
1226 sender
1227 .send_to(
1228 2,
1229 &ShuffleMessage::Barrier(CheckpointBarrier {
1230 checkpoint_id: 7,
1231 epoch: 3,
1232 flags: 0,
1233 }),
1234 )
1235 .await
1236 .unwrap();
1237
1238 let mut subs: FxHashMap<String, Vec<RecordBatch>> = FxHashMap::default();
1241 let mut barriers = Vec::new();
1242 let deadline = std::time::Instant::now() + std::time::Duration::from_secs(2);
1243 while subs.len() < 2 || barriers.is_empty() {
1244 for (k, v) in recv.drain_staged_with_prefix("__sub::") {
1245 subs.entry(k).or_default().extend(v);
1246 }
1247 barriers.extend(recv.drain_staged_barriers());
1248 assert!(
1249 std::time::Instant::now() < deadline,
1250 "frames not delivered within 2s",
1251 );
1252 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
1253 }
1254
1255 assert_eq!(subs.len(), 2, "only the two __sub:: stages are returned");
1257 assert_eq!(col(&subs["__sub::alpha"][0]), vec![1, 2, 3]);
1258 assert_eq!(col(&subs["__sub::beta"][0]), vec![4, 5, 6]);
1259
1260 assert_eq!(barriers.len(), 1);
1262 assert_eq!(barriers[0].0, 1, "barrier attributed to sender peer 1");
1263 assert_eq!(barriers[0].1.checkpoint_id, 7);
1264
1265 let op = recv.drain_vnode_data_for("op_stage");
1267 assert_eq!(op.len(), 1);
1268 assert_eq!(col(&op[0]), vec![7, 8, 9]);
1269
1270 assert!(recv.drain_staged_with_prefix("__sub::").is_empty());
1272 }
1273}