1use std::io;
11
12use arrow_array::RecordBatch;
13use tokio::io::{AsyncReadExt, AsyncWriteExt};
14
15use crate::checkpoint::barrier::CheckpointBarrier;
16use crate::serialization::{deserialize_batch_stream, serialize_batch_stream};
17
18pub(crate) const TAG_BARRIER: u8 = 0x02;
20pub(crate) const TAG_HELLO: u8 = 0x04;
23pub(crate) const TAG_VNODE_DATA: u8 = 0x05;
26pub(crate) const TAG_CLOSE: u8 = 0xFF;
28
29#[derive(Debug, Clone, PartialEq)]
31pub enum ShuffleMessage {
32 Barrier(CheckpointBarrier),
34 Hello(u64),
36 VnodeData(u32, RecordBatch),
38 Close(String),
40}
41
42pub const MAX_PAYLOAD_BYTES: usize = 64 * 1024 * 1024;
45
46pub(crate) async fn write_message<W>(writer: &mut W, msg: &ShuffleMessage) -> io::Result<()>
52where
53 W: AsyncWriteExt + Unpin,
54{
55 let (tag, payload): (u8, Vec<u8>) = match msg {
56 ShuffleMessage::Barrier(b) => {
57 let mut buf = Vec::with_capacity(24);
58 buf.extend_from_slice(&b.checkpoint_id.to_le_bytes());
59 buf.extend_from_slice(&b.epoch.to_le_bytes());
60 buf.extend_from_slice(&b.flags.to_le_bytes());
61 (TAG_BARRIER, buf)
62 }
63 ShuffleMessage::Hello(node_id) => (TAG_HELLO, node_id.to_le_bytes().to_vec()),
64 ShuffleMessage::VnodeData(vnode, batch) => {
65 let ipc = serialize_batch_stream(batch)
66 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?;
67 let mut buf = Vec::with_capacity(4 + ipc.len());
68 buf.extend_from_slice(&vnode.to_le_bytes());
69 buf.extend_from_slice(&ipc);
70 (TAG_VNODE_DATA, buf)
71 }
72 ShuffleMessage::Close(reason) => (TAG_CLOSE, reason.as_bytes().to_vec()),
73 };
74
75 if payload.len() > MAX_PAYLOAD_BYTES {
76 return Err(io::Error::new(
77 io::ErrorKind::InvalidInput,
78 format!(
79 "shuffle payload {} exceeds MAX_PAYLOAD_BYTES ({MAX_PAYLOAD_BYTES})",
80 payload.len()
81 ),
82 ));
83 }
84
85 let total_len = u32::try_from(payload.len() + 1)
88 .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "frame length overflow"))?;
89 writer.write_all(&total_len.to_be_bytes()).await?;
90 writer.write_all(&[tag]).await?;
91 writer.write_all(&payload).await?;
92 writer.flush().await?;
93 Ok(())
94}
95
96pub(crate) async fn read_message<R>(reader: &mut R) -> io::Result<ShuffleMessage>
107where
108 R: AsyncReadExt + Unpin,
109{
110 let mut len_buf = [0u8; 4];
111 reader.read_exact(&mut len_buf).await?;
112 let total_len = u32::from_be_bytes(len_buf) as usize;
113 if total_len == 0 {
114 return Err(io::Error::new(
115 io::ErrorKind::InvalidData,
116 "zero-length shuffle frame",
117 ));
118 }
119 let payload_len = total_len - 1;
120 if payload_len > MAX_PAYLOAD_BYTES {
121 return Err(io::Error::new(
122 io::ErrorKind::InvalidInput,
123 format!("shuffle frame {payload_len} exceeds MAX_PAYLOAD_BYTES"),
124 ));
125 }
126
127 let mut tag_buf = [0u8; 1];
128 reader.read_exact(&mut tag_buf).await?;
129 let tag = tag_buf[0];
130
131 let mut payload = vec![0u8; payload_len];
132 reader.read_exact(&mut payload).await?;
133
134 match tag {
135 TAG_BARRIER => {
136 if payload.len() != 24 {
137 return Err(io::Error::new(
138 io::ErrorKind::InvalidData,
139 format!("barrier payload {} bytes, expected 24", payload.len()),
140 ));
141 }
142 let checkpoint_id = u64::from_le_bytes(payload[0..8].try_into().unwrap());
143 let epoch = u64::from_le_bytes(payload[8..16].try_into().unwrap());
144 let flags = u64::from_le_bytes(payload[16..24].try_into().unwrap());
145 Ok(ShuffleMessage::Barrier(CheckpointBarrier {
146 checkpoint_id,
147 epoch,
148 flags,
149 }))
150 }
151 TAG_HELLO => {
152 if payload.len() != 8 {
153 return Err(io::Error::new(
154 io::ErrorKind::InvalidData,
155 format!("hello payload {} bytes, expected 8", payload.len()),
156 ));
157 }
158 let node_id = u64::from_le_bytes(payload[..].try_into().unwrap());
159 Ok(ShuffleMessage::Hello(node_id))
160 }
161 TAG_VNODE_DATA => {
162 if payload.len() < 4 {
163 return Err(io::Error::new(
164 io::ErrorKind::InvalidData,
165 format!("vnode-data payload {} bytes, expected ≥4", payload.len()),
166 ));
167 }
168 let vnode = u32::from_le_bytes(payload[..4].try_into().unwrap());
169 let batch = deserialize_batch_stream(&payload[4..])
170 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?;
171 Ok(ShuffleMessage::VnodeData(vnode, batch))
172 }
173 TAG_CLOSE => {
174 let reason = String::from_utf8(payload).map_err(|e| {
175 io::Error::new(io::ErrorKind::InvalidData, format!("close reason: {e}"))
176 })?;
177 Ok(ShuffleMessage::Close(reason))
178 }
179 other => Err(io::Error::new(
180 io::ErrorKind::InvalidData,
181 format!("unknown shuffle tag byte: {other:#04x}"),
182 )),
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use std::sync::Arc;
190
191 use arrow_array::Int64Array;
192 use arrow_schema::{DataType, Field, Schema};
193 use tokio::io::duplex;
194
195 use crate::checkpoint::barrier::flags;
196
197 fn sample_batch() -> RecordBatch {
198 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
199 let col = Arc::new(Int64Array::from(vec![1i64, 2, 3]));
200 RecordBatch::try_new(schema, vec![col]).unwrap()
201 }
202
203 #[tokio::test]
204 async fn barrier_roundtrip() {
205 let (mut a, mut b) = duplex(512);
206 let barrier = CheckpointBarrier {
207 checkpoint_id: 17,
208 epoch: 42,
209 flags: flags::FULL_SNAPSHOT,
210 };
211 write_message(&mut a, &ShuffleMessage::Barrier(barrier))
212 .await
213 .unwrap();
214 let got = read_message(&mut b).await.unwrap();
215 assert_eq!(got, ShuffleMessage::Barrier(barrier));
216 }
217
218 #[tokio::test]
219 async fn vnode_data_roundtrip() {
220 let (mut a, mut b) = duplex(64 * 1024);
221 let batch = sample_batch();
222 write_message(&mut a, &ShuffleMessage::VnodeData(42, batch.clone()))
223 .await
224 .unwrap();
225 match read_message(&mut b).await.unwrap() {
226 ShuffleMessage::VnodeData(v, got) => {
227 assert_eq!(v, 42);
228 assert_eq!(got, batch);
229 }
230 other => panic!("expected VnodeData, got {other:?}"),
231 }
232 }
233
234 #[tokio::test]
235 async fn hello_roundtrip() {
236 let (mut a, mut b) = duplex(64);
237 write_message(&mut a, &ShuffleMessage::Hello(0xDEAD_BEEF))
238 .await
239 .unwrap();
240 assert_eq!(
241 read_message(&mut b).await.unwrap(),
242 ShuffleMessage::Hello(0xDEAD_BEEF)
243 );
244 }
245
246 #[tokio::test]
247 async fn close_roundtrip() {
248 let (mut a, mut b) = duplex(128);
249 write_message(&mut a, &ShuffleMessage::Close("graceful".into()))
250 .await
251 .unwrap();
252 assert_eq!(
253 read_message(&mut b).await.unwrap(),
254 ShuffleMessage::Close("graceful".into())
255 );
256 }
257
258 #[tokio::test]
259 async fn sequenced_messages_preserve_order() {
260 let (mut a, mut b) = duplex(64 * 1024);
263 let batch1 = sample_batch();
264 let barrier = CheckpointBarrier::new(1, 1);
265
266 write_message(&mut a, &ShuffleMessage::VnodeData(0, batch1.clone()))
267 .await
268 .unwrap();
269 write_message(&mut a, &ShuffleMessage::Barrier(barrier))
270 .await
271 .unwrap();
272 write_message(&mut a, &ShuffleMessage::Hello(9))
273 .await
274 .unwrap();
275
276 assert!(matches!(
277 read_message(&mut b).await.unwrap(),
278 ShuffleMessage::VnodeData(0, _),
279 ));
280 assert_eq!(
281 read_message(&mut b).await.unwrap(),
282 ShuffleMessage::Barrier(barrier)
283 );
284 assert_eq!(
285 read_message(&mut b).await.unwrap(),
286 ShuffleMessage::Hello(9)
287 );
288 }
289
290 #[tokio::test]
291 async fn unknown_tag_returns_error() {
292 let (mut a, mut b) = duplex(64);
294 a.write_all(&[0u8, 0, 0, 1]).await.unwrap(); a.write_all(&[0x7Fu8]).await.unwrap(); drop(a);
297
298 let err = read_message(&mut b).await.unwrap_err();
299 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
300 }
301
302 #[tokio::test]
303 async fn oversize_frame_rejected() {
304 let (mut a, mut b) = duplex(64);
305 let bogus_len = u32::try_from(MAX_PAYLOAD_BYTES)
306 .expect("MAX_PAYLOAD_BYTES fits in u32")
307 .saturating_add(2);
308 a.write_all(&bogus_len.to_be_bytes()).await.unwrap();
309 drop(a);
310 let err = read_message(&mut b).await.unwrap_err();
311 assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
312 }
313}