1#![allow(clippy::disallowed_types)] use std::any::Any;
9use std::collections::HashMap;
10use std::fmt::{self, Debug, Formatter};
11use std::sync::{Arc, OnceLock};
12
13use arrow_array::RecordBatch;
14use arrow_schema::SchemaRef;
15use datafusion::error::Result;
16use datafusion::execution::{SendableRecordBatchStream, TaskContext};
17use datafusion::physical_expr::{EquivalenceProperties, Partitioning};
18use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
21use datafusion_common::DataFusionError;
22use futures::stream::{self, StreamExt};
23use laminar_core::checkpoint::barrier::CheckpointBarrier;
24use laminar_core::shuffle::{BarrierTracker, ShufflePeerId, ShuffleReceiver, ShuffleSender};
25use laminar_core::state::{owned_vnodes, peer_owners, NodeId, VnodeRegistry};
26use tokio::sync::{mpsc, watch, Mutex as AsyncMutex};
27use tokio::task::JoinHandle;
28
29pub struct ClusterRepartitionExec {
31 input: Arc<dyn ExecutionPlan>,
32 hash_columns: Vec<usize>,
35 registry: Arc<VnodeRegistry>,
36 sender: Arc<ShuffleSender>,
37 receiver: Arc<ShuffleReceiver>,
38 self_id: NodeId,
39 peers: Vec<ShufflePeerId>,
41 owned: Vec<u32>,
43 vnode_to_partition: HashMap<u32, usize>,
44 schema: SchemaRef,
45 properties: PlanProperties,
46 runtime: OnceLock<Arc<RuntimeState>>,
47}
48
49struct RuntimeState {
50 receivers: AsyncMutex<Vec<Option<mpsc::Receiver<RecordBatch>>>>,
52 inject_barrier_tx: mpsc::Sender<CheckpointBarrier>,
53 aligned_epoch_watch: watch::Receiver<u64>,
55 _router: JoinHandle<()>,
56 _dispatcher: JoinHandle<()>,
57}
58
59impl ClusterRepartitionExec {
60 #[must_use]
63 pub fn registry(&self) -> &Arc<VnodeRegistry> {
64 &self.registry
65 }
66
67 pub fn try_new(
72 input: Arc<dyn ExecutionPlan>,
73 hash_columns: Vec<usize>,
74 registry: Arc<VnodeRegistry>,
75 sender: Arc<ShuffleSender>,
76 receiver: Arc<ShuffleReceiver>,
77 self_id: NodeId,
78 ) -> Result<Self> {
79 let schema = input.schema();
80 if let Some(&bad) = hash_columns.iter().find(|&&i| i >= schema.fields().len()) {
81 return Err(DataFusionError::Plan(format!(
82 "ClusterRepartitionExec: hash column {bad} out of range (schema has {} fields)",
83 schema.fields().len()
84 )));
85 }
86
87 let owned = owned_vnodes(®istry, self_id);
88 if owned.is_empty() {
89 return Err(DataFusionError::Plan(format!(
90 "ClusterRepartitionExec: instance {:?} owns no vnodes",
91 self_id
92 )));
93 }
94 let vnode_to_partition = owned.iter().enumerate().map(|(i, &v)| (v, i)).collect();
95
96 let peers: Vec<ShufflePeerId> = peer_owners(®istry, self_id)
99 .iter()
100 .map(|n| n.0)
101 .collect();
102
103 let properties = PlanProperties::new(
104 EquivalenceProperties::new(Arc::clone(&schema)),
105 Partitioning::UnknownPartitioning(owned.len()),
108 EmissionType::Incremental,
109 Boundedness::Unbounded {
110 requires_infinite_memory: false,
111 },
112 );
113
114 Ok(Self {
115 input,
116 hash_columns,
117 registry,
118 sender,
119 receiver,
120 self_id,
121 peers,
122 owned,
123 vnode_to_partition,
124 schema,
125 properties,
126 runtime: OnceLock::new(),
127 })
128 }
129
130 pub fn inject_barrier(&self, barrier: CheckpointBarrier) -> Result<()> {
135 let Some(runtime) = self.runtime.get() else {
136 return Err(DataFusionError::Execution(
137 "ClusterRepartitionExec::inject_barrier before execute()".into(),
138 ));
139 };
140 runtime
145 .inject_barrier_tx
146 .try_send(barrier)
147 .map_err(|e| match e {
148 mpsc::error::TrySendError::Full(_) => DataFusionError::Execution(
149 "barrier inject channel full — router lagging".into(),
150 ),
151 mpsc::error::TrySendError::Closed(_) => {
152 DataFusionError::Execution("router task exited".into())
153 }
154 })
155 }
156
157 #[must_use]
160 pub fn aligned_epoch_watch(&self) -> Option<watch::Receiver<u64>> {
161 self.runtime.get().map(|rt| rt.aligned_epoch_watch.clone())
162 }
163
164 fn init_runtime(&self, context: &Arc<TaskContext>) -> Result<Arc<RuntimeState>> {
166 const BARRIER_QUEUE: usize = 256;
170 if let Some(existing) = self.runtime.get() {
171 return Ok(Arc::clone(existing));
172 }
173
174 let n_partitions = self.owned.len();
175 let mut partition_txs: Vec<mpsc::Sender<RecordBatch>> = Vec::with_capacity(n_partitions);
176 let mut receivers = Vec::with_capacity(n_partitions);
177 for _ in 0..n_partitions {
178 let (tx, rx) = mpsc::channel::<RecordBatch>(16);
179 partition_txs.push(tx);
180 receivers.push(Some(rx));
181 }
182
183 let input_stream = self.input.execute(0, Arc::clone(context))?;
184 let hash_columns = self.hash_columns.clone();
185 let registry = Arc::clone(&self.registry);
186 let self_id = self.self_id;
187 let sender = Arc::clone(&self.sender);
188 let receiver = Arc::clone(&self.receiver);
189 let vnode_to_partition = self.vnode_to_partition.clone();
190 let peers = self.peers.clone();
191
192 let (inject_tx, inject_rx) = mpsc::channel::<CheckpointBarrier>(BARRIER_QUEUE);
197 let (peer_tx, peer_rx) = mpsc::channel::<(ShufflePeerId, CheckpointBarrier)>(BARRIER_QUEUE);
198 let (aligned_tx, aligned_rx) = watch::channel::<u64>(0);
199
200 let router_txs = partition_txs.clone();
201 let router_registry = Arc::clone(®istry);
202 let router_vtp = vnode_to_partition.clone();
203 let router_sender = Arc::clone(&sender);
204 let router_peers = peers.clone();
205 let router = tokio::spawn(async move {
206 route_input_stream(
207 input_stream,
208 hash_columns,
209 router_registry,
210 self_id,
211 router_vtp,
212 router_sender,
213 router_txs,
214 router_peers,
215 inject_rx,
216 peer_rx,
217 aligned_tx,
218 )
219 .await;
220 });
221
222 let dispatcher_txs = partition_txs;
223 let dispatcher = tokio::spawn(async move {
224 dispatch_inbound(receiver, vnode_to_partition, dispatcher_txs, peer_tx).await;
225 });
226
227 let state = Arc::new(RuntimeState {
228 receivers: AsyncMutex::new(receivers),
229 inject_barrier_tx: inject_tx,
230 aligned_epoch_watch: aligned_rx,
231 _router: router,
232 _dispatcher: dispatcher,
233 });
234 Ok(Arc::clone(self.runtime.get_or_init(|| state)))
235 }
236}
237
238impl Debug for ClusterRepartitionExec {
239 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
240 f.debug_struct("ClusterRepartitionExec")
241 .field("self_id", &self.self_id)
242 .field("hash_columns", &self.hash_columns)
243 .field("owned_vnodes", &self.owned.len())
244 .finish_non_exhaustive()
245 }
246}
247
248impl DisplayAs for ClusterRepartitionExec {
249 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter<'_>) -> fmt::Result {
250 match t {
251 DisplayFormatType::Default | DisplayFormatType::Verbose => write!(
252 f,
253 "ClusterRepartitionExec: owned_vnodes={}, hash_columns={:?}",
254 self.owned.len(),
255 self.hash_columns
256 ),
257 DisplayFormatType::TreeRender => write!(f, "ClusterRepartitionExec"),
258 }
259 }
260}
261
262impl ExecutionPlan for ClusterRepartitionExec {
263 fn name(&self) -> &'static str {
264 "ClusterRepartitionExec"
265 }
266
267 fn as_any(&self) -> &dyn Any {
268 self
269 }
270
271 fn schema(&self) -> SchemaRef {
272 Arc::clone(&self.schema)
273 }
274
275 fn properties(&self) -> &PlanProperties {
276 &self.properties
277 }
278
279 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
280 vec![&self.input]
281 }
282
283 fn with_new_children(
284 self: Arc<Self>,
285 mut children: Vec<Arc<dyn ExecutionPlan>>,
286 ) -> Result<Arc<dyn ExecutionPlan>> {
287 if children.len() != 1 {
288 return Err(DataFusionError::Plan(
289 "ClusterRepartitionExec requires exactly one child".into(),
290 ));
291 }
292 Ok(Arc::new(Self {
293 input: children.swap_remove(0),
294 hash_columns: self.hash_columns.clone(),
295 registry: Arc::clone(&self.registry),
296 sender: Arc::clone(&self.sender),
297 receiver: Arc::clone(&self.receiver),
298 self_id: self.self_id,
299 peers: self.peers.clone(),
300 owned: self.owned.clone(),
301 vnode_to_partition: self.vnode_to_partition.clone(),
302 schema: Arc::clone(&self.schema),
303 properties: self.properties.clone(),
304 runtime: OnceLock::new(),
306 }))
307 }
308
309 fn execute(
310 &self,
311 partition: usize,
312 context: Arc<TaskContext>,
313 ) -> Result<SendableRecordBatchStream> {
314 if partition >= self.owned.len() {
315 return Err(DataFusionError::Plan(format!(
316 "ClusterRepartitionExec: partition {partition} >= output partitions {}",
317 self.owned.len(),
318 )));
319 }
320
321 let runtime = self.init_runtime(&context)?;
322
323 let schema = Arc::clone(&self.schema);
329 let fut = async move {
330 let mut guard = runtime.receivers.lock().await;
331 guard[partition].take().ok_or_else(|| {
332 datafusion_common::DataFusionError::Execution(format!(
333 "ClusterRepartitionExec::execute called twice for \
334 partition {partition}; receivers are single-use"
335 ))
336 })
337 };
338 let stream = stream::once(fut).flat_map(move |maybe_rx| match maybe_rx {
339 Ok(rx) => tokio_stream::wrappers::ReceiverStream::new(rx)
340 .map(Ok)
341 .boxed(),
342 Err(e) => stream::once(async move { Err(e) }).boxed(),
343 });
344 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
345 }
346}
347
348#[allow(clippy::too_many_arguments)]
352async fn route_input_stream(
353 input: SendableRecordBatchStream,
354 hash_columns: Vec<usize>,
355 registry: Arc<VnodeRegistry>,
356 self_id: NodeId,
357 vnode_to_partition: HashMap<u32, usize>,
358 sender: Arc<ShuffleSender>,
359 partition_txs: Vec<mpsc::Sender<RecordBatch>>,
360 peers: Vec<ShufflePeerId>,
361 mut inject_rx: mpsc::Receiver<CheckpointBarrier>,
362 mut peer_rx: mpsc::Receiver<(ShufflePeerId, CheckpointBarrier)>,
363 aligned_tx: watch::Sender<u64>,
364) {
365 let vnode_count = registry.vnode_count();
366 let n_inputs = peers.len() + 1;
367 let tracker = BarrierTracker::new(n_inputs);
368 let peer_port: HashMap<ShufflePeerId, usize> =
369 peers.iter().enumerate().map(|(i, &p)| (p, i + 1)).collect();
370
371 let mut input: Option<SendableRecordBatchStream> = Some(input);
376 let mut partition_txs: Option<Vec<mpsc::Sender<RecordBatch>>> = Some(partition_txs);
377
378 loop {
379 tokio::select! {
380 biased;
381 Some(barrier) = inject_rx.recv() => {
383 let _ = sender.fan_out_barrier(&peers, barrier).await;
384 if let Some(aligned) = tracker.observe(0, barrier) {
385 let _ = aligned_tx.send(aligned.checkpoint_id);
386 }
387 }
388 Some((from, barrier)) = peer_rx.recv() => {
390 if let Some(&port) = peer_port.get(&from) {
391 if let Some(aligned) = tracker.observe(port, barrier) {
392 let _ = aligned_tx.send(aligned.checkpoint_id);
393 }
394 }
395 }
396 next = async {
399 match input.as_mut() {
400 Some(s) => s.next().await,
401 None => std::future::pending().await,
402 }
403 } => {
404 if let Some(Ok(batch)) = next {
405 if batch.num_rows() == 0 { continue; }
406 if partition_txs.is_none() { continue; }
407 let row_vn =
408 laminar_core::shuffle::row_vnodes(&batch, &hash_columns, vnode_count);
409 let (local_slices, remote_slices) =
410 laminar_core::shuffle::slice_batch_by_targets(&batch, &row_vn, ®istry, self_id);
411 let mut downstream_dropped = false;
412
413 for (v, slice) in local_slices {
414 if let Some(&idx) = vnode_to_partition.get(&v) {
415 let send_res = partition_txs.as_ref().unwrap()[idx]
416 .send(slice)
417 .await;
418 if send_res.is_err() {
419 downstream_dropped = true;
420 break;
421 }
422 }
423 }
424
425 if !downstream_dropped {
426 for (owner, slice) in remote_slices {
427 let vnode_col = slice
428 .column(slice.num_columns() - 1)
429 .as_any()
430 .downcast_ref::<arrow::array::UInt32Array>()
431 .expect("vnode col");
432 let row_vnodes = vnode_col.values().to_vec();
433 let sub_slices = laminar_core::shuffle::slice_batch_by_vnodes(&slice, &row_vnodes);
434 for (vnode_id, sub_slice) in sub_slices {
435 let schema = Arc::new(arrow_schema::Schema::new(
436 sub_slice.schema().fields()[..sub_slice.num_columns() - 1].to_vec(),
437 ));
438 let columns = sub_slice.columns()[..sub_slice.num_columns() - 1].to_vec();
439 let sub_slice_clean = RecordBatch::try_new(schema, columns).expect("clean");
440 let msg = laminar_core::shuffle::ShuffleMessage::VnodeData(
441 String::new(),
442 vnode_id,
443 sub_slice_clean,
444 );
445 let _ = sender.send_to(owner.0, &msg).await;
446 }
447 }
448 }
449
450 if downstream_dropped {
451 partition_txs = None;
452 input = None;
453 }
454 } else {
455 input = None;
458 partition_txs = None;
459 }
460 }
461 }
462 }
463}
464
465async fn dispatch_inbound(
468 receiver: Arc<ShuffleReceiver>,
469 vnode_to_partition: HashMap<u32, usize>,
470 partition_txs: Vec<mpsc::Sender<RecordBatch>>,
471 peer_barrier_tx: mpsc::Sender<(ShufflePeerId, CheckpointBarrier)>,
472) {
473 use laminar_core::shuffle::ShuffleMessage;
474 while let Some((from, msg)) = receiver.recv().await {
475 match msg {
476 ShuffleMessage::VnodeData(_stage, vnode, batch) => {
477 if batch.num_rows() == 0 {
478 continue;
479 }
480 let Some(&idx) = vnode_to_partition.get(&vnode) else {
481 continue;
484 };
485 if partition_txs[idx].send(batch).await.is_err() {
486 return; }
488 }
489 ShuffleMessage::Barrier(b) if peer_barrier_tx.send((from, b)).await.is_err() => return,
490 _ => {}
493 }
494 }
495}
496
497#[cfg(feature = "cluster")]
498use datafusion::physical_optimizer::PhysicalOptimizerRule;
499#[cfg(feature = "cluster")]
500use datafusion::physical_plan::joins::HashJoinExec;
501#[cfg(feature = "cluster")]
502use datafusion_common::config::ConfigOptions;
503
504#[cfg(feature = "cluster")]
505static CLUSTER_CONTEXT: parking_lot::RwLock<Option<ClusterContext>> =
506 parking_lot::RwLock::new(None);
507
508#[cfg(feature = "cluster")]
509#[derive(Clone)]
510struct ClusterContext {
511 registry: Arc<VnodeRegistry>,
512 sender: Arc<ShuffleSender>,
513 receiver: Arc<ShuffleReceiver>,
514 self_id: NodeId,
515}
516
517#[cfg(feature = "cluster")]
518pub fn set_cluster_context(
520 registry: Arc<VnodeRegistry>,
521 sender: Arc<ShuffleSender>,
522 receiver: Arc<ShuffleReceiver>,
523 self_id: NodeId,
524) {
525 *CLUSTER_CONTEXT.write() = Some(ClusterContext {
526 registry,
527 sender,
528 receiver,
529 self_id,
530 });
531}
532
533#[cfg(feature = "cluster")]
534#[derive(Debug)]
535pub struct DistributedJoinRule;
537
538#[cfg(feature = "cluster")]
539impl PhysicalOptimizerRule for DistributedJoinRule {
540 fn optimize(
541 &self,
542 plan: Arc<dyn ExecutionPlan>,
543 _config: &ConfigOptions,
544 ) -> Result<Arc<dyn ExecutionPlan>> {
545 let ctx_opt = CLUSTER_CONTEXT.read().clone();
546 let Some(ctx) = ctx_opt else {
547 return Ok(plan);
548 };
549 optimize_plan(plan, &ctx.registry, &ctx.sender, &ctx.receiver, ctx.self_id)
550 }
551
552 fn name(&self) -> &'static str {
553 "DistributedJoinRule"
554 }
555
556 fn schema_check(&self) -> bool {
557 true
558 }
559}
560
561#[cfg(feature = "cluster")]
562fn optimize_plan(
563 plan: Arc<dyn ExecutionPlan>,
564 registry: &Arc<VnodeRegistry>,
565 sender: &Arc<ShuffleSender>,
566 receiver: &Arc<ShuffleReceiver>,
567 self_id: NodeId,
568) -> Result<Arc<dyn ExecutionPlan>> {
569 let new_children: Result<Vec<Arc<dyn ExecutionPlan>>> = plan
570 .children()
571 .into_iter()
572 .map(|child| optimize_plan(child.clone(), registry, sender, receiver, self_id))
573 .collect();
574 let new_children = new_children?;
575
576 if let Some(hash_join) = plan.as_any().downcast_ref::<HashJoinExec>() {
577 let left = new_children[0].clone();
578 let right = new_children[1].clone();
579
580 if *hash_join.partition_mode()
581 == datafusion::physical_plan::joins::PartitionMode::CollectLeft
582 {
583 return plan.clone().with_new_children(vec![left, right]);
584 }
585
586 let mut left_keys = Vec::new();
587 let mut right_keys = Vec::new();
588 for (l_col, r_col) in hash_join.on() {
589 if let Some(col) = l_col
590 .as_any()
591 .downcast_ref::<datafusion::physical_expr::expressions::Column>()
592 {
593 left_keys.push(col.index());
594 } else {
595 return Err(datafusion::error::DataFusionError::Internal(
596 "HashJoinExec: left key is not a Column expression".to_string(),
597 ));
598 }
599 if let Some(col) = r_col
600 .as_any()
601 .downcast_ref::<datafusion::physical_expr::expressions::Column>()
602 {
603 right_keys.push(col.index());
604 } else {
605 return Err(datafusion::error::DataFusionError::Internal(
606 "HashJoinExec: right key is not a Column expression".to_string(),
607 ));
608 }
609 }
610
611 let new_left = Arc::new(ClusterRepartitionExec::try_new(
612 left,
613 left_keys,
614 Arc::clone(registry),
615 Arc::clone(sender),
616 Arc::clone(receiver),
617 self_id,
618 )?);
619
620 let new_right = Arc::new(ClusterRepartitionExec::try_new(
621 right,
622 right_keys,
623 Arc::clone(registry),
624 Arc::clone(sender),
625 Arc::clone(receiver),
626 self_id,
627 )?);
628
629 return plan.clone().with_new_children(vec![new_left, new_right]);
630 }
631
632 if new_children.is_empty() {
633 Ok(plan)
634 } else {
635 plan.with_new_children(new_children)
636 }
637}