Skip to main content

laminar_sql/datafusion/
cluster_repartition.rs

1//! Cross-instance hash repartition. Replaces DataFusion's in-process
2//! `RepartitionExec::Hash` between `AggregateExec::Partial` and
3//! `AggregateExec::FinalPartitioned`. One output partition per owned
4//! vnode; remote rows ship via `ShuffleSender`.
5
6#![allow(clippy::disallowed_types)] // this is a cluster-only ExecutionPlan; not hot-path
7
8use std::any::Any;
9use std::collections::HashMap;
10use std::fmt::{self, Debug, Formatter};
11use std::sync::{Arc, OnceLock};
12
13use arrow::compute::take;
14use arrow_array::{RecordBatch, UInt32Array};
15use arrow_schema::SchemaRef;
16use datafusion::error::Result;
17use datafusion::execution::{SendableRecordBatchStream, TaskContext};
18use datafusion::physical_expr::{EquivalenceProperties, Partitioning};
19use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
20use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
21use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
22use datafusion_common::DataFusionError;
23use futures::stream::{self, StreamExt};
24use laminar_core::checkpoint::barrier::CheckpointBarrier;
25use laminar_core::shuffle::{BarrierTracker, ShufflePeerId, ShuffleReceiver, ShuffleSender};
26use laminar_core::state::{owned_vnodes, NodeId, VnodeRegistry};
27use tokio::sync::{mpsc, watch, Mutex as AsyncMutex};
28use tokio::task::JoinHandle;
29
30/// Cross-instance hash repartition.
31pub struct ClusterRepartitionExec {
32    input: Arc<dyn ExecutionPlan>,
33    /// Column indices hashed to pick a vnode. Only plain columns today;
34    /// computed hash keys would need `PhysicalExpr` evaluation.
35    hash_columns: Vec<usize>,
36    registry: Arc<VnodeRegistry>,
37    sender: Arc<ShuffleSender>,
38    receiver: Arc<ShuffleReceiver>,
39    self_id: NodeId,
40    /// Peers we fan barriers out to; frozen at construction.
41    peers: Vec<ShufflePeerId>,
42    /// One output partition per owned vnode.
43    owned: Vec<u32>,
44    vnode_to_partition: HashMap<u32, usize>,
45    schema: SchemaRef,
46    properties: PlanProperties,
47    runtime: OnceLock<Arc<RuntimeState>>,
48}
49
50struct RuntimeState {
51    /// One receiver per output partition; claimed once by `execute(p)`.
52    receivers: AsyncMutex<Vec<Option<mpsc::Receiver<RecordBatch>>>>,
53    inject_barrier_tx: mpsc::UnboundedSender<CheckpointBarrier>,
54    /// Carries the aligned checkpoint id for downstream wrappers.
55    aligned_epoch_watch: watch::Receiver<u64>,
56    _router: JoinHandle<()>,
57    _dispatcher: JoinHandle<()>,
58}
59
60impl ClusterRepartitionExec {
61    /// Handle on the vnode registry. Callers read `assignment_version()`
62    /// to stamp state writes for the split-brain fence.
63    #[must_use]
64    pub fn registry(&self) -> &Arc<VnodeRegistry> {
65        &self.registry
66    }
67
68    /// Construct the exec.
69    ///
70    /// # Errors
71    /// Out-of-range hash columns or zero owned vnodes.
72    pub fn try_new(
73        input: Arc<dyn ExecutionPlan>,
74        hash_columns: Vec<usize>,
75        registry: Arc<VnodeRegistry>,
76        sender: Arc<ShuffleSender>,
77        receiver: Arc<ShuffleReceiver>,
78        self_id: NodeId,
79    ) -> Result<Self> {
80        let schema = input.schema();
81        if let Some(&bad) = hash_columns.iter().find(|&&i| i >= schema.fields().len()) {
82            return Err(DataFusionError::Plan(format!(
83                "ClusterRepartitionExec: hash column {bad} out of range (schema has {} fields)",
84                schema.fields().len()
85            )));
86        }
87
88        let owned = owned_vnodes(&registry, self_id);
89        if owned.is_empty() {
90            return Err(DataFusionError::Plan(format!(
91                "ClusterRepartitionExec: instance {:?} owns no vnodes",
92                self_id
93            )));
94        }
95        let vnode_to_partition = owned.iter().enumerate().map(|(i, &v)| (v, i)).collect();
96
97        // Enumerate distinct peers from the registry (everyone who
98        // owns at least one vnode and isn't us). Frozen at construction
99        // — dynamic membership is deferred work.
100        let mut peer_set: std::collections::BTreeSet<u64> = std::collections::BTreeSet::new();
101        for v in 0..registry.vnode_count() {
102            let o = registry.owner(v);
103            if !o.is_unassigned() && o != self_id {
104                peer_set.insert(o.0);
105            }
106        }
107        let peers: Vec<ShufflePeerId> = peer_set.into_iter().collect();
108
109        let properties = PlanProperties::new(
110            EquivalenceProperties::new(Arc::clone(&schema)),
111            // We advertise hash partitioning. Satisfies
112            // `FinalPartitioned`'s required-input-distribution.
113            Partitioning::UnknownPartitioning(owned.len()),
114            EmissionType::Incremental,
115            Boundedness::Unbounded {
116                requires_infinite_memory: false,
117            },
118        );
119
120        Ok(Self {
121            input,
122            hash_columns,
123            registry,
124            sender,
125            receiver,
126            self_id,
127            peers,
128            owned,
129            vnode_to_partition,
130            schema,
131            properties,
132            runtime: OnceLock::new(),
133        })
134    }
135
136    /// Inject a checkpoint barrier, fanning out to peers.
137    ///
138    /// # Errors
139    /// Before `execute()` or after the router exits.
140    pub fn inject_barrier(&self, barrier: CheckpointBarrier) -> Result<()> {
141        let Some(runtime) = self.runtime.get() else {
142            return Err(DataFusionError::Execution(
143                "ClusterRepartitionExec::inject_barrier before execute()".into(),
144            ));
145        };
146        runtime
147            .inject_barrier_tx
148            .send(barrier)
149            .map_err(|_| DataFusionError::Execution("router task exited".into()))
150    }
151
152    /// `None` before `execute()`; otherwise a watch that fires each
153    /// time the tracker aligns, carrying the checkpoint id.
154    #[must_use]
155    pub fn aligned_epoch_watch(&self) -> Option<watch::Receiver<u64>> {
156        self.runtime.get().map(|rt| rt.aligned_epoch_watch.clone())
157    }
158
159    /// Spawn the router + dispatcher tasks on the first `execute()`.
160    fn init_runtime(&self, context: &Arc<TaskContext>) -> Result<Arc<RuntimeState>> {
161        if let Some(existing) = self.runtime.get() {
162            return Ok(Arc::clone(existing));
163        }
164
165        let n_partitions = self.owned.len();
166        let mut partition_txs: Vec<mpsc::Sender<RecordBatch>> = Vec::with_capacity(n_partitions);
167        let mut receivers = Vec::with_capacity(n_partitions);
168        for _ in 0..n_partitions {
169            let (tx, rx) = mpsc::channel::<RecordBatch>(16);
170            partition_txs.push(tx);
171            receivers.push(Some(rx));
172        }
173
174        let input_stream = self.input.execute(0, Arc::clone(context))?;
175        let hash_columns = self.hash_columns.clone();
176        let registry = Arc::clone(&self.registry);
177        let self_id = self.self_id;
178        let sender = Arc::clone(&self.sender);
179        let receiver = Arc::clone(&self.receiver);
180        let vnode_to_partition = self.vnode_to_partition.clone();
181        let peers = self.peers.clone();
182
183        // Barrier plumbing:
184        //   - inject_barrier_(tx|rx): external trigger → router
185        //   - peer_barrier_(tx|rx):   dispatcher → router (after peer gossip)
186        //   - aligned_(tx|rx):        router → subscribers (watch channel)
187        let (inject_tx, inject_rx) = mpsc::unbounded_channel::<CheckpointBarrier>();
188        let (peer_tx, peer_rx) = mpsc::unbounded_channel::<(ShufflePeerId, CheckpointBarrier)>();
189        let (aligned_tx, aligned_rx) = watch::channel::<u64>(0);
190
191        let router_txs = partition_txs.clone();
192        let router_registry = Arc::clone(&registry);
193        let router_vtp = vnode_to_partition.clone();
194        let router_sender = Arc::clone(&sender);
195        let router_peers = peers.clone();
196        let router = tokio::spawn(async move {
197            route_input_stream(
198                input_stream,
199                hash_columns,
200                router_registry,
201                self_id,
202                router_vtp,
203                router_sender,
204                router_txs,
205                router_peers,
206                inject_rx,
207                peer_rx,
208                aligned_tx,
209            )
210            .await;
211        });
212
213        let dispatcher_txs = partition_txs;
214        let dispatcher = tokio::spawn(async move {
215            dispatch_inbound(receiver, vnode_to_partition, dispatcher_txs, peer_tx).await;
216        });
217
218        let state = Arc::new(RuntimeState {
219            receivers: AsyncMutex::new(receivers),
220            inject_barrier_tx: inject_tx,
221            aligned_epoch_watch: aligned_rx,
222            _router: router,
223            _dispatcher: dispatcher,
224        });
225        Ok(Arc::clone(self.runtime.get_or_init(|| state)))
226    }
227}
228
229impl Debug for ClusterRepartitionExec {
230    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
231        f.debug_struct("ClusterRepartitionExec")
232            .field("self_id", &self.self_id)
233            .field("hash_columns", &self.hash_columns)
234            .field("owned_vnodes", &self.owned.len())
235            .finish_non_exhaustive()
236    }
237}
238
239impl DisplayAs for ClusterRepartitionExec {
240    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter<'_>) -> fmt::Result {
241        match t {
242            DisplayFormatType::Default | DisplayFormatType::Verbose => write!(
243                f,
244                "ClusterRepartitionExec: owned_vnodes={}, hash_columns={:?}",
245                self.owned.len(),
246                self.hash_columns
247            ),
248            DisplayFormatType::TreeRender => write!(f, "ClusterRepartitionExec"),
249        }
250    }
251}
252
253impl ExecutionPlan for ClusterRepartitionExec {
254    fn name(&self) -> &'static str {
255        "ClusterRepartitionExec"
256    }
257
258    fn as_any(&self) -> &dyn Any {
259        self
260    }
261
262    fn schema(&self) -> SchemaRef {
263        Arc::clone(&self.schema)
264    }
265
266    fn properties(&self) -> &PlanProperties {
267        &self.properties
268    }
269
270    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
271        vec![&self.input]
272    }
273
274    fn with_new_children(
275        self: Arc<Self>,
276        mut children: Vec<Arc<dyn ExecutionPlan>>,
277    ) -> Result<Arc<dyn ExecutionPlan>> {
278        if children.len() != 1 {
279            return Err(DataFusionError::Plan(
280                "ClusterRepartitionExec requires exactly one child".into(),
281            ));
282        }
283        Ok(Arc::new(Self {
284            input: children.swap_remove(0),
285            hash_columns: self.hash_columns.clone(),
286            registry: Arc::clone(&self.registry),
287            sender: Arc::clone(&self.sender),
288            receiver: Arc::clone(&self.receiver),
289            self_id: self.self_id,
290            peers: self.peers.clone(),
291            owned: self.owned.clone(),
292            vnode_to_partition: self.vnode_to_partition.clone(),
293            schema: Arc::clone(&self.schema),
294            properties: self.properties.clone(),
295            // Fresh runtime on re-plan.
296            runtime: OnceLock::new(),
297        }))
298    }
299
300    fn execute(
301        &self,
302        partition: usize,
303        context: Arc<TaskContext>,
304    ) -> Result<SendableRecordBatchStream> {
305        if partition >= self.owned.len() {
306            return Err(DataFusionError::Plan(format!(
307                "ClusterRepartitionExec: partition {partition} >= output partitions {}",
308                self.owned.len(),
309            )));
310        }
311
312        let runtime = self.init_runtime(&context)?;
313
314        // Claim this partition's receiver. DataFusion contracts execute
315        // each partition exactly once per plan execution; a second call
316        // would have previously returned an empty stream silently,
317        // masking plan-reuse bugs. Now we return a typed error so the
318        // caller sees something went wrong.
319        let schema = Arc::clone(&self.schema);
320        let fut = async move {
321            let mut guard = runtime.receivers.lock().await;
322            guard[partition].take().ok_or_else(|| {
323                datafusion_common::DataFusionError::Execution(format!(
324                    "ClusterRepartitionExec::execute called twice for \
325                         partition {partition}; receivers are single-use"
326                ))
327            })
328        };
329        let stream = stream::once(fut).flat_map(move |maybe_rx| match maybe_rx {
330            Ok(rx) => tokio_stream::wrappers::ReceiverStream::new(rx)
331                .map(Ok)
332                .boxed(),
333            Err(e) => stream::once(async move { Err(e) }).boxed(),
334        });
335        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
336    }
337}
338
339/// Per-row vnode, hashed via `arrow::row::RowConverter` on the key
340/// columns (works for any Arrow type).
341fn row_vnodes(batch: &RecordBatch, hash_columns: &[usize], vnode_count: u32) -> Vec<u32> {
342    use arrow::row::{RowConverter, SortField};
343    use laminar_core::state::key_hash;
344
345    let cols: Vec<_> = hash_columns
346        .iter()
347        .map(|&i| Arc::clone(batch.column(i)))
348        .collect();
349    let fields: Vec<_> = cols
350        .iter()
351        .map(|c| SortField::new(c.data_type().clone()))
352        .collect();
353    let converter = RowConverter::new(fields).expect("row converter");
354    let rows = converter.convert_columns(&cols).expect("convert rows");
355
356    (0..batch.num_rows())
357        .map(|row| {
358            #[allow(clippy::cast_possible_truncation)]
359            let v = (key_hash(rows.row(row).as_ref()) % u64::from(vnode_count)) as u32;
360            v
361        })
362        .collect()
363}
364
365fn slice_for_vnode(batch: &RecordBatch, vnodes: &[u32], target_vnode: u32) -> Option<RecordBatch> {
366    let indices: UInt32Array = vnodes
367        .iter()
368        .enumerate()
369        .filter_map(|(i, &v)| {
370            if v == target_vnode {
371                u32::try_from(i).ok()
372            } else {
373                None
374            }
375        })
376        .collect();
377    if indices.is_empty() {
378        return None;
379    }
380    let new_cols = batch
381        .columns()
382        .iter()
383        .map(|c| take(c, &indices, None).expect("take"))
384        .collect::<Vec<_>>();
385    Some(RecordBatch::try_new(batch.schema(), new_cols).expect("rebuild"))
386}
387
388/// Drains data batches and barriers (local + peer), routes data to
389/// local partitions or peer senders, and publishes aligned-epoch ids
390/// on `aligned_tx`.
391#[allow(clippy::too_many_arguments)]
392async fn route_input_stream(
393    input: SendableRecordBatchStream,
394    hash_columns: Vec<usize>,
395    registry: Arc<VnodeRegistry>,
396    self_id: NodeId,
397    vnode_to_partition: HashMap<u32, usize>,
398    sender: Arc<ShuffleSender>,
399    partition_txs: Vec<mpsc::Sender<RecordBatch>>,
400    peers: Vec<ShufflePeerId>,
401    mut inject_rx: mpsc::UnboundedReceiver<CheckpointBarrier>,
402    mut peer_rx: mpsc::UnboundedReceiver<(ShufflePeerId, CheckpointBarrier)>,
403    aligned_tx: watch::Sender<u64>,
404) {
405    let vnode_count = registry.vnode_count();
406    let n_inputs = peers.len() + 1;
407    let tracker = BarrierTracker::new(n_inputs);
408    let peer_port: HashMap<ShufflePeerId, usize> =
409        peers.iter().enumerate().map(|(i, &p)| (p, i + 1)).collect();
410
411    // Once the local input stream ends we drop it so the select arm
412    // becomes `pending` forever. Partition senders are also dropped
413    // at input EOS so the downstream output streams terminate cleanly
414    // — the router stays alive purely for barrier coordination.
415    let mut input: Option<SendableRecordBatchStream> = Some(input);
416    let mut partition_txs: Option<Vec<mpsc::Sender<RecordBatch>>> = Some(partition_txs);
417
418    loop {
419        tokio::select! {
420            biased;
421            // 1. External barrier trigger.
422            Some(barrier) = inject_rx.recv() => {
423                let _ = sender.fan_out_barrier(&peers, barrier).await;
424                if let Some(aligned) = tracker.observe(0, barrier) {
425                    let _ = aligned_tx.send(aligned.checkpoint_id);
426                }
427            }
428            // 2. Peer barrier forwarded from the dispatcher.
429            Some((from, barrier)) = peer_rx.recv() => {
430                if let Some(&port) = peer_port.get(&from) {
431                    if let Some(aligned) = tracker.observe(port, barrier) {
432                        let _ = aligned_tx.send(aligned.checkpoint_id);
433                    }
434                }
435            }
436            // 3. Data from the local input. `pending` forever once the
437            // input is exhausted so this arm drops out of the select.
438            next = async {
439                match input.as_mut() {
440                    Some(s) => s.next().await,
441                    None => std::future::pending().await,
442                }
443            } => {
444                if let Some(Ok(batch)) = next {
445                    if batch.num_rows() == 0 { continue; }
446                    if partition_txs.is_none() { continue; }
447                    let row_vn = row_vnodes(&batch, &hash_columns, vnode_count);
448                    let mut seen = row_vn.clone();
449                    seen.sort_unstable();
450                    seen.dedup();
451                    let mut downstream_dropped = false;
452                    for v in seen {
453                        let Some(slice) = slice_for_vnode(&batch, &row_vn, v) else {
454                            continue;
455                        };
456                        let owner = registry.owner(v);
457                        if owner == self_id {
458                            if let Some(&idx) = vnode_to_partition.get(&v) {
459                                let send_res = partition_txs.as_ref().unwrap()[idx]
460                                    .send(slice)
461                                    .await;
462                                if send_res.is_err() {
463                                    downstream_dropped = true;
464                                    break;
465                                }
466                            }
467                        } else if !owner.is_unassigned() {
468                            let msg = laminar_core::shuffle::ShuffleMessage::VnodeData(
469                                v, slice,
470                            );
471                            // Drop on unreachable peer.
472                            let _ = sender.send_to(owner.0, &msg).await;
473                        }
474                    }
475                    if downstream_dropped {
476                        partition_txs = None;
477                        input = None;
478                    }
479                } else {
480                    // EOS or error: drop senders so output streams
481                    // terminate; keep router alive for barriers.
482                    input = None;
483                    partition_txs = None;
484                }
485            }
486        }
487    }
488}
489
490/// Consumes inbound shuffle frames: `VnodeData` to local partitions,
491/// `Barrier` to the router's peer-barrier channel.
492async fn dispatch_inbound(
493    receiver: Arc<ShuffleReceiver>,
494    vnode_to_partition: HashMap<u32, usize>,
495    partition_txs: Vec<mpsc::Sender<RecordBatch>>,
496    peer_barrier_tx: mpsc::UnboundedSender<(ShufflePeerId, CheckpointBarrier)>,
497) {
498    use laminar_core::shuffle::ShuffleMessage;
499    while let Some((from, msg)) = receiver.recv().await {
500        match msg {
501            ShuffleMessage::VnodeData(vnode, batch) => {
502                if batch.num_rows() == 0 {
503                    continue;
504                }
505                let Some(&idx) = vnode_to_partition.get(&vnode) else {
506                    // Sender mis-routed to a vnode we don't own — drop
507                    // (a future observability hook can count these).
508                    continue;
509                };
510                if partition_txs[idx].send(batch).await.is_err() {
511                    return; // downstream dropped
512                }
513            }
514            ShuffleMessage::Barrier(b) if peer_barrier_tx.send((from, b)).is_err() => return,
515            // Hello is consumed by per_peer_loop; Close closes the
516            // reader. Barriers that succeed fall through here.
517            _ => {}
518        }
519    }
520}