1use arrow::compute::take;
8use arrow_array::cast::AsArray;
9use arrow_array::types::{
10 Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type,
11 UInt64Type, UInt8Type,
12};
13use arrow_array::{Array, RecordBatch, UInt32Array};
14use std::hash::{Hash, Hasher};
15
16use super::router::{KeySpec, RouterError};
17
18pub struct PartitionedRouter {
26 key_spec: KeySpec,
27 num_cores: usize,
28 core_rows: Vec<Vec<u32>>,
30 result_buf: Vec<(usize, RecordBatch)>,
32 rr_counter: usize,
34}
35
36impl PartitionedRouter {
37 #[must_use]
43 pub fn new(key_spec: KeySpec, num_cores: usize) -> Self {
44 assert!(num_cores > 0, "num_cores must be > 0");
45 Self {
46 key_spec,
47 num_cores,
48 core_rows: (0..num_cores).map(|_| Vec::with_capacity(256)).collect(),
49 result_buf: Vec::with_capacity(num_cores),
50 rr_counter: 0,
51 }
52 }
53
54 pub fn route_batch(
63 &mut self,
64 batch: &RecordBatch,
65 ) -> Result<Vec<(usize, RecordBatch)>, RouterError> {
66 if batch.num_rows() == 0 {
67 return Err(RouterError::EmptyBatch);
68 }
69
70 match &self.key_spec {
71 KeySpec::RoundRobin => {
72 let core = self.rr_counter % self.num_cores;
73 self.rr_counter += 1;
74 Ok(vec![(core, batch.clone())])
75 }
76 KeySpec::Columns(names) => {
77 let indices: Vec<usize> = names
78 .iter()
79 .map(|name| {
80 batch
81 .schema()
82 .index_of(name)
83 .map_err(|_| RouterError::ColumnNotFoundByName)
84 })
85 .collect::<Result<_, _>>()?;
86 self.route_by_indices(batch, &indices)
87 }
88 KeySpec::ColumnIndices(indices) => {
89 let indices = indices.clone();
90 for &idx in &indices {
91 if idx >= batch.num_columns() {
92 return Err(RouterError::ColumnIndexOutOfRange);
93 }
94 }
95 self.route_by_indices(batch, &indices)
96 }
97 KeySpec::AllColumns => {
98 let indices: Vec<usize> = (0..batch.num_columns()).collect();
99 self.route_by_indices(batch, &indices)
100 }
101 }
102 }
103
104 fn route_by_indices(
106 &mut self,
107 batch: &RecordBatch,
108 indices: &[usize],
109 ) -> Result<Vec<(usize, RecordBatch)>, RouterError> {
110 let num_rows = batch.num_rows();
111
112 self.core_rows.iter_mut().for_each(Vec::clear);
114 self.result_buf.clear();
115
116 for row in 0..num_rows {
117 let mut hasher = rustc_hash::FxHasher::default();
118 for &col_idx in indices {
119 hash_array_value(batch.column(col_idx).as_ref(), row, &mut hasher)?;
120 }
121 #[allow(clippy::cast_possible_truncation)]
122 let core_id = (hasher.finish() as usize) % self.num_cores;
123 #[allow(clippy::cast_possible_truncation)]
124 self.core_rows[core_id].push(row as u32);
125 }
126
127 for (core_id, rows) in self.core_rows.iter().enumerate() {
128 if !rows.is_empty() {
129 let take_indices = UInt32Array::from_iter_values(rows.iter().copied());
130 let columns: Vec<_> = batch
131 .columns()
132 .iter()
133 .map(|col| take(col.as_ref(), &take_indices, None))
134 .collect::<Result<_, _>>()
135 .map_err(|_| RouterError::UnsupportedDataType)?;
136 let sub_batch = RecordBatch::try_new(batch.schema(), columns)
137 .map_err(|_| RouterError::UnsupportedDataType)?;
138 self.result_buf.push((core_id, sub_batch));
139 }
140 }
141 Ok(std::mem::take(&mut self.result_buf))
142 }
143}
144
145fn hash_array_value(
147 array: &dyn Array,
148 row: usize,
149 hasher: &mut impl Hasher,
150) -> Result<(), RouterError> {
151 if row >= array.len() {
152 return Err(RouterError::RowIndexOutOfRange);
153 }
154
155 if array.is_null(row) {
156 0u8.hash(hasher);
157 return Ok(());
158 }
159
160 if let Some(a) = array.as_primitive_opt::<Int64Type>() {
162 a.value(row).hash(hasher);
163 } else if let Some(a) = array.as_primitive_opt::<Int32Type>() {
164 a.value(row).hash(hasher);
165 } else if let Some(a) = array.as_primitive_opt::<UInt64Type>() {
166 a.value(row).hash(hasher);
167 } else if let Some(a) = array.as_primitive_opt::<UInt32Type>() {
168 a.value(row).hash(hasher);
169 } else if let Some(a) = array.as_primitive_opt::<Int16Type>() {
170 a.value(row).hash(hasher);
171 } else if let Some(a) = array.as_primitive_opt::<Int8Type>() {
172 a.value(row).hash(hasher);
173 } else if let Some(a) = array.as_primitive_opt::<UInt16Type>() {
174 a.value(row).hash(hasher);
175 } else if let Some(a) = array.as_primitive_opt::<UInt8Type>() {
176 a.value(row).hash(hasher);
177 } else if let Some(a) = array.as_primitive_opt::<Float64Type>() {
178 a.value(row).to_bits().hash(hasher);
179 } else if let Some(a) = array.as_primitive_opt::<Float32Type>() {
180 a.value(row).to_bits().hash(hasher);
181 } else if let Some(a) = array.as_string_opt::<i32>() {
182 a.value(row).hash(hasher);
183 } else if let Some(a) = array.as_string_opt::<i64>() {
184 a.value(row).hash(hasher);
185 } else if let Some(a) = array.as_binary_opt::<i32>() {
186 a.value(row).hash(hasher);
187 } else if let Some(a) = array.as_binary_opt::<i64>() {
188 a.value(row).hash(hasher);
189 } else if let Some(a) = array.as_boolean_opt() {
190 a.value(row).hash(hasher);
191 } else if let Some(a) = array.as_primitive_opt::<arrow_array::types::TimestampMillisecondType>()
192 {
193 a.value(row).hash(hasher);
194 } else if let Some(a) = array.as_primitive_opt::<arrow_array::types::TimestampMicrosecondType>()
195 {
196 a.value(row).hash(hasher);
197 } else if let Some(a) = array.as_primitive_opt::<arrow_array::types::Date32Type>() {
198 a.value(row).hash(hasher);
199 } else if let Some(a) = array.as_primitive_opt::<arrow_array::types::Date64Type>() {
200 a.value(row).hash(hasher);
201 } else if let Some(a) = array.as_primitive_opt::<arrow_array::types::Decimal128Type>() {
202 a.value(row).hash(hasher);
203 } else {
204 return Err(RouterError::UnsupportedDataType);
205 }
206 Ok(())
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212 use arrow_array::{Int64Array, StringArray};
213 use arrow_schema::{DataType, Field, Schema};
214 use std::sync::Arc;
215
216 fn make_batch(keys: Vec<i64>) -> RecordBatch {
217 let schema = Schema::new(vec![Field::new("key", DataType::Int64, false)]);
218 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(Int64Array::from(keys))]).unwrap()
219 }
220
221 fn make_string_batch(keys: Vec<&str>) -> RecordBatch {
222 let schema = Schema::new(vec![Field::new("name", DataType::Utf8, false)]);
223 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(StringArray::from(keys))]).unwrap()
224 }
225
226 #[test]
227 fn test_deterministic_routing() {
228 let mut router = PartitionedRouter::new(KeySpec::ColumnIndices(vec![0]), 4);
229 let batch = make_batch(vec![1, 2, 3, 4, 1, 2, 3, 4]);
230
231 let result = router.route_batch(&batch).unwrap();
232
233 let result2 = router.route_batch(&batch).unwrap();
236 assert_eq!(result.len(), result2.len());
237 for (a, b) in result.iter().zip(result2.iter()) {
238 assert_eq!(a.0, b.0);
239 assert_eq!(a.1.num_rows(), b.1.num_rows());
240 }
241 }
242
243 #[test]
244 fn test_same_key_same_core() {
245 let mut router = PartitionedRouter::new(KeySpec::ColumnIndices(vec![0]), 4);
246 let batch = make_batch(vec![42, 42, 42]);
247
248 let result = router.route_batch(&batch).unwrap();
249 assert_eq!(result.len(), 1, "All same keys should go to one core");
250 assert_eq!(result[0].1.num_rows(), 3);
251 }
252
253 #[test]
254 fn test_round_robin_cycles_across_cores() {
255 let mut router = PartitionedRouter::new(KeySpec::RoundRobin, 3);
256 let batch = make_batch(vec![1, 2, 3]);
257
258 let r0 = router.route_batch(&batch).unwrap();
259 assert_eq!(r0[0].0, 0);
260
261 let r1 = router.route_batch(&batch).unwrap();
262 assert_eq!(r1[0].0, 1);
263
264 let r2 = router.route_batch(&batch).unwrap();
265 assert_eq!(r2[0].0, 2);
266
267 let r3 = router.route_batch(&batch).unwrap();
269 assert_eq!(r3[0].0, 0);
270 }
271
272 #[test]
273 fn test_column_names() {
274 let mut router = PartitionedRouter::new(KeySpec::Columns(vec!["key".to_string()]), 2);
275 let batch = make_batch(vec![1, 2, 3, 4]);
276
277 let result = router.route_batch(&batch).unwrap();
278 let total_rows: usize = result.iter().map(|(_, b)| b.num_rows()).sum();
279 assert_eq!(total_rows, 4);
280 }
281
282 #[test]
283 fn test_column_not_found() {
284 let mut router = PartitionedRouter::new(KeySpec::Columns(vec!["missing".to_string()]), 2);
285 let batch = make_batch(vec![1, 2]);
286
287 let result = router.route_batch(&batch);
288 assert!(matches!(result, Err(RouterError::ColumnNotFoundByName)));
289 }
290
291 #[test]
292 fn test_column_index_out_of_range() {
293 let mut router = PartitionedRouter::new(KeySpec::ColumnIndices(vec![5]), 2);
294 let batch = make_batch(vec![1, 2]);
295
296 let result = router.route_batch(&batch);
297 assert!(matches!(result, Err(RouterError::ColumnIndexOutOfRange)));
298 }
299
300 #[test]
301 fn test_empty_batch() {
302 let mut router = PartitionedRouter::new(KeySpec::ColumnIndices(vec![0]), 2);
303 let batch = make_batch(vec![]);
304
305 let result = router.route_batch(&batch);
306 assert!(matches!(result, Err(RouterError::EmptyBatch)));
307 }
308
309 #[test]
310 fn test_string_keys() {
311 let mut router = PartitionedRouter::new(KeySpec::Columns(vec!["name".to_string()]), 4);
312 let batch = make_string_batch(vec!["alice", "bob", "alice", "charlie"]);
313
314 let result = router.route_batch(&batch).unwrap();
315 let total_rows: usize = result.iter().map(|(_, b)| b.num_rows()).sum();
316 assert_eq!(total_rows, 4);
317
318 for (_, sub) in &result {
320 let col = sub.column(0);
321 let arr = col.as_string::<i32>();
322 let values: Vec<&str> = (0..arr.len()).map(|i| arr.value(i)).collect();
323 if values.contains(&"alice") {
324 assert_eq!(
325 values.iter().filter(|&&v| v == "alice").count(),
326 2,
327 "Both 'alice' rows should be in the same sub-batch"
328 );
329 }
330 }
331 }
332
333 #[test]
334 fn test_all_columns() {
335 let mut router = PartitionedRouter::new(KeySpec::AllColumns, 2);
336 let batch = make_batch(vec![1, 2, 3, 4]);
337
338 let result = router.route_batch(&batch).unwrap();
339 let total_rows: usize = result.iter().map(|(_, b)| b.num_rows()).sum();
340 assert_eq!(total_rows, 4);
341 }
342
343 #[test]
344 fn test_preserves_schema() {
345 let mut router = PartitionedRouter::new(KeySpec::ColumnIndices(vec![0]), 2);
346 let batch = make_batch(vec![1, 2, 3]);
347
348 let result = router.route_batch(&batch).unwrap();
349 for (_, sub) in &result {
350 assert_eq!(sub.schema(), batch.schema());
351 }
352 }
353}