Skip to main content

laminar_core/lookup/
align.rs

1//! Shared key decoding + result realignment for on-demand lookup sources.
2//!
3//! Every [`LookupSource`](crate::lookup::source::LookupSource) has the same
4//! shape around its backend-specific fetch: the operator hands keys as
5//! `RowConverter`-encoded bytes, the source fetches matching rows in arbitrary
6//! order, and each fetched row is matched back to its key by re-encoding its
7//! primary-key columns with the *same* converter. [`KeyAligner`] owns that
8//! converter and performs the decode + realign so each source only writes its
9//! fetch.
10
11use std::sync::Arc;
12
13use arrow::row::{RowConverter, SortField};
14use arrow_array::{ArrayRef, RecordBatch};
15use rustc_hash::FxHashMap;
16
17use crate::lookup::source::LookupError;
18
19/// Decodes opaque lookup keys and realigns fetched rows to the input key order.
20pub struct KeyAligner {
21    converter: RowConverter,
22    pk_columns: Vec<String>,
23}
24
25impl KeyAligner {
26    /// Build an aligner from the primary-key sort fields and column names
27    /// (same order).
28    ///
29    /// # Errors
30    ///
31    /// Returns `LookupError::Internal` if the key is empty or the converter
32    /// cannot be built for the given sort fields.
33    pub fn new(
34        pk_sort_fields: Vec<SortField>,
35        pk_columns: Vec<String>,
36    ) -> Result<Self, LookupError> {
37        if pk_columns.is_empty() {
38            return Err(LookupError::Internal(
39                "primary_key_columns must not be empty".into(),
40            ));
41        }
42        let converter = RowConverter::new(pk_sort_fields)
43            .map_err(|e| LookupError::Internal(format!("row converter: {e}")))?;
44        Ok(Self {
45            converter,
46            pk_columns,
47        })
48    }
49
50    /// The primary-key column names, in key order.
51    #[must_use]
52    pub fn pk_columns(&self) -> &[String] {
53        &self.pk_columns
54    }
55
56    /// Decode opaque key bytes into the primary-key columns (column-major: one
57    /// array per PK column, each `keys.len()` long) for building a source
58    /// filter.
59    ///
60    /// # Errors
61    ///
62    /// Returns `LookupError::Internal` if the bytes cannot be decoded.
63    pub fn decode_keys(&self, keys: &[&[u8]]) -> Result<Vec<ArrayRef>, LookupError> {
64        let parser = self.converter.parser();
65        let parsed = keys.iter().map(|k| parser.parse(k));
66        self.converter
67            .convert_rows(parsed)
68            .map_err(|e| LookupError::Internal(format!("decode keys: {e}")))
69    }
70
71    /// Realign fetched rows to the input key order. Each fetched row is matched
72    /// to its key by re-encoding its PK columns; the first row wins per key,
73    /// duplicate input keys each resolve to their own single-row slice, and
74    /// misses are `None`.
75    ///
76    /// # Errors
77    ///
78    /// Returns `LookupError::Internal` if a PK column is absent from a fetched
79    /// batch or cannot be re-encoded.
80    pub fn align(
81        &self,
82        keys: &[&[u8]],
83        fetched: &[RecordBatch],
84    ) -> Result<Vec<Option<RecordBatch>>, LookupError> {
85        let mut index: FxHashMap<Vec<u8>, (usize, usize)> = FxHashMap::default();
86        for (batch_idx, batch) in fetched.iter().enumerate() {
87            if batch.num_rows() == 0 {
88                continue;
89            }
90            let pk_cols = self
91                .pk_columns
92                .iter()
93                .map(|name| {
94                    let idx = batch.schema().index_of(name).map_err(|_| {
95                        LookupError::Internal(format!("pk column not found in result: {name}"))
96                    })?;
97                    Ok(Arc::clone(batch.column(idx)))
98                })
99                .collect::<Result<Vec<ArrayRef>, LookupError>>()?;
100            let rows = self
101                .converter
102                .convert_columns(&pk_cols)
103                .map_err(|e| LookupError::Internal(format!("encode result keys: {e}")))?;
104            for row in 0..batch.num_rows() {
105                index
106                    .entry(rows.row(row).as_ref().to_vec())
107                    .or_insert((batch_idx, row));
108            }
109        }
110        Ok(keys
111            .iter()
112            .map(|key| index.get(*key).map(|&(bi, row)| fetched[bi].slice(row, 1)))
113            .collect())
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use arrow_array::Int64Array;
121    use arrow_schema::{DataType, Field, Schema};
122
123    fn aligner() -> KeyAligner {
124        KeyAligner::new(vec![SortField::new(DataType::Int64)], vec!["id".into()]).unwrap()
125    }
126
127    fn batch(ids: &[i64]) -> RecordBatch {
128        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
129        RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(ids.to_vec()))]).unwrap()
130    }
131
132    fn encode(ids: &[i64]) -> Vec<Vec<u8>> {
133        let conv = RowConverter::new(vec![SortField::new(DataType::Int64)]).unwrap();
134        let rows = conv
135            .convert_columns(&[Arc::new(Int64Array::from(ids.to_vec()))])
136            .unwrap();
137        (0..ids.len())
138            .map(|i| rows.row(i).as_ref().to_vec())
139            .collect()
140    }
141
142    #[test]
143    fn aligns_out_of_order_with_misses_and_dups() {
144        let aligner = aligner();
145        // Fetched rows arrive in a different order than the keys, and key 99
146        // is absent; key 2 is requested twice.
147        let fetched = vec![batch(&[2, 5])];
148        let keys = encode(&[5, 2, 99, 2]);
149        let key_refs: Vec<&[u8]> = keys.iter().map(Vec::as_slice).collect();
150
151        let out = aligner.align(&key_refs, &fetched).unwrap();
152        let id = |b: &Option<RecordBatch>| {
153            b.as_ref().map(|b| {
154                b.column(0)
155                    .as_any()
156                    .downcast_ref::<Int64Array>()
157                    .unwrap()
158                    .value(0)
159            })
160        };
161        assert_eq!(id(&out[0]), Some(5));
162        assert_eq!(id(&out[1]), Some(2));
163        assert_eq!(id(&out[2]), None); // miss
164        assert_eq!(id(&out[3]), Some(2)); // duplicate key resolves again
165    }
166
167    #[test]
168    fn decode_round_trips_to_pk_columns() {
169        let aligner = aligner();
170        let keys = encode(&[7, 8]);
171        let key_refs: Vec<&[u8]> = keys.iter().map(Vec::as_slice).collect();
172        let cols = aligner.decode_keys(&key_refs).unwrap();
173        let ids = cols[0].as_any().downcast_ref::<Int64Array>().unwrap();
174        assert_eq!(ids.values(), &[7, 8]);
175    }
176}