Skip to main content

laminar_sql/datafusion/
cluster_repartition.rs

1//! Cross-instance hash repartition. Replaces DataFusion's in-process
2//! `RepartitionExec::Hash` between `AggregateExec::Partial` and
3//! `AggregateExec::FinalPartitioned`. One output partition per owned
4//! vnode; remote rows ship via `ShuffleSender`.
5
6#![allow(clippy::disallowed_types)] // this is a cluster-only ExecutionPlan; not hot-path
7
8use std::any::Any;
9use std::collections::HashMap;
10use std::fmt::{self, Debug, Formatter};
11use std::sync::{Arc, OnceLock};
12
13use arrow_array::RecordBatch;
14use arrow_schema::SchemaRef;
15use datafusion::error::Result;
16use datafusion::execution::{SendableRecordBatchStream, TaskContext};
17use datafusion::physical_expr::{EquivalenceProperties, Partitioning};
18use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
21use datafusion_common::DataFusionError;
22use futures::stream::{self, StreamExt};
23use laminar_core::checkpoint::barrier::CheckpointBarrier;
24use laminar_core::shuffle::{BarrierTracker, ShufflePeerId, ShuffleReceiver, ShuffleSender};
25use laminar_core::state::{owned_vnodes, peer_owners, NodeId, VnodeRegistry};
26use tokio::sync::{mpsc, watch, Mutex as AsyncMutex};
27use tokio::task::JoinHandle;
28
29/// Cross-instance hash repartition.
30pub struct ClusterRepartitionExec {
31    input: Arc<dyn ExecutionPlan>,
32    /// Column indices hashed to pick a vnode. Only plain columns today;
33    /// computed hash keys would need `PhysicalExpr` evaluation.
34    hash_columns: Vec<usize>,
35    registry: Arc<VnodeRegistry>,
36    sender: Arc<ShuffleSender>,
37    receiver: Arc<ShuffleReceiver>,
38    self_id: NodeId,
39    /// Peers we fan barriers out to; frozen at construction.
40    peers: Vec<ShufflePeerId>,
41    /// One output partition per owned vnode.
42    owned: Vec<u32>,
43    vnode_to_partition: HashMap<u32, usize>,
44    schema: SchemaRef,
45    properties: PlanProperties,
46    runtime: OnceLock<Arc<RuntimeState>>,
47}
48
49struct RuntimeState {
50    /// One receiver per output partition; claimed once by `execute(p)`.
51    receivers: AsyncMutex<Vec<Option<mpsc::Receiver<RecordBatch>>>>,
52    inject_barrier_tx: mpsc::Sender<CheckpointBarrier>,
53    /// Carries the aligned checkpoint id for downstream wrappers.
54    aligned_epoch_watch: watch::Receiver<u64>,
55    _router: JoinHandle<()>,
56    _dispatcher: JoinHandle<()>,
57}
58
59impl ClusterRepartitionExec {
60    /// Handle on the vnode registry. Callers read `assignment_version()`
61    /// to stamp state writes for the split-brain fence.
62    #[must_use]
63    pub fn registry(&self) -> &Arc<VnodeRegistry> {
64        &self.registry
65    }
66
67    /// Construct the exec.
68    ///
69    /// # Errors
70    /// Out-of-range hash columns or zero owned vnodes.
71    pub fn try_new(
72        input: Arc<dyn ExecutionPlan>,
73        hash_columns: Vec<usize>,
74        registry: Arc<VnodeRegistry>,
75        sender: Arc<ShuffleSender>,
76        receiver: Arc<ShuffleReceiver>,
77        self_id: NodeId,
78    ) -> Result<Self> {
79        let schema = input.schema();
80        if let Some(&bad) = hash_columns.iter().find(|&&i| i >= schema.fields().len()) {
81            return Err(DataFusionError::Plan(format!(
82                "ClusterRepartitionExec: hash column {bad} out of range (schema has {} fields)",
83                schema.fields().len()
84            )));
85        }
86
87        let owned = owned_vnodes(&registry, self_id);
88        if owned.is_empty() {
89            return Err(DataFusionError::Plan(format!(
90                "ClusterRepartitionExec: instance {:?} owns no vnodes",
91                self_id
92            )));
93        }
94        let vnode_to_partition = owned.iter().enumerate().map(|(i, &v)| (v, i)).collect();
95
96        // Distinct peers from the registry (everyone owning a vnode but us),
97        // frozen at construction — dynamic membership is deferred work.
98        let peers: Vec<ShufflePeerId> = peer_owners(&registry, self_id)
99            .iter()
100            .map(|n| n.0)
101            .collect();
102
103        let properties = PlanProperties::new(
104            EquivalenceProperties::new(Arc::clone(&schema)),
105            // We advertise hash partitioning. Satisfies
106            // `FinalPartitioned`'s required-input-distribution.
107            Partitioning::UnknownPartitioning(owned.len()),
108            EmissionType::Incremental,
109            Boundedness::Unbounded {
110                requires_infinite_memory: false,
111            },
112        );
113
114        Ok(Self {
115            input,
116            hash_columns,
117            registry,
118            sender,
119            receiver,
120            self_id,
121            peers,
122            owned,
123            vnode_to_partition,
124            schema,
125            properties,
126            runtime: OnceLock::new(),
127        })
128    }
129
130    /// Inject a checkpoint barrier, fanning out to peers.
131    ///
132    /// # Errors
133    /// Before `execute()` or after the router exits.
134    pub fn inject_barrier(&self, barrier: CheckpointBarrier) -> Result<()> {
135        let Some(runtime) = self.runtime.get() else {
136            return Err(DataFusionError::Execution(
137                "ClusterRepartitionExec::inject_barrier before execute()".into(),
138            ));
139        };
140        // Bounded control channel; barriers arrive at checkpoint cadence
141        // so the queue stays shallow. `try_send` keeps the sync signature
142        // and surfaces backpressure as a hard error rather than silently
143        // queueing arbitrary depth.
144        runtime
145            .inject_barrier_tx
146            .try_send(barrier)
147            .map_err(|e| match e {
148                mpsc::error::TrySendError::Full(_) => DataFusionError::Execution(
149                    "barrier inject channel full — router lagging".into(),
150                ),
151                mpsc::error::TrySendError::Closed(_) => {
152                    DataFusionError::Execution("router task exited".into())
153                }
154            })
155    }
156
157    /// `None` before `execute()`; otherwise a watch that fires each
158    /// time the tracker aligns, carrying the checkpoint id.
159    #[must_use]
160    pub fn aligned_epoch_watch(&self) -> Option<watch::Receiver<u64>> {
161        self.runtime.get().map(|rt| rt.aligned_epoch_watch.clone())
162    }
163
164    /// Spawn the router + dispatcher tasks on the first `execute()`.
165    fn init_runtime(&self, context: &Arc<TaskContext>) -> Result<Arc<RuntimeState>> {
166        // 256 is generous for what is at most "barriers in flight per
167        // checkpoint cadence × peers"; chosen to avoid coupling test
168        // tunables to per-deployment cluster sizes.
169        const BARRIER_QUEUE: usize = 256;
170        if let Some(existing) = self.runtime.get() {
171            return Ok(Arc::clone(existing));
172        }
173
174        let n_partitions = self.owned.len();
175        let mut partition_txs: Vec<mpsc::Sender<RecordBatch>> = Vec::with_capacity(n_partitions);
176        let mut receivers = Vec::with_capacity(n_partitions);
177        for _ in 0..n_partitions {
178            let (tx, rx) = mpsc::channel::<RecordBatch>(16);
179            partition_txs.push(tx);
180            receivers.push(Some(rx));
181        }
182
183        let input_stream = self.input.execute(0, Arc::clone(context))?;
184        let hash_columns = self.hash_columns.clone();
185        let registry = Arc::clone(&self.registry);
186        let self_id = self.self_id;
187        let sender = Arc::clone(&self.sender);
188        let receiver = Arc::clone(&self.receiver);
189        let vnode_to_partition = self.vnode_to_partition.clone();
190        let peers = self.peers.clone();
191
192        // Barrier plumbing:
193        //   - inject_barrier_(tx|rx): external trigger → router
194        //   - peer_barrier_(tx|rx):   dispatcher → router (after peer gossip)
195        //   - aligned_(tx|rx):        router → subscribers (watch channel)
196        let (inject_tx, inject_rx) = mpsc::channel::<CheckpointBarrier>(BARRIER_QUEUE);
197        let (peer_tx, peer_rx) = mpsc::channel::<(ShufflePeerId, CheckpointBarrier)>(BARRIER_QUEUE);
198        let (aligned_tx, aligned_rx) = watch::channel::<u64>(0);
199
200        let router_txs = partition_txs.clone();
201        let router_registry = Arc::clone(&registry);
202        let router_vtp = vnode_to_partition.clone();
203        let router_sender = Arc::clone(&sender);
204        let router_peers = peers.clone();
205        let router = tokio::spawn(async move {
206            route_input_stream(
207                input_stream,
208                hash_columns,
209                router_registry,
210                self_id,
211                router_vtp,
212                router_sender,
213                router_txs,
214                router_peers,
215                inject_rx,
216                peer_rx,
217                aligned_tx,
218            )
219            .await;
220        });
221
222        let dispatcher_txs = partition_txs;
223        let dispatcher = tokio::spawn(async move {
224            dispatch_inbound(receiver, vnode_to_partition, dispatcher_txs, peer_tx).await;
225        });
226
227        let state = Arc::new(RuntimeState {
228            receivers: AsyncMutex::new(receivers),
229            inject_barrier_tx: inject_tx,
230            aligned_epoch_watch: aligned_rx,
231            _router: router,
232            _dispatcher: dispatcher,
233        });
234        Ok(Arc::clone(self.runtime.get_or_init(|| state)))
235    }
236}
237
238impl Debug for ClusterRepartitionExec {
239    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
240        f.debug_struct("ClusterRepartitionExec")
241            .field("self_id", &self.self_id)
242            .field("hash_columns", &self.hash_columns)
243            .field("owned_vnodes", &self.owned.len())
244            .finish_non_exhaustive()
245    }
246}
247
248impl DisplayAs for ClusterRepartitionExec {
249    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter<'_>) -> fmt::Result {
250        match t {
251            DisplayFormatType::Default | DisplayFormatType::Verbose => write!(
252                f,
253                "ClusterRepartitionExec: owned_vnodes={}, hash_columns={:?}",
254                self.owned.len(),
255                self.hash_columns
256            ),
257            DisplayFormatType::TreeRender => write!(f, "ClusterRepartitionExec"),
258        }
259    }
260}
261
262impl ExecutionPlan for ClusterRepartitionExec {
263    fn name(&self) -> &'static str {
264        "ClusterRepartitionExec"
265    }
266
267    fn as_any(&self) -> &dyn Any {
268        self
269    }
270
271    fn schema(&self) -> SchemaRef {
272        Arc::clone(&self.schema)
273    }
274
275    fn properties(&self) -> &PlanProperties {
276        &self.properties
277    }
278
279    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
280        vec![&self.input]
281    }
282
283    fn with_new_children(
284        self: Arc<Self>,
285        mut children: Vec<Arc<dyn ExecutionPlan>>,
286    ) -> Result<Arc<dyn ExecutionPlan>> {
287        if children.len() != 1 {
288            return Err(DataFusionError::Plan(
289                "ClusterRepartitionExec requires exactly one child".into(),
290            ));
291        }
292        Ok(Arc::new(Self {
293            input: children.swap_remove(0),
294            hash_columns: self.hash_columns.clone(),
295            registry: Arc::clone(&self.registry),
296            sender: Arc::clone(&self.sender),
297            receiver: Arc::clone(&self.receiver),
298            self_id: self.self_id,
299            peers: self.peers.clone(),
300            owned: self.owned.clone(),
301            vnode_to_partition: self.vnode_to_partition.clone(),
302            schema: Arc::clone(&self.schema),
303            properties: self.properties.clone(),
304            // Fresh runtime on re-plan.
305            runtime: OnceLock::new(),
306        }))
307    }
308
309    fn execute(
310        &self,
311        partition: usize,
312        context: Arc<TaskContext>,
313    ) -> Result<SendableRecordBatchStream> {
314        if partition >= self.owned.len() {
315            return Err(DataFusionError::Plan(format!(
316                "ClusterRepartitionExec: partition {partition} >= output partitions {}",
317                self.owned.len(),
318            )));
319        }
320
321        let runtime = self.init_runtime(&context)?;
322
323        // Claim this partition's receiver. DataFusion contracts execute
324        // each partition exactly once per plan execution; a second call
325        // would have previously returned an empty stream silently,
326        // masking plan-reuse bugs. Now we return a typed error so the
327        // caller sees something went wrong.
328        let schema = Arc::clone(&self.schema);
329        let fut = async move {
330            let mut guard = runtime.receivers.lock().await;
331            guard[partition].take().ok_or_else(|| {
332                datafusion_common::DataFusionError::Execution(format!(
333                    "ClusterRepartitionExec::execute called twice for \
334                         partition {partition}; receivers are single-use"
335                ))
336            })
337        };
338        let stream = stream::once(fut).flat_map(move |maybe_rx| match maybe_rx {
339            Ok(rx) => tokio_stream::wrappers::ReceiverStream::new(rx)
340                .map(Ok)
341                .boxed(),
342            Err(e) => stream::once(async move { Err(e) }).boxed(),
343        });
344        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
345    }
346}
347
348/// Drains data batches and barriers (local + peer), routes data to
349/// local partitions or peer senders, and publishes aligned-epoch ids
350/// on `aligned_tx`.
351#[allow(clippy::too_many_arguments)]
352async fn route_input_stream(
353    input: SendableRecordBatchStream,
354    hash_columns: Vec<usize>,
355    registry: Arc<VnodeRegistry>,
356    self_id: NodeId,
357    vnode_to_partition: HashMap<u32, usize>,
358    sender: Arc<ShuffleSender>,
359    partition_txs: Vec<mpsc::Sender<RecordBatch>>,
360    peers: Vec<ShufflePeerId>,
361    mut inject_rx: mpsc::Receiver<CheckpointBarrier>,
362    mut peer_rx: mpsc::Receiver<(ShufflePeerId, CheckpointBarrier)>,
363    aligned_tx: watch::Sender<u64>,
364) {
365    let vnode_count = registry.vnode_count();
366    let n_inputs = peers.len() + 1;
367    let tracker = BarrierTracker::new(n_inputs);
368    let peer_port: HashMap<ShufflePeerId, usize> =
369        peers.iter().enumerate().map(|(i, &p)| (p, i + 1)).collect();
370
371    // Once the local input stream ends we drop it so the select arm
372    // becomes `pending` forever. Partition senders are also dropped
373    // at input EOS so the downstream output streams terminate cleanly
374    // — the router stays alive purely for barrier coordination.
375    let mut input: Option<SendableRecordBatchStream> = Some(input);
376    let mut partition_txs: Option<Vec<mpsc::Sender<RecordBatch>>> = Some(partition_txs);
377
378    loop {
379        tokio::select! {
380            biased;
381            // 1. External barrier trigger.
382            Some(barrier) = inject_rx.recv() => {
383                let _ = sender.fan_out_barrier(&peers, barrier).await;
384                if let Some(aligned) = tracker.observe(0, barrier) {
385                    let _ = aligned_tx.send(aligned.checkpoint_id);
386                }
387            }
388            // 2. Peer barrier forwarded from the dispatcher.
389            Some((from, barrier)) = peer_rx.recv() => {
390                if let Some(&port) = peer_port.get(&from) {
391                    if let Some(aligned) = tracker.observe(port, barrier) {
392                        let _ = aligned_tx.send(aligned.checkpoint_id);
393                    }
394                }
395            }
396            // 3. Data from the local input. `pending` forever once the
397            // input is exhausted so this arm drops out of the select.
398            next = async {
399                match input.as_mut() {
400                    Some(s) => s.next().await,
401                    None => std::future::pending().await,
402                }
403            } => {
404                if let Some(Ok(batch)) = next {
405                    if batch.num_rows() == 0 { continue; }
406                    if partition_txs.is_none() { continue; }
407                    let row_vn =
408                        laminar_core::shuffle::row_vnodes(&batch, &hash_columns, vnode_count);
409                    let (local_slices, remote_slices) =
410                        laminar_core::shuffle::slice_batch_by_targets(&batch, &row_vn, &registry, self_id);
411                    let mut downstream_dropped = false;
412
413                    for (v, slice) in local_slices {
414                        if let Some(&idx) = vnode_to_partition.get(&v) {
415                            let send_res = partition_txs.as_ref().unwrap()[idx]
416                                .send(slice)
417                                .await;
418                            if send_res.is_err() {
419                                downstream_dropped = true;
420                                break;
421                            }
422                        }
423                    }
424
425                    if !downstream_dropped {
426                        for (owner, slice) in remote_slices {
427                            let vnode_col = slice
428                                .column(slice.num_columns() - 1)
429                                .as_any()
430                                .downcast_ref::<arrow::array::UInt32Array>()
431                                .expect("vnode col");
432                            let row_vnodes = vnode_col.values().to_vec();
433                            let sub_slices = laminar_core::shuffle::slice_batch_by_vnodes(&slice, &row_vnodes);
434                            for (vnode_id, sub_slice) in sub_slices {
435                                let schema = Arc::new(arrow_schema::Schema::new(
436                                    sub_slice.schema().fields()[..sub_slice.num_columns() - 1].to_vec(),
437                                ));
438                                let columns = sub_slice.columns()[..sub_slice.num_columns() - 1].to_vec();
439                                let sub_slice_clean = RecordBatch::try_new(schema, columns).expect("clean");
440                                let msg = laminar_core::shuffle::ShuffleMessage::VnodeData(
441                                    String::new(),
442                                    vnode_id,
443                                    sub_slice_clean,
444                                );
445                                let _ = sender.send_to(owner.0, &msg).await;
446                            }
447                        }
448                    }
449
450                    if downstream_dropped {
451                        partition_txs = None;
452                        input = None;
453                    }
454                } else {
455                    // EOS or error: drop senders so output streams
456                    // terminate; keep router alive for barriers.
457                    input = None;
458                    partition_txs = None;
459                }
460            }
461        }
462    }
463}
464
465/// Consumes inbound shuffle frames: `VnodeData` to local partitions,
466/// `Barrier` to the router's peer-barrier channel.
467async fn dispatch_inbound(
468    receiver: Arc<ShuffleReceiver>,
469    vnode_to_partition: HashMap<u32, usize>,
470    partition_txs: Vec<mpsc::Sender<RecordBatch>>,
471    peer_barrier_tx: mpsc::Sender<(ShufflePeerId, CheckpointBarrier)>,
472) {
473    use laminar_core::shuffle::ShuffleMessage;
474    while let Some((from, msg)) = receiver.recv().await {
475        match msg {
476            ShuffleMessage::VnodeData(_stage, vnode, batch) => {
477                if batch.num_rows() == 0 {
478                    continue;
479                }
480                let Some(&idx) = vnode_to_partition.get(&vnode) else {
481                    // Sender mis-routed to a vnode we don't own — drop
482                    // (a future observability hook can count these).
483                    continue;
484                };
485                if partition_txs[idx].send(batch).await.is_err() {
486                    return; // downstream dropped
487                }
488            }
489            ShuffleMessage::Barrier(b) if peer_barrier_tx.send((from, b)).await.is_err() => return,
490            // Hello is consumed by per_peer_loop; Close closes the
491            // reader. Barriers that succeed fall through here.
492            _ => {}
493        }
494    }
495}
496
497#[cfg(feature = "cluster")]
498use datafusion::physical_optimizer::PhysicalOptimizerRule;
499#[cfg(feature = "cluster")]
500use datafusion::physical_plan::joins::HashJoinExec;
501#[cfg(feature = "cluster")]
502use datafusion_common::config::ConfigOptions;
503
504#[cfg(feature = "cluster")]
505static CLUSTER_CONTEXT: parking_lot::RwLock<Option<ClusterContext>> =
506    parking_lot::RwLock::new(None);
507
508#[cfg(feature = "cluster")]
509#[derive(Clone)]
510struct ClusterContext {
511    registry: Arc<VnodeRegistry>,
512    sender: Arc<ShuffleSender>,
513    receiver: Arc<ShuffleReceiver>,
514    self_id: NodeId,
515}
516
517#[cfg(feature = "cluster")]
518/// Set the global cluster context for the distributed physical optimizer rules.
519pub fn set_cluster_context(
520    registry: Arc<VnodeRegistry>,
521    sender: Arc<ShuffleSender>,
522    receiver: Arc<ShuffleReceiver>,
523    self_id: NodeId,
524) {
525    *CLUSTER_CONTEXT.write() = Some(ClusterContext {
526        registry,
527        sender,
528        receiver,
529        self_id,
530    });
531}
532
533#[cfg(feature = "cluster")]
534#[derive(Debug)]
535/// Physical optimizer rule that wraps HashJoinExec inputs in ClusterRepartitionExec.
536pub struct DistributedJoinRule;
537
538#[cfg(feature = "cluster")]
539impl PhysicalOptimizerRule for DistributedJoinRule {
540    fn optimize(
541        &self,
542        plan: Arc<dyn ExecutionPlan>,
543        _config: &ConfigOptions,
544    ) -> Result<Arc<dyn ExecutionPlan>> {
545        let ctx_opt = CLUSTER_CONTEXT.read().clone();
546        let Some(ctx) = ctx_opt else {
547            return Ok(plan);
548        };
549        optimize_plan(plan, &ctx.registry, &ctx.sender, &ctx.receiver, ctx.self_id)
550    }
551
552    fn name(&self) -> &'static str {
553        "DistributedJoinRule"
554    }
555
556    fn schema_check(&self) -> bool {
557        true
558    }
559}
560
561#[cfg(feature = "cluster")]
562fn optimize_plan(
563    plan: Arc<dyn ExecutionPlan>,
564    registry: &Arc<VnodeRegistry>,
565    sender: &Arc<ShuffleSender>,
566    receiver: &Arc<ShuffleReceiver>,
567    self_id: NodeId,
568) -> Result<Arc<dyn ExecutionPlan>> {
569    let new_children: Result<Vec<Arc<dyn ExecutionPlan>>> = plan
570        .children()
571        .into_iter()
572        .map(|child| optimize_plan(child.clone(), registry, sender, receiver, self_id))
573        .collect();
574    let new_children = new_children?;
575
576    if let Some(hash_join) = plan.as_any().downcast_ref::<HashJoinExec>() {
577        let left = new_children[0].clone();
578        let right = new_children[1].clone();
579
580        if *hash_join.partition_mode()
581            == datafusion::physical_plan::joins::PartitionMode::CollectLeft
582        {
583            return plan.clone().with_new_children(vec![left, right]);
584        }
585
586        let mut left_keys = Vec::new();
587        let mut right_keys = Vec::new();
588        for (l_col, r_col) in hash_join.on() {
589            if let Some(col) = l_col
590                .as_any()
591                .downcast_ref::<datafusion::physical_expr::expressions::Column>()
592            {
593                left_keys.push(col.index());
594            } else {
595                return Err(datafusion::error::DataFusionError::Internal(
596                    "HashJoinExec: left key is not a Column expression".to_string(),
597                ));
598            }
599            if let Some(col) = r_col
600                .as_any()
601                .downcast_ref::<datafusion::physical_expr::expressions::Column>()
602            {
603                right_keys.push(col.index());
604            } else {
605                return Err(datafusion::error::DataFusionError::Internal(
606                    "HashJoinExec: right key is not a Column expression".to_string(),
607                ));
608            }
609        }
610
611        let new_left = Arc::new(ClusterRepartitionExec::try_new(
612            left,
613            left_keys,
614            Arc::clone(registry),
615            Arc::clone(sender),
616            Arc::clone(receiver),
617            self_id,
618        )?);
619
620        let new_right = Arc::new(ClusterRepartitionExec::try_new(
621            right,
622            right_keys,
623            Arc::clone(registry),
624            Arc::clone(sender),
625            Arc::clone(receiver),
626            self_id,
627        )?);
628
629        return plan.clone().with_new_children(vec![new_left, new_right]);
630    }
631
632    if new_children.is_empty() {
633        Ok(plan)
634    } else {
635        plan.with_new_children(new_children)
636    }
637}