Skip to main content

laminar_core/shuffle/
routing.rs

1//! Row→vnode routing shared by the cluster shuffle paths (the aggregate
2//! row-shuffle, the lookup-enrich key-shuffle, and `ClusterRepartitionExec`).
3//! Hashing matches [`crate::state::VnodeRegistry::vnode_for_key`] so every stage agrees on a
4//! key's vnode.
5
6use std::sync::Arc;
7
8use arrow::compute::take;
9use arrow::row::{RowConverter, SortField};
10use arrow_array::{ArrayRef, RecordBatch, UInt32Array};
11use rustc_hash::FxHashMap;
12
13use crate::state::{key_hash, NodeId, VnodeRegistry};
14
15/// Vnode for each row, hashing `columns` (by index) with the engine's
16/// `arrow-row` + xxh3 encoding. `columns` must be non-empty.
17///
18/// # Panics
19/// Panics if `columns` holds an out-of-range index or the columns cannot be
20/// row-encoded — both internal-invariant violations, not input errors.
21#[must_use]
22pub fn row_vnodes(batch: &RecordBatch, columns: &[usize], vnode_count: u32) -> Vec<u32> {
23    let cols: Vec<ArrayRef> = columns
24        .iter()
25        .map(|&i| Arc::clone(batch.column(i)))
26        .collect();
27    let fields: Vec<SortField> = cols
28        .iter()
29        .map(|c| SortField::new(c.data_type().clone()))
30        .collect();
31    let converter = RowConverter::new(fields).expect("row converter");
32    let rows = converter.convert_columns(&cols).expect("convert rows");
33    (0..batch.num_rows())
34        .map(|r| {
35            #[allow(clippy::cast_possible_truncation)]
36            let v = (key_hash(rows.row(r).as_ref()) % u64::from(vnode_count)) as u32;
37            v
38        })
39        .collect()
40}
41
42/// The sub-batch of `batch` whose `row_vnodes[i] == target`, or `None` if no
43/// row maps to `target`.
44///
45/// # Panics
46/// Panics if the `take` or rebuild fails — only on an internal-invariant
47/// violation (the indices are derived from `batch` itself).
48#[must_use]
49pub fn slice_batch_by_vnode(
50    batch: &RecordBatch,
51    row_vnodes: &[u32],
52    target: u32,
53) -> Option<RecordBatch> {
54    let indices: UInt32Array = row_vnodes
55        .iter()
56        .enumerate()
57        .filter_map(|(i, &v)| (v == target).then(|| u32::try_from(i).ok()).flatten())
58        .collect();
59    if indices.is_empty() {
60        return None;
61    }
62    let cols: Vec<ArrayRef> = batch
63        .columns()
64        .iter()
65        .map(|c| take(c, &indices, None).expect("take"))
66        .collect();
67    Some(RecordBatch::try_new(batch.schema(), cols).expect("rebuild"))
68}
69
70/// Slice batch by vnodes in a single pass, returning a vector of (vnode, `RecordBatch`).
71/// Avoids the $O(V \times R)$ loop and allocations of vnode-by-vnode slicing.
72///
73/// # Panics
74/// Panics if `take` or rebuild fails — only on internal-invariant violations.
75#[must_use]
76#[allow(clippy::cast_possible_truncation)]
77pub fn slice_batch_by_vnodes(batch: &RecordBatch, row_vnodes: &[u32]) -> Vec<(u32, RecordBatch)> {
78    if batch.num_rows() == 0 {
79        return Vec::new();
80    }
81    let mut groups: FxHashMap<u32, Vec<u32>> = FxHashMap::default();
82    for (i, &v) in row_vnodes.iter().enumerate() {
83        groups.entry(v).or_default().push(i as u32);
84    }
85    let mut slices = Vec::with_capacity(groups.len());
86    for (vnode, indices_vec) in groups {
87        let indices = UInt32Array::from(indices_vec);
88        let cols: Vec<ArrayRef> = batch
89            .columns()
90            .iter()
91            .map(|c| take(c, &indices, None).expect("take"))
92            .collect();
93        if let Ok(slice) = RecordBatch::try_new(batch.schema(), cols) {
94            slices.push((vnode, slice));
95        }
96    }
97    slices
98}
99
100/// Slices a `RecordBatch` by targets.
101/// Remote rows are grouped by `NodeId`, and the metadata column `__laminar_vnode` is appended to the batch.
102/// Local rows are grouped by vnode.
103///
104/// Returns (`local_slices`, `remote_slices`).
105///
106/// # Panics
107/// Panics if `take` or rebuild fails — only on internal-invariant violations.
108#[must_use]
109#[allow(clippy::cast_possible_truncation)]
110pub fn slice_batch_by_targets(
111    batch: &RecordBatch,
112    row_vnodes: &[u32],
113    registry: &VnodeRegistry,
114    self_id: NodeId,
115) -> (FxHashMap<u32, RecordBatch>, FxHashMap<NodeId, RecordBatch>) {
116    if batch.num_rows() == 0 {
117        return (FxHashMap::default(), FxHashMap::default());
118    }
119
120    let mut local_groups: FxHashMap<u32, Vec<u32>> = FxHashMap::default();
121    let mut remote_groups: FxHashMap<NodeId, (Vec<u32>, Vec<u32>)> = FxHashMap::default();
122
123    for (row_idx, &vnode) in row_vnodes.iter().enumerate() {
124        let owner = registry.owner(vnode);
125        if owner == self_id {
126            local_groups.entry(vnode).or_default().push(row_idx as u32);
127        } else if !owner.is_unassigned() {
128            let entry = remote_groups.entry(owner).or_default();
129            entry.0.push(row_idx as u32);
130            entry.1.push(vnode);
131        }
132    }
133
134    let mut local_slices = FxHashMap::default();
135    for (vnode, indices_vec) in local_groups {
136        let indices = UInt32Array::from(indices_vec);
137        let cols: Vec<ArrayRef> = batch
138            .columns()
139            .iter()
140            .map(|c| take(c, &indices, None).expect("take"))
141            .collect();
142        if let Ok(slice) = RecordBatch::try_new(batch.schema(), cols) {
143            local_slices.insert(vnode, slice);
144        }
145    }
146
147    let mut remote_slices = FxHashMap::default();
148    for (node_id, (indices_vec, vnodes_vec)) in remote_groups {
149        let indices = UInt32Array::from(indices_vec);
150        let mut cols: Vec<ArrayRef> = batch
151            .columns()
152            .iter()
153            .map(|c| take(c, &indices, None).expect("take"))
154            .collect();
155        let vnode_col = Arc::new(UInt32Array::from(vnodes_vec)) as ArrayRef;
156        cols.push(vnode_col);
157
158        let mut fields = batch.schema().fields().to_vec();
159        fields.push(Arc::new(arrow_schema::Field::new(
160            "__laminar_vnode",
161            arrow_schema::DataType::UInt32,
162            false,
163        )));
164        let extended_schema = Arc::new(arrow_schema::Schema::new_with_metadata(
165            fields,
166            batch.schema().metadata().clone(),
167        ));
168
169        if let Ok(slice) = RecordBatch::try_new(extended_schema, cols) {
170            remote_slices.insert(node_id, slice);
171        }
172    }
173
174    (local_slices, remote_slices)
175}