1#![allow(clippy::disallowed_types)] use std::any::Any;
9use std::collections::HashMap;
10use std::fmt::{self, Debug, Formatter};
11use std::sync::{Arc, OnceLock};
12
13use arrow::compute::take;
14use arrow_array::{RecordBatch, UInt32Array};
15use arrow_schema::SchemaRef;
16use datafusion::error::Result;
17use datafusion::execution::{SendableRecordBatchStream, TaskContext};
18use datafusion::physical_expr::{EquivalenceProperties, Partitioning};
19use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
20use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
21use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
22use datafusion_common::DataFusionError;
23use futures::stream::{self, StreamExt};
24use laminar_core::checkpoint::barrier::CheckpointBarrier;
25use laminar_core::shuffle::{BarrierTracker, ShufflePeerId, ShuffleReceiver, ShuffleSender};
26use laminar_core::state::{owned_vnodes, NodeId, VnodeRegistry};
27use tokio::sync::{mpsc, watch, Mutex as AsyncMutex};
28use tokio::task::JoinHandle;
29
30pub struct ClusterRepartitionExec {
32 input: Arc<dyn ExecutionPlan>,
33 hash_columns: Vec<usize>,
36 registry: Arc<VnodeRegistry>,
37 sender: Arc<ShuffleSender>,
38 receiver: Arc<ShuffleReceiver>,
39 self_id: NodeId,
40 peers: Vec<ShufflePeerId>,
42 owned: Vec<u32>,
44 vnode_to_partition: HashMap<u32, usize>,
45 schema: SchemaRef,
46 properties: PlanProperties,
47 runtime: OnceLock<Arc<RuntimeState>>,
48}
49
50struct RuntimeState {
51 receivers: AsyncMutex<Vec<Option<mpsc::Receiver<RecordBatch>>>>,
53 inject_barrier_tx: mpsc::UnboundedSender<CheckpointBarrier>,
54 aligned_epoch_watch: watch::Receiver<u64>,
56 _router: JoinHandle<()>,
57 _dispatcher: JoinHandle<()>,
58}
59
60impl ClusterRepartitionExec {
61 #[must_use]
64 pub fn registry(&self) -> &Arc<VnodeRegistry> {
65 &self.registry
66 }
67
68 pub fn try_new(
73 input: Arc<dyn ExecutionPlan>,
74 hash_columns: Vec<usize>,
75 registry: Arc<VnodeRegistry>,
76 sender: Arc<ShuffleSender>,
77 receiver: Arc<ShuffleReceiver>,
78 self_id: NodeId,
79 ) -> Result<Self> {
80 let schema = input.schema();
81 if let Some(&bad) = hash_columns.iter().find(|&&i| i >= schema.fields().len()) {
82 return Err(DataFusionError::Plan(format!(
83 "ClusterRepartitionExec: hash column {bad} out of range (schema has {} fields)",
84 schema.fields().len()
85 )));
86 }
87
88 let owned = owned_vnodes(®istry, self_id);
89 if owned.is_empty() {
90 return Err(DataFusionError::Plan(format!(
91 "ClusterRepartitionExec: instance {:?} owns no vnodes",
92 self_id
93 )));
94 }
95 let vnode_to_partition = owned.iter().enumerate().map(|(i, &v)| (v, i)).collect();
96
97 let mut peer_set: std::collections::BTreeSet<u64> = std::collections::BTreeSet::new();
101 for v in 0..registry.vnode_count() {
102 let o = registry.owner(v);
103 if !o.is_unassigned() && o != self_id {
104 peer_set.insert(o.0);
105 }
106 }
107 let peers: Vec<ShufflePeerId> = peer_set.into_iter().collect();
108
109 let properties = PlanProperties::new(
110 EquivalenceProperties::new(Arc::clone(&schema)),
111 Partitioning::UnknownPartitioning(owned.len()),
114 EmissionType::Incremental,
115 Boundedness::Unbounded {
116 requires_infinite_memory: false,
117 },
118 );
119
120 Ok(Self {
121 input,
122 hash_columns,
123 registry,
124 sender,
125 receiver,
126 self_id,
127 peers,
128 owned,
129 vnode_to_partition,
130 schema,
131 properties,
132 runtime: OnceLock::new(),
133 })
134 }
135
136 pub fn inject_barrier(&self, barrier: CheckpointBarrier) -> Result<()> {
141 let Some(runtime) = self.runtime.get() else {
142 return Err(DataFusionError::Execution(
143 "ClusterRepartitionExec::inject_barrier before execute()".into(),
144 ));
145 };
146 runtime
147 .inject_barrier_tx
148 .send(barrier)
149 .map_err(|_| DataFusionError::Execution("router task exited".into()))
150 }
151
152 #[must_use]
155 pub fn aligned_epoch_watch(&self) -> Option<watch::Receiver<u64>> {
156 self.runtime.get().map(|rt| rt.aligned_epoch_watch.clone())
157 }
158
159 fn init_runtime(&self, context: &Arc<TaskContext>) -> Result<Arc<RuntimeState>> {
161 if let Some(existing) = self.runtime.get() {
162 return Ok(Arc::clone(existing));
163 }
164
165 let n_partitions = self.owned.len();
166 let mut partition_txs: Vec<mpsc::Sender<RecordBatch>> = Vec::with_capacity(n_partitions);
167 let mut receivers = Vec::with_capacity(n_partitions);
168 for _ in 0..n_partitions {
169 let (tx, rx) = mpsc::channel::<RecordBatch>(16);
170 partition_txs.push(tx);
171 receivers.push(Some(rx));
172 }
173
174 let input_stream = self.input.execute(0, Arc::clone(context))?;
175 let hash_columns = self.hash_columns.clone();
176 let registry = Arc::clone(&self.registry);
177 let self_id = self.self_id;
178 let sender = Arc::clone(&self.sender);
179 let receiver = Arc::clone(&self.receiver);
180 let vnode_to_partition = self.vnode_to_partition.clone();
181 let peers = self.peers.clone();
182
183 let (inject_tx, inject_rx) = mpsc::unbounded_channel::<CheckpointBarrier>();
188 let (peer_tx, peer_rx) = mpsc::unbounded_channel::<(ShufflePeerId, CheckpointBarrier)>();
189 let (aligned_tx, aligned_rx) = watch::channel::<u64>(0);
190
191 let router_txs = partition_txs.clone();
192 let router_registry = Arc::clone(®istry);
193 let router_vtp = vnode_to_partition.clone();
194 let router_sender = Arc::clone(&sender);
195 let router_peers = peers.clone();
196 let router = tokio::spawn(async move {
197 route_input_stream(
198 input_stream,
199 hash_columns,
200 router_registry,
201 self_id,
202 router_vtp,
203 router_sender,
204 router_txs,
205 router_peers,
206 inject_rx,
207 peer_rx,
208 aligned_tx,
209 )
210 .await;
211 });
212
213 let dispatcher_txs = partition_txs;
214 let dispatcher = tokio::spawn(async move {
215 dispatch_inbound(receiver, vnode_to_partition, dispatcher_txs, peer_tx).await;
216 });
217
218 let state = Arc::new(RuntimeState {
219 receivers: AsyncMutex::new(receivers),
220 inject_barrier_tx: inject_tx,
221 aligned_epoch_watch: aligned_rx,
222 _router: router,
223 _dispatcher: dispatcher,
224 });
225 Ok(Arc::clone(self.runtime.get_or_init(|| state)))
226 }
227}
228
229impl Debug for ClusterRepartitionExec {
230 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
231 f.debug_struct("ClusterRepartitionExec")
232 .field("self_id", &self.self_id)
233 .field("hash_columns", &self.hash_columns)
234 .field("owned_vnodes", &self.owned.len())
235 .finish_non_exhaustive()
236 }
237}
238
239impl DisplayAs for ClusterRepartitionExec {
240 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter<'_>) -> fmt::Result {
241 match t {
242 DisplayFormatType::Default | DisplayFormatType::Verbose => write!(
243 f,
244 "ClusterRepartitionExec: owned_vnodes={}, hash_columns={:?}",
245 self.owned.len(),
246 self.hash_columns
247 ),
248 DisplayFormatType::TreeRender => write!(f, "ClusterRepartitionExec"),
249 }
250 }
251}
252
253impl ExecutionPlan for ClusterRepartitionExec {
254 fn name(&self) -> &'static str {
255 "ClusterRepartitionExec"
256 }
257
258 fn as_any(&self) -> &dyn Any {
259 self
260 }
261
262 fn schema(&self) -> SchemaRef {
263 Arc::clone(&self.schema)
264 }
265
266 fn properties(&self) -> &PlanProperties {
267 &self.properties
268 }
269
270 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
271 vec![&self.input]
272 }
273
274 fn with_new_children(
275 self: Arc<Self>,
276 mut children: Vec<Arc<dyn ExecutionPlan>>,
277 ) -> Result<Arc<dyn ExecutionPlan>> {
278 if children.len() != 1 {
279 return Err(DataFusionError::Plan(
280 "ClusterRepartitionExec requires exactly one child".into(),
281 ));
282 }
283 Ok(Arc::new(Self {
284 input: children.swap_remove(0),
285 hash_columns: self.hash_columns.clone(),
286 registry: Arc::clone(&self.registry),
287 sender: Arc::clone(&self.sender),
288 receiver: Arc::clone(&self.receiver),
289 self_id: self.self_id,
290 peers: self.peers.clone(),
291 owned: self.owned.clone(),
292 vnode_to_partition: self.vnode_to_partition.clone(),
293 schema: Arc::clone(&self.schema),
294 properties: self.properties.clone(),
295 runtime: OnceLock::new(),
297 }))
298 }
299
300 fn execute(
301 &self,
302 partition: usize,
303 context: Arc<TaskContext>,
304 ) -> Result<SendableRecordBatchStream> {
305 if partition >= self.owned.len() {
306 return Err(DataFusionError::Plan(format!(
307 "ClusterRepartitionExec: partition {partition} >= output partitions {}",
308 self.owned.len(),
309 )));
310 }
311
312 let runtime = self.init_runtime(&context)?;
313
314 let schema = Arc::clone(&self.schema);
320 let fut = async move {
321 let mut guard = runtime.receivers.lock().await;
322 guard[partition].take().ok_or_else(|| {
323 datafusion_common::DataFusionError::Execution(format!(
324 "ClusterRepartitionExec::execute called twice for \
325 partition {partition}; receivers are single-use"
326 ))
327 })
328 };
329 let stream = stream::once(fut).flat_map(move |maybe_rx| match maybe_rx {
330 Ok(rx) => tokio_stream::wrappers::ReceiverStream::new(rx)
331 .map(Ok)
332 .boxed(),
333 Err(e) => stream::once(async move { Err(e) }).boxed(),
334 });
335 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
336 }
337}
338
339fn row_vnodes(batch: &RecordBatch, hash_columns: &[usize], vnode_count: u32) -> Vec<u32> {
342 use arrow::row::{RowConverter, SortField};
343 use laminar_core::state::key_hash;
344
345 let cols: Vec<_> = hash_columns
346 .iter()
347 .map(|&i| Arc::clone(batch.column(i)))
348 .collect();
349 let fields: Vec<_> = cols
350 .iter()
351 .map(|c| SortField::new(c.data_type().clone()))
352 .collect();
353 let converter = RowConverter::new(fields).expect("row converter");
354 let rows = converter.convert_columns(&cols).expect("convert rows");
355
356 (0..batch.num_rows())
357 .map(|row| {
358 #[allow(clippy::cast_possible_truncation)]
359 let v = (key_hash(rows.row(row).as_ref()) % u64::from(vnode_count)) as u32;
360 v
361 })
362 .collect()
363}
364
365fn slice_for_vnode(batch: &RecordBatch, vnodes: &[u32], target_vnode: u32) -> Option<RecordBatch> {
366 let indices: UInt32Array = vnodes
367 .iter()
368 .enumerate()
369 .filter_map(|(i, &v)| {
370 if v == target_vnode {
371 u32::try_from(i).ok()
372 } else {
373 None
374 }
375 })
376 .collect();
377 if indices.is_empty() {
378 return None;
379 }
380 let new_cols = batch
381 .columns()
382 .iter()
383 .map(|c| take(c, &indices, None).expect("take"))
384 .collect::<Vec<_>>();
385 Some(RecordBatch::try_new(batch.schema(), new_cols).expect("rebuild"))
386}
387
388#[allow(clippy::too_many_arguments)]
392async fn route_input_stream(
393 input: SendableRecordBatchStream,
394 hash_columns: Vec<usize>,
395 registry: Arc<VnodeRegistry>,
396 self_id: NodeId,
397 vnode_to_partition: HashMap<u32, usize>,
398 sender: Arc<ShuffleSender>,
399 partition_txs: Vec<mpsc::Sender<RecordBatch>>,
400 peers: Vec<ShufflePeerId>,
401 mut inject_rx: mpsc::UnboundedReceiver<CheckpointBarrier>,
402 mut peer_rx: mpsc::UnboundedReceiver<(ShufflePeerId, CheckpointBarrier)>,
403 aligned_tx: watch::Sender<u64>,
404) {
405 let vnode_count = registry.vnode_count();
406 let n_inputs = peers.len() + 1;
407 let tracker = BarrierTracker::new(n_inputs);
408 let peer_port: HashMap<ShufflePeerId, usize> =
409 peers.iter().enumerate().map(|(i, &p)| (p, i + 1)).collect();
410
411 let mut input: Option<SendableRecordBatchStream> = Some(input);
416 let mut partition_txs: Option<Vec<mpsc::Sender<RecordBatch>>> = Some(partition_txs);
417
418 loop {
419 tokio::select! {
420 biased;
421 Some(barrier) = inject_rx.recv() => {
423 let _ = sender.fan_out_barrier(&peers, barrier).await;
424 if let Some(aligned) = tracker.observe(0, barrier) {
425 let _ = aligned_tx.send(aligned.checkpoint_id);
426 }
427 }
428 Some((from, barrier)) = peer_rx.recv() => {
430 if let Some(&port) = peer_port.get(&from) {
431 if let Some(aligned) = tracker.observe(port, barrier) {
432 let _ = aligned_tx.send(aligned.checkpoint_id);
433 }
434 }
435 }
436 next = async {
439 match input.as_mut() {
440 Some(s) => s.next().await,
441 None => std::future::pending().await,
442 }
443 } => {
444 if let Some(Ok(batch)) = next {
445 if batch.num_rows() == 0 { continue; }
446 if partition_txs.is_none() { continue; }
447 let row_vn = row_vnodes(&batch, &hash_columns, vnode_count);
448 let mut seen = row_vn.clone();
449 seen.sort_unstable();
450 seen.dedup();
451 let mut downstream_dropped = false;
452 for v in seen {
453 let Some(slice) = slice_for_vnode(&batch, &row_vn, v) else {
454 continue;
455 };
456 let owner = registry.owner(v);
457 if owner == self_id {
458 if let Some(&idx) = vnode_to_partition.get(&v) {
459 let send_res = partition_txs.as_ref().unwrap()[idx]
460 .send(slice)
461 .await;
462 if send_res.is_err() {
463 downstream_dropped = true;
464 break;
465 }
466 }
467 } else if !owner.is_unassigned() {
468 let msg = laminar_core::shuffle::ShuffleMessage::VnodeData(
469 v, slice,
470 );
471 let _ = sender.send_to(owner.0, &msg).await;
473 }
474 }
475 if downstream_dropped {
476 partition_txs = None;
477 input = None;
478 }
479 } else {
480 input = None;
483 partition_txs = None;
484 }
485 }
486 }
487 }
488}
489
490async fn dispatch_inbound(
493 receiver: Arc<ShuffleReceiver>,
494 vnode_to_partition: HashMap<u32, usize>,
495 partition_txs: Vec<mpsc::Sender<RecordBatch>>,
496 peer_barrier_tx: mpsc::UnboundedSender<(ShufflePeerId, CheckpointBarrier)>,
497) {
498 use laminar_core::shuffle::ShuffleMessage;
499 while let Some((from, msg)) = receiver.recv().await {
500 match msg {
501 ShuffleMessage::VnodeData(vnode, batch) => {
502 if batch.num_rows() == 0 {
503 continue;
504 }
505 let Some(&idx) = vnode_to_partition.get(&vnode) else {
506 continue;
509 };
510 if partition_txs[idx].send(batch).await.is_err() {
511 return; }
513 }
514 ShuffleMessage::Barrier(b) if peer_barrier_tx.send((from, b)).is_err() => return,
515 _ => {}
518 }
519 }
520}