laminar_core/lookup/
align.rs1use std::sync::Arc;
12
13use arrow::row::{RowConverter, SortField};
14use arrow_array::{ArrayRef, RecordBatch};
15use rustc_hash::FxHashMap;
16
17use crate::lookup::source::LookupError;
18
19pub struct KeyAligner {
21 converter: RowConverter,
22 pk_columns: Vec<String>,
23}
24
25impl KeyAligner {
26 pub fn new(
34 pk_sort_fields: Vec<SortField>,
35 pk_columns: Vec<String>,
36 ) -> Result<Self, LookupError> {
37 if pk_columns.is_empty() {
38 return Err(LookupError::Internal(
39 "primary_key_columns must not be empty".into(),
40 ));
41 }
42 let converter = RowConverter::new(pk_sort_fields)
43 .map_err(|e| LookupError::Internal(format!("row converter: {e}")))?;
44 Ok(Self {
45 converter,
46 pk_columns,
47 })
48 }
49
50 #[must_use]
52 pub fn pk_columns(&self) -> &[String] {
53 &self.pk_columns
54 }
55
56 pub fn decode_keys(&self, keys: &[&[u8]]) -> Result<Vec<ArrayRef>, LookupError> {
64 let parser = self.converter.parser();
65 let parsed = keys.iter().map(|k| parser.parse(k));
66 self.converter
67 .convert_rows(parsed)
68 .map_err(|e| LookupError::Internal(format!("decode keys: {e}")))
69 }
70
71 pub fn align(
81 &self,
82 keys: &[&[u8]],
83 fetched: &[RecordBatch],
84 ) -> Result<Vec<Option<RecordBatch>>, LookupError> {
85 let mut index: FxHashMap<Vec<u8>, (usize, usize)> = FxHashMap::default();
86 for (batch_idx, batch) in fetched.iter().enumerate() {
87 if batch.num_rows() == 0 {
88 continue;
89 }
90 let pk_cols = self
91 .pk_columns
92 .iter()
93 .map(|name| {
94 let idx = batch.schema().index_of(name).map_err(|_| {
95 LookupError::Internal(format!("pk column not found in result: {name}"))
96 })?;
97 Ok(Arc::clone(batch.column(idx)))
98 })
99 .collect::<Result<Vec<ArrayRef>, LookupError>>()?;
100 let rows = self
101 .converter
102 .convert_columns(&pk_cols)
103 .map_err(|e| LookupError::Internal(format!("encode result keys: {e}")))?;
104 for row in 0..batch.num_rows() {
105 index
106 .entry(rows.row(row).as_ref().to_vec())
107 .or_insert((batch_idx, row));
108 }
109 }
110 Ok(keys
111 .iter()
112 .map(|key| index.get(*key).map(|&(bi, row)| fetched[bi].slice(row, 1)))
113 .collect())
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120 use arrow_array::Int64Array;
121 use arrow_schema::{DataType, Field, Schema};
122
123 fn aligner() -> KeyAligner {
124 KeyAligner::new(vec![SortField::new(DataType::Int64)], vec!["id".into()]).unwrap()
125 }
126
127 fn batch(ids: &[i64]) -> RecordBatch {
128 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
129 RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(ids.to_vec()))]).unwrap()
130 }
131
132 fn encode(ids: &[i64]) -> Vec<Vec<u8>> {
133 let conv = RowConverter::new(vec![SortField::new(DataType::Int64)]).unwrap();
134 let rows = conv
135 .convert_columns(&[Arc::new(Int64Array::from(ids.to_vec()))])
136 .unwrap();
137 (0..ids.len())
138 .map(|i| rows.row(i).as_ref().to_vec())
139 .collect()
140 }
141
142 #[test]
143 fn aligns_out_of_order_with_misses_and_dups() {
144 let aligner = aligner();
145 let fetched = vec![batch(&[2, 5])];
148 let keys = encode(&[5, 2, 99, 2]);
149 let key_refs: Vec<&[u8]> = keys.iter().map(Vec::as_slice).collect();
150
151 let out = aligner.align(&key_refs, &fetched).unwrap();
152 let id = |b: &Option<RecordBatch>| {
153 b.as_ref().map(|b| {
154 b.column(0)
155 .as_any()
156 .downcast_ref::<Int64Array>()
157 .unwrap()
158 .value(0)
159 })
160 };
161 assert_eq!(id(&out[0]), Some(5));
162 assert_eq!(id(&out[1]), Some(2));
163 assert_eq!(id(&out[2]), None); assert_eq!(id(&out[3]), Some(2)); }
166
167 #[test]
168 fn decode_round_trips_to_pk_columns() {
169 let aligner = aligner();
170 let keys = encode(&[7, 8]);
171 let key_refs: Vec<&[u8]> = keys.iter().map(Vec::as_slice).collect();
172 let cols = aligner.decode_keys(&key_refs).unwrap();
173 let ids = cols[0].as_any().downcast_ref::<Int64Array>().unwrap();
174 assert_eq!(ids.values(), &[7, 8]);
175 }
176}