Skip to main content

laminar_core/shuffle/
message.rs

1//! Wire format for the shuffle: `[u32 length][u8 tag][payload]`.
2//!
3//! Each frame is self-contained. The receiver reads a 4-byte
4//! big-endian length, a 1-byte tag, then exactly `length − 1` bytes of
5//! payload. This avoids entangling the shuffle with Arrow IPC stream
6//! framing: a Data message's payload is itself an Arrow IPC single-
7//! batch stream (schema + batch), so a schema roll on one message
8//! doesn't poison the connection.
9
10use std::io;
11
12use arrow_array::RecordBatch;
13use tokio::io::{AsyncReadExt, AsyncWriteExt};
14
15use crate::checkpoint::barrier::CheckpointBarrier;
16use crate::serialization::{deserialize_batch_stream, serialize_batch_stream};
17
18/// Checkpoint barrier flowing in-band with the data stream.
19pub(crate) const TAG_BARRIER: u8 = 0x02;
20/// Connection handshake — one `Hello(node_id)` is the first frame
21/// exchanged in each direction.
22pub(crate) const TAG_HELLO: u8 = 0x04;
23/// Pre-routed data: the sender has classified the batch's rows as
24/// all belonging to `vnode` so the receiver skips re-hashing.
25pub(crate) const TAG_VNODE_DATA: u8 = 0x05;
26/// Graceful-close signal with a short UTF-8 reason string.
27pub(crate) const TAG_CLOSE: u8 = 0xFF;
28
29/// Logical message carried on a shuffle connection.
30#[derive(Debug, Clone, PartialEq)]
31pub enum ShuffleMessage {
32    /// A checkpoint barrier (Chandy-Lamport).
33    Barrier(CheckpointBarrier),
34    /// Peer identifying itself during the connection handshake.
35    Hello(u64),
36    /// A batch of rows pre-routed to `vnode`.
37    VnodeData(u32, RecordBatch),
38    /// Sender announcing graceful shutdown with a brief reason.
39    Close(String),
40}
41
42/// Maximum payload size accepted by the codec: 64 MiB. Receivers
43/// reject oversized frames instead of allocating unbounded memory.
44pub const MAX_PAYLOAD_BYTES: usize = 64 * 1024 * 1024;
45
46/// Serialize `msg` and write it as one frame on `writer`.
47///
48/// # Errors
49/// Returns [`io::Error`] on I/O failure, Arrow IPC encoding failure
50/// (wrapped as `InvalidData`), or payload size over [`MAX_PAYLOAD_BYTES`].
51pub(crate) async fn write_message<W>(writer: &mut W, msg: &ShuffleMessage) -> io::Result<()>
52where
53    W: AsyncWriteExt + Unpin,
54{
55    let (tag, payload): (u8, Vec<u8>) = match msg {
56        ShuffleMessage::Barrier(b) => {
57            let mut buf = Vec::with_capacity(24);
58            buf.extend_from_slice(&b.checkpoint_id.to_le_bytes());
59            buf.extend_from_slice(&b.epoch.to_le_bytes());
60            buf.extend_from_slice(&b.flags.to_le_bytes());
61            (TAG_BARRIER, buf)
62        }
63        ShuffleMessage::Hello(node_id) => (TAG_HELLO, node_id.to_le_bytes().to_vec()),
64        ShuffleMessage::VnodeData(vnode, batch) => {
65            let ipc = serialize_batch_stream(batch)
66                .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?;
67            let mut buf = Vec::with_capacity(4 + ipc.len());
68            buf.extend_from_slice(&vnode.to_le_bytes());
69            buf.extend_from_slice(&ipc);
70            (TAG_VNODE_DATA, buf)
71        }
72        ShuffleMessage::Close(reason) => (TAG_CLOSE, reason.as_bytes().to_vec()),
73    };
74
75    if payload.len() > MAX_PAYLOAD_BYTES {
76        return Err(io::Error::new(
77            io::ErrorKind::InvalidInput,
78            format!(
79                "shuffle payload {} exceeds MAX_PAYLOAD_BYTES ({MAX_PAYLOAD_BYTES})",
80                payload.len()
81            ),
82        ));
83    }
84
85    // Frame: [u32 total_len BE][u8 tag][payload]
86    // total_len = 1 (tag) + payload.len()
87    let total_len = u32::try_from(payload.len() + 1)
88        .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "frame length overflow"))?;
89    writer.write_all(&total_len.to_be_bytes()).await?;
90    writer.write_all(&[tag]).await?;
91    writer.write_all(&payload).await?;
92    writer.flush().await?;
93    Ok(())
94}
95
96/// Read one frame from `reader` and decode it to a `ShuffleMessage`.
97///
98/// # Errors
99/// Returns [`io::Error`] on truncated frames, unknown tags, or Arrow
100/// IPC decoding failures. An oversized frame (over
101/// [`MAX_PAYLOAD_BYTES`]) is rejected before allocation.
102///
103/// # Panics
104/// Panics only on internal invariants (fixed-width slice conversions
105/// after length checks). A correctly-framed stream cannot panic.
106pub(crate) async fn read_message<R>(reader: &mut R) -> io::Result<ShuffleMessage>
107where
108    R: AsyncReadExt + Unpin,
109{
110    let mut len_buf = [0u8; 4];
111    reader.read_exact(&mut len_buf).await?;
112    let total_len = u32::from_be_bytes(len_buf) as usize;
113    if total_len == 0 {
114        return Err(io::Error::new(
115            io::ErrorKind::InvalidData,
116            "zero-length shuffle frame",
117        ));
118    }
119    let payload_len = total_len - 1;
120    if payload_len > MAX_PAYLOAD_BYTES {
121        return Err(io::Error::new(
122            io::ErrorKind::InvalidInput,
123            format!("shuffle frame {payload_len} exceeds MAX_PAYLOAD_BYTES"),
124        ));
125    }
126
127    let mut tag_buf = [0u8; 1];
128    reader.read_exact(&mut tag_buf).await?;
129    let tag = tag_buf[0];
130
131    let mut payload = vec![0u8; payload_len];
132    reader.read_exact(&mut payload).await?;
133
134    match tag {
135        TAG_BARRIER => {
136            if payload.len() != 24 {
137                return Err(io::Error::new(
138                    io::ErrorKind::InvalidData,
139                    format!("barrier payload {} bytes, expected 24", payload.len()),
140                ));
141            }
142            let checkpoint_id = u64::from_le_bytes(payload[0..8].try_into().unwrap());
143            let epoch = u64::from_le_bytes(payload[8..16].try_into().unwrap());
144            let flags = u64::from_le_bytes(payload[16..24].try_into().unwrap());
145            Ok(ShuffleMessage::Barrier(CheckpointBarrier {
146                checkpoint_id,
147                epoch,
148                flags,
149            }))
150        }
151        TAG_HELLO => {
152            if payload.len() != 8 {
153                return Err(io::Error::new(
154                    io::ErrorKind::InvalidData,
155                    format!("hello payload {} bytes, expected 8", payload.len()),
156                ));
157            }
158            let node_id = u64::from_le_bytes(payload[..].try_into().unwrap());
159            Ok(ShuffleMessage::Hello(node_id))
160        }
161        TAG_VNODE_DATA => {
162            if payload.len() < 4 {
163                return Err(io::Error::new(
164                    io::ErrorKind::InvalidData,
165                    format!("vnode-data payload {} bytes, expected ≥4", payload.len()),
166                ));
167            }
168            let vnode = u32::from_le_bytes(payload[..4].try_into().unwrap());
169            let batch = deserialize_batch_stream(&payload[4..])
170                .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?;
171            Ok(ShuffleMessage::VnodeData(vnode, batch))
172        }
173        TAG_CLOSE => {
174            let reason = String::from_utf8(payload).map_err(|e| {
175                io::Error::new(io::ErrorKind::InvalidData, format!("close reason: {e}"))
176            })?;
177            Ok(ShuffleMessage::Close(reason))
178        }
179        other => Err(io::Error::new(
180            io::ErrorKind::InvalidData,
181            format!("unknown shuffle tag byte: {other:#04x}"),
182        )),
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use std::sync::Arc;
190
191    use arrow_array::Int64Array;
192    use arrow_schema::{DataType, Field, Schema};
193    use tokio::io::duplex;
194
195    use crate::checkpoint::barrier::flags;
196
197    fn sample_batch() -> RecordBatch {
198        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
199        let col = Arc::new(Int64Array::from(vec![1i64, 2, 3]));
200        RecordBatch::try_new(schema, vec![col]).unwrap()
201    }
202
203    #[tokio::test]
204    async fn barrier_roundtrip() {
205        let (mut a, mut b) = duplex(512);
206        let barrier = CheckpointBarrier {
207            checkpoint_id: 17,
208            epoch: 42,
209            flags: flags::FULL_SNAPSHOT,
210        };
211        write_message(&mut a, &ShuffleMessage::Barrier(barrier))
212            .await
213            .unwrap();
214        let got = read_message(&mut b).await.unwrap();
215        assert_eq!(got, ShuffleMessage::Barrier(barrier));
216    }
217
218    #[tokio::test]
219    async fn vnode_data_roundtrip() {
220        let (mut a, mut b) = duplex(64 * 1024);
221        let batch = sample_batch();
222        write_message(&mut a, &ShuffleMessage::VnodeData(42, batch.clone()))
223            .await
224            .unwrap();
225        match read_message(&mut b).await.unwrap() {
226            ShuffleMessage::VnodeData(v, got) => {
227                assert_eq!(v, 42);
228                assert_eq!(got, batch);
229            }
230            other => panic!("expected VnodeData, got {other:?}"),
231        }
232    }
233
234    #[tokio::test]
235    async fn hello_roundtrip() {
236        let (mut a, mut b) = duplex(64);
237        write_message(&mut a, &ShuffleMessage::Hello(0xDEAD_BEEF))
238            .await
239            .unwrap();
240        assert_eq!(
241            read_message(&mut b).await.unwrap(),
242            ShuffleMessage::Hello(0xDEAD_BEEF)
243        );
244    }
245
246    #[tokio::test]
247    async fn close_roundtrip() {
248        let (mut a, mut b) = duplex(128);
249        write_message(&mut a, &ShuffleMessage::Close("graceful".into()))
250            .await
251            .unwrap();
252        assert_eq!(
253            read_message(&mut b).await.unwrap(),
254            ShuffleMessage::Close("graceful".into())
255        );
256    }
257
258    #[tokio::test]
259    async fn sequenced_messages_preserve_order() {
260        // Multiple messages on the same connection round-trip in FIFO
261        // order — the property per-key ordering relies on.
262        let (mut a, mut b) = duplex(64 * 1024);
263        let batch1 = sample_batch();
264        let barrier = CheckpointBarrier::new(1, 1);
265
266        write_message(&mut a, &ShuffleMessage::VnodeData(0, batch1.clone()))
267            .await
268            .unwrap();
269        write_message(&mut a, &ShuffleMessage::Barrier(barrier))
270            .await
271            .unwrap();
272        write_message(&mut a, &ShuffleMessage::Hello(9))
273            .await
274            .unwrap();
275
276        assert!(matches!(
277            read_message(&mut b).await.unwrap(),
278            ShuffleMessage::VnodeData(0, _),
279        ));
280        assert_eq!(
281            read_message(&mut b).await.unwrap(),
282            ShuffleMessage::Barrier(barrier)
283        );
284        assert_eq!(
285            read_message(&mut b).await.unwrap(),
286            ShuffleMessage::Hello(9)
287        );
288    }
289
290    #[tokio::test]
291    async fn unknown_tag_returns_error() {
292        // Write a hand-rolled frame with a bogus tag.
293        let (mut a, mut b) = duplex(64);
294        a.write_all(&[0u8, 0, 0, 1]).await.unwrap(); // total_len = 1
295        a.write_all(&[0x7Fu8]).await.unwrap(); // unknown tag
296        drop(a);
297
298        let err = read_message(&mut b).await.unwrap_err();
299        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
300    }
301
302    #[tokio::test]
303    async fn oversize_frame_rejected() {
304        let (mut a, mut b) = duplex(64);
305        let bogus_len = u32::try_from(MAX_PAYLOAD_BYTES)
306            .expect("MAX_PAYLOAD_BYTES fits in u32")
307            .saturating_add(2);
308        a.write_all(&bogus_len.to_be_bytes()).await.unwrap();
309        drop(a);
310        let err = read_message(&mut b).await.unwrap_err();
311        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
312    }
313}