Skip to main content

laminar_core/tpc/
partitioned_router.rs

1//! Key-based row routing for per-core state partitioning.
2//!
3//! Splits a `RecordBatch` into per-core sub-batches based on a hash of
4//! the specified key columns, ensuring that rows with the same key always
5//! route to the same core.
6
7use arrow::compute::take;
8use arrow_array::cast::AsArray;
9use arrow_array::types::{
10    Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type,
11    UInt64Type, UInt8Type,
12};
13use arrow_array::{Array, RecordBatch, UInt32Array};
14use std::hash::{Hash, Hasher};
15
16use super::router::{KeySpec, RouterError};
17
18/// Routes `RecordBatch` rows to cores based on key columns.
19///
20/// Uses `FxHasher` for fast, deterministic hashing. Same key always maps
21/// to the same core, guaranteeing state locality.
22///
23/// Pre-allocates per-core row buffers and reuses them across calls to avoid
24/// per-cycle heap allocations on the hot path.
25pub struct PartitionedRouter {
26    key_spec: KeySpec,
27    num_cores: usize,
28    /// Per-core row index buffers, cleared (not dropped) between calls.
29    core_rows: Vec<Vec<u32>>,
30    /// Reusable result buffer.
31    result_buf: Vec<(usize, RecordBatch)>,
32    /// Round-robin counter for `KeySpec::RoundRobin`.
33    rr_counter: usize,
34}
35
36impl PartitionedRouter {
37    /// Create a new router.
38    ///
39    /// # Panics
40    ///
41    /// Panics if `num_cores` is zero.
42    #[must_use]
43    pub fn new(key_spec: KeySpec, num_cores: usize) -> Self {
44        assert!(num_cores > 0, "num_cores must be > 0");
45        Self {
46            key_spec,
47            num_cores,
48            core_rows: (0..num_cores).map(|_| Vec::with_capacity(256)).collect(),
49            result_buf: Vec::with_capacity(num_cores),
50            rr_counter: 0,
51        }
52    }
53
54    /// Split a batch into per-core sub-batches.
55    ///
56    /// Returns a vec of `(core_id, RecordBatch)` — one entry per core that
57    /// has rows. Cores with zero rows are omitted.
58    ///
59    /// # Errors
60    ///
61    /// Returns an error if key columns are not found or have unsupported types.
62    pub fn route_batch(
63        &mut self,
64        batch: &RecordBatch,
65    ) -> Result<Vec<(usize, RecordBatch)>, RouterError> {
66        if batch.num_rows() == 0 {
67            return Err(RouterError::EmptyBatch);
68        }
69
70        match &self.key_spec {
71            KeySpec::RoundRobin => {
72                let core = self.rr_counter % self.num_cores;
73                self.rr_counter += 1;
74                Ok(vec![(core, batch.clone())])
75            }
76            KeySpec::Columns(names) => {
77                let indices: Vec<usize> = names
78                    .iter()
79                    .map(|name| {
80                        batch
81                            .schema()
82                            .index_of(name)
83                            .map_err(|_| RouterError::ColumnNotFoundByName)
84                    })
85                    .collect::<Result<_, _>>()?;
86                self.route_by_indices(batch, &indices)
87            }
88            KeySpec::ColumnIndices(indices) => {
89                let indices = indices.clone();
90                for &idx in &indices {
91                    if idx >= batch.num_columns() {
92                        return Err(RouterError::ColumnIndexOutOfRange);
93                    }
94                }
95                self.route_by_indices(batch, &indices)
96            }
97            KeySpec::AllColumns => {
98                let indices: Vec<usize> = (0..batch.num_columns()).collect();
99                self.route_by_indices(batch, &indices)
100            }
101        }
102    }
103
104    /// Route rows by hashing the values at the given column indices.
105    fn route_by_indices(
106        &mut self,
107        batch: &RecordBatch,
108        indices: &[usize],
109    ) -> Result<Vec<(usize, RecordBatch)>, RouterError> {
110        let num_rows = batch.num_rows();
111
112        // Clear scratch buffers (retains capacity)
113        self.core_rows.iter_mut().for_each(Vec::clear);
114        self.result_buf.clear();
115
116        for row in 0..num_rows {
117            let mut hasher = rustc_hash::FxHasher::default();
118            for &col_idx in indices {
119                hash_array_value(batch.column(col_idx).as_ref(), row, &mut hasher)?;
120            }
121            #[allow(clippy::cast_possible_truncation)]
122            let core_id = (hasher.finish() as usize) % self.num_cores;
123            #[allow(clippy::cast_possible_truncation)]
124            self.core_rows[core_id].push(row as u32);
125        }
126
127        for (core_id, rows) in self.core_rows.iter().enumerate() {
128            if !rows.is_empty() {
129                let take_indices = UInt32Array::from_iter_values(rows.iter().copied());
130                let columns: Vec<_> = batch
131                    .columns()
132                    .iter()
133                    .map(|col| take(col.as_ref(), &take_indices, None))
134                    .collect::<Result<_, _>>()
135                    .map_err(|_| RouterError::UnsupportedDataType)?;
136                let sub_batch = RecordBatch::try_new(batch.schema(), columns)
137                    .map_err(|_| RouterError::UnsupportedDataType)?;
138                self.result_buf.push((core_id, sub_batch));
139            }
140        }
141        Ok(std::mem::take(&mut self.result_buf))
142    }
143}
144
145/// Hash a single value from an Arrow array at the given row index.
146fn hash_array_value(
147    array: &dyn Array,
148    row: usize,
149    hasher: &mut impl Hasher,
150) -> Result<(), RouterError> {
151    if row >= array.len() {
152        return Err(RouterError::RowIndexOutOfRange);
153    }
154
155    if array.is_null(row) {
156        0u8.hash(hasher);
157        return Ok(());
158    }
159
160    // Try numeric types (most common for keys)
161    if let Some(a) = array.as_primitive_opt::<Int64Type>() {
162        a.value(row).hash(hasher);
163    } else if let Some(a) = array.as_primitive_opt::<Int32Type>() {
164        a.value(row).hash(hasher);
165    } else if let Some(a) = array.as_primitive_opt::<UInt64Type>() {
166        a.value(row).hash(hasher);
167    } else if let Some(a) = array.as_primitive_opt::<UInt32Type>() {
168        a.value(row).hash(hasher);
169    } else if let Some(a) = array.as_primitive_opt::<Int16Type>() {
170        a.value(row).hash(hasher);
171    } else if let Some(a) = array.as_primitive_opt::<Int8Type>() {
172        a.value(row).hash(hasher);
173    } else if let Some(a) = array.as_primitive_opt::<UInt16Type>() {
174        a.value(row).hash(hasher);
175    } else if let Some(a) = array.as_primitive_opt::<UInt8Type>() {
176        a.value(row).hash(hasher);
177    } else if let Some(a) = array.as_primitive_opt::<Float64Type>() {
178        a.value(row).to_bits().hash(hasher);
179    } else if let Some(a) = array.as_primitive_opt::<Float32Type>() {
180        a.value(row).to_bits().hash(hasher);
181    } else if let Some(a) = array.as_string_opt::<i32>() {
182        a.value(row).hash(hasher);
183    } else if let Some(a) = array.as_string_opt::<i64>() {
184        a.value(row).hash(hasher);
185    } else if let Some(a) = array.as_binary_opt::<i32>() {
186        a.value(row).hash(hasher);
187    } else if let Some(a) = array.as_binary_opt::<i64>() {
188        a.value(row).hash(hasher);
189    } else if let Some(a) = array.as_boolean_opt() {
190        a.value(row).hash(hasher);
191    } else if let Some(a) = array.as_primitive_opt::<arrow_array::types::TimestampMillisecondType>()
192    {
193        a.value(row).hash(hasher);
194    } else if let Some(a) = array.as_primitive_opt::<arrow_array::types::TimestampMicrosecondType>()
195    {
196        a.value(row).hash(hasher);
197    } else if let Some(a) = array.as_primitive_opt::<arrow_array::types::Date32Type>() {
198        a.value(row).hash(hasher);
199    } else if let Some(a) = array.as_primitive_opt::<arrow_array::types::Date64Type>() {
200        a.value(row).hash(hasher);
201    } else if let Some(a) = array.as_primitive_opt::<arrow_array::types::Decimal128Type>() {
202        a.value(row).hash(hasher);
203    } else {
204        return Err(RouterError::UnsupportedDataType);
205    }
206    Ok(())
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use arrow_array::{Int64Array, StringArray};
213    use arrow_schema::{DataType, Field, Schema};
214    use std::sync::Arc;
215
216    fn make_batch(keys: Vec<i64>) -> RecordBatch {
217        let schema = Schema::new(vec![Field::new("key", DataType::Int64, false)]);
218        RecordBatch::try_new(Arc::new(schema), vec![Arc::new(Int64Array::from(keys))]).unwrap()
219    }
220
221    fn make_string_batch(keys: Vec<&str>) -> RecordBatch {
222        let schema = Schema::new(vec![Field::new("name", DataType::Utf8, false)]);
223        RecordBatch::try_new(Arc::new(schema), vec![Arc::new(StringArray::from(keys))]).unwrap()
224    }
225
226    #[test]
227    fn test_deterministic_routing() {
228        let mut router = PartitionedRouter::new(KeySpec::ColumnIndices(vec![0]), 4);
229        let batch = make_batch(vec![1, 2, 3, 4, 1, 2, 3, 4]);
230
231        let result = router.route_batch(&batch).unwrap();
232
233        // Same key should always map to the same core
234        // Verify by routing the same batch again
235        let result2 = router.route_batch(&batch).unwrap();
236        assert_eq!(result.len(), result2.len());
237        for (a, b) in result.iter().zip(result2.iter()) {
238            assert_eq!(a.0, b.0);
239            assert_eq!(a.1.num_rows(), b.1.num_rows());
240        }
241    }
242
243    #[test]
244    fn test_same_key_same_core() {
245        let mut router = PartitionedRouter::new(KeySpec::ColumnIndices(vec![0]), 4);
246        let batch = make_batch(vec![42, 42, 42]);
247
248        let result = router.route_batch(&batch).unwrap();
249        assert_eq!(result.len(), 1, "All same keys should go to one core");
250        assert_eq!(result[0].1.num_rows(), 3);
251    }
252
253    #[test]
254    fn test_round_robin_cycles_across_cores() {
255        let mut router = PartitionedRouter::new(KeySpec::RoundRobin, 3);
256        let batch = make_batch(vec![1, 2, 3]);
257
258        let r0 = router.route_batch(&batch).unwrap();
259        assert_eq!(r0[0].0, 0);
260
261        let r1 = router.route_batch(&batch).unwrap();
262        assert_eq!(r1[0].0, 1);
263
264        let r2 = router.route_batch(&batch).unwrap();
265        assert_eq!(r2[0].0, 2);
266
267        // Wraps back to core 0
268        let r3 = router.route_batch(&batch).unwrap();
269        assert_eq!(r3[0].0, 0);
270    }
271
272    #[test]
273    fn test_column_names() {
274        let mut router = PartitionedRouter::new(KeySpec::Columns(vec!["key".to_string()]), 2);
275        let batch = make_batch(vec![1, 2, 3, 4]);
276
277        let result = router.route_batch(&batch).unwrap();
278        let total_rows: usize = result.iter().map(|(_, b)| b.num_rows()).sum();
279        assert_eq!(total_rows, 4);
280    }
281
282    #[test]
283    fn test_column_not_found() {
284        let mut router = PartitionedRouter::new(KeySpec::Columns(vec!["missing".to_string()]), 2);
285        let batch = make_batch(vec![1, 2]);
286
287        let result = router.route_batch(&batch);
288        assert!(matches!(result, Err(RouterError::ColumnNotFoundByName)));
289    }
290
291    #[test]
292    fn test_column_index_out_of_range() {
293        let mut router = PartitionedRouter::new(KeySpec::ColumnIndices(vec![5]), 2);
294        let batch = make_batch(vec![1, 2]);
295
296        let result = router.route_batch(&batch);
297        assert!(matches!(result, Err(RouterError::ColumnIndexOutOfRange)));
298    }
299
300    #[test]
301    fn test_empty_batch() {
302        let mut router = PartitionedRouter::new(KeySpec::ColumnIndices(vec![0]), 2);
303        let batch = make_batch(vec![]);
304
305        let result = router.route_batch(&batch);
306        assert!(matches!(result, Err(RouterError::EmptyBatch)));
307    }
308
309    #[test]
310    fn test_string_keys() {
311        let mut router = PartitionedRouter::new(KeySpec::Columns(vec!["name".to_string()]), 4);
312        let batch = make_string_batch(vec!["alice", "bob", "alice", "charlie"]);
313
314        let result = router.route_batch(&batch).unwrap();
315        let total_rows: usize = result.iter().map(|(_, b)| b.num_rows()).sum();
316        assert_eq!(total_rows, 4);
317
318        // "alice" appears twice — both should be in the same sub-batch
319        for (_, sub) in &result {
320            let col = sub.column(0);
321            let arr = col.as_string::<i32>();
322            let values: Vec<&str> = (0..arr.len()).map(|i| arr.value(i)).collect();
323            if values.contains(&"alice") {
324                assert_eq!(
325                    values.iter().filter(|&&v| v == "alice").count(),
326                    2,
327                    "Both 'alice' rows should be in the same sub-batch"
328                );
329            }
330        }
331    }
332
333    #[test]
334    fn test_all_columns() {
335        let mut router = PartitionedRouter::new(KeySpec::AllColumns, 2);
336        let batch = make_batch(vec![1, 2, 3, 4]);
337
338        let result = router.route_batch(&batch).unwrap();
339        let total_rows: usize = result.iter().map(|(_, b)| b.num_rows()).sum();
340        assert_eq!(total_rows, 4);
341    }
342
343    #[test]
344    fn test_preserves_schema() {
345        let mut router = PartitionedRouter::new(KeySpec::ColumnIndices(vec![0]), 2);
346        let batch = make_batch(vec![1, 2, 3]);
347
348        let result = router.route_batch(&batch).unwrap();
349        for (_, sub) in &result {
350            assert_eq!(sub.schema(), batch.schema());
351        }
352    }
353}