laminar_core/shuffle/
routing.rs1use std::sync::Arc;
7
8use arrow::compute::take;
9use arrow::row::{RowConverter, SortField};
10use arrow_array::{ArrayRef, RecordBatch, UInt32Array};
11use rustc_hash::FxHashMap;
12
13use crate::state::{key_hash, NodeId, VnodeRegistry};
14
15#[must_use]
22pub fn row_vnodes(batch: &RecordBatch, columns: &[usize], vnode_count: u32) -> Vec<u32> {
23 let cols: Vec<ArrayRef> = columns
24 .iter()
25 .map(|&i| Arc::clone(batch.column(i)))
26 .collect();
27 let fields: Vec<SortField> = cols
28 .iter()
29 .map(|c| SortField::new(c.data_type().clone()))
30 .collect();
31 let converter = RowConverter::new(fields).expect("row converter");
32 let rows = converter.convert_columns(&cols).expect("convert rows");
33 (0..batch.num_rows())
34 .map(|r| {
35 #[allow(clippy::cast_possible_truncation)]
36 let v = (key_hash(rows.row(r).as_ref()) % u64::from(vnode_count)) as u32;
37 v
38 })
39 .collect()
40}
41
42#[must_use]
49pub fn slice_batch_by_vnode(
50 batch: &RecordBatch,
51 row_vnodes: &[u32],
52 target: u32,
53) -> Option<RecordBatch> {
54 let indices: UInt32Array = row_vnodes
55 .iter()
56 .enumerate()
57 .filter_map(|(i, &v)| (v == target).then(|| u32::try_from(i).ok()).flatten())
58 .collect();
59 if indices.is_empty() {
60 return None;
61 }
62 let cols: Vec<ArrayRef> = batch
63 .columns()
64 .iter()
65 .map(|c| take(c, &indices, None).expect("take"))
66 .collect();
67 Some(RecordBatch::try_new(batch.schema(), cols).expect("rebuild"))
68}
69
70#[must_use]
76#[allow(clippy::cast_possible_truncation)]
77pub fn slice_batch_by_vnodes(batch: &RecordBatch, row_vnodes: &[u32]) -> Vec<(u32, RecordBatch)> {
78 if batch.num_rows() == 0 {
79 return Vec::new();
80 }
81 let mut groups: FxHashMap<u32, Vec<u32>> = FxHashMap::default();
82 for (i, &v) in row_vnodes.iter().enumerate() {
83 groups.entry(v).or_default().push(i as u32);
84 }
85 let mut slices = Vec::with_capacity(groups.len());
86 for (vnode, indices_vec) in groups {
87 let indices = UInt32Array::from(indices_vec);
88 let cols: Vec<ArrayRef> = batch
89 .columns()
90 .iter()
91 .map(|c| take(c, &indices, None).expect("take"))
92 .collect();
93 if let Ok(slice) = RecordBatch::try_new(batch.schema(), cols) {
94 slices.push((vnode, slice));
95 }
96 }
97 slices
98}
99
100#[must_use]
109#[allow(clippy::cast_possible_truncation)]
110pub fn slice_batch_by_targets(
111 batch: &RecordBatch,
112 row_vnodes: &[u32],
113 registry: &VnodeRegistry,
114 self_id: NodeId,
115) -> (FxHashMap<u32, RecordBatch>, FxHashMap<NodeId, RecordBatch>) {
116 if batch.num_rows() == 0 {
117 return (FxHashMap::default(), FxHashMap::default());
118 }
119
120 let mut local_groups: FxHashMap<u32, Vec<u32>> = FxHashMap::default();
121 let mut remote_groups: FxHashMap<NodeId, (Vec<u32>, Vec<u32>)> = FxHashMap::default();
122
123 for (row_idx, &vnode) in row_vnodes.iter().enumerate() {
124 let owner = registry.owner(vnode);
125 if owner == self_id {
126 local_groups.entry(vnode).or_default().push(row_idx as u32);
127 } else if !owner.is_unassigned() {
128 let entry = remote_groups.entry(owner).or_default();
129 entry.0.push(row_idx as u32);
130 entry.1.push(vnode);
131 }
132 }
133
134 let mut local_slices = FxHashMap::default();
135 for (vnode, indices_vec) in local_groups {
136 let indices = UInt32Array::from(indices_vec);
137 let cols: Vec<ArrayRef> = batch
138 .columns()
139 .iter()
140 .map(|c| take(c, &indices, None).expect("take"))
141 .collect();
142 if let Ok(slice) = RecordBatch::try_new(batch.schema(), cols) {
143 local_slices.insert(vnode, slice);
144 }
145 }
146
147 let mut remote_slices = FxHashMap::default();
148 for (node_id, (indices_vec, vnodes_vec)) in remote_groups {
149 let indices = UInt32Array::from(indices_vec);
150 let mut cols: Vec<ArrayRef> = batch
151 .columns()
152 .iter()
153 .map(|c| take(c, &indices, None).expect("take"))
154 .collect();
155 let vnode_col = Arc::new(UInt32Array::from(vnodes_vec)) as ArrayRef;
156 cols.push(vnode_col);
157
158 let mut fields = batch.schema().fields().to_vec();
159 fields.push(Arc::new(arrow_schema::Field::new(
160 "__laminar_vnode",
161 arrow_schema::DataType::UInt32,
162 false,
163 )));
164 let extended_schema = Arc::new(arrow_schema::Schema::new_with_metadata(
165 fields,
166 batch.schema().metadata().clone(),
167 ));
168
169 if let Ok(slice) = RecordBatch::try_new(extended_schema, cols) {
170 remote_slices.insert(node_id, slice);
171 }
172 }
173
174 (local_slices, remote_slices)
175}