Skip to main content

laminar_core/lookup/
predicate.rs

1//! Predicate types for lookup table query pushdown.
2//!
3//! These types represent filter predicates that can be pushed down to
4//! [`LookupSource`](crate::lookup) implementations. Sources that support
5//! predicate pushdown can filter data at the source, reducing network
6//! and deserialization costs.
7//!
8//! ## Pushdown Flow
9//!
10//! 1. SQL planner extracts WHERE predicates from a lookup join
11//! 2. [`split_predicates`] classifies each predicate as pushable or local
12//! 3. Pushable predicates are sent to the source via `LookupSource::query()`
13//! 4. Local predicates are applied after fetching results
14
15use std::fmt;
16
17/// A scalar value used in predicate evaluation.
18///
19/// This enum covers the value types supported by lookup table predicates.
20/// It is intentionally kept small — only types that can be pushed down to
21/// external sources are included.
22#[derive(Debug, Clone, PartialEq)]
23pub enum ScalarValue {
24    /// SQL NULL
25    Null,
26    /// Boolean
27    Bool(bool),
28    /// 64-bit signed integer (covers i8/i16/i32/i64)
29    Int64(i64),
30    /// 64-bit float (covers f32/f64)
31    Float64(f64),
32    /// UTF-8 string
33    Utf8(String),
34    /// Raw binary data
35    Binary(Vec<u8>),
36    /// Timestamp as microseconds since Unix epoch
37    Timestamp(i64),
38}
39
40impl fmt::Display for ScalarValue {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        match self {
43            Self::Null => write!(f, "NULL"),
44            Self::Bool(v) => write!(f, "{v}"),
45            Self::Int64(v) => write!(f, "{v}"),
46            Self::Float64(v) => write!(f, "{v}"),
47            Self::Utf8(v) => {
48                // Escape single quotes to prevent SQL injection
49                write!(f, "'{}'", v.replace('\'', "''"))
50            }
51            Self::Binary(v) => write!(f, "X'{}'", hex_encode(v)),
52            Self::Timestamp(us) => write!(f, "TIMESTAMP '{us}'"),
53        }
54    }
55}
56
57/// Encode bytes as lowercase hex string.
58fn hex_encode(bytes: &[u8]) -> String {
59    use std::fmt::Write;
60    bytes
61        .iter()
62        .fold(String::with_capacity(bytes.len() * 2), |mut s, b| {
63            let _ = write!(s, "{b:02x}");
64            s
65        })
66}
67
68/// A filter predicate for lookup table queries.
69///
70/// Each variant maps directly to a SQL comparison operator. The `column`
71/// field names the lookup table column, and the value(s) come from the
72/// stream side of the join.
73#[derive(Debug, Clone, PartialEq)]
74pub enum Predicate {
75    /// `column = value`
76    Eq {
77        /// Column name
78        column: String,
79        /// Value to compare against
80        value: ScalarValue,
81    },
82    /// `column != value`
83    NotEq {
84        /// Column name
85        column: String,
86        /// Value to compare against
87        value: ScalarValue,
88    },
89    /// `column < value`
90    Lt {
91        /// Column name
92        column: String,
93        /// Value to compare against
94        value: ScalarValue,
95    },
96    /// `column <= value`
97    LtEq {
98        /// Column name
99        column: String,
100        /// Value to compare against
101        value: ScalarValue,
102    },
103    /// `column > value`
104    Gt {
105        /// Column name
106        column: String,
107        /// Value to compare against
108        value: ScalarValue,
109    },
110    /// `column >= value`
111    GtEq {
112        /// Column name
113        column: String,
114        /// Value to compare against
115        value: ScalarValue,
116    },
117    /// `column IN (values...)`
118    In {
119        /// Column name
120        column: String,
121        /// Set of values to match against
122        values: Vec<ScalarValue>,
123    },
124    /// `column IS NULL`
125    IsNull {
126        /// Column name
127        column: String,
128    },
129    /// `column IS NOT NULL`
130    IsNotNull {
131        /// Column name
132        column: String,
133    },
134}
135
136impl Predicate {
137    /// Returns the column name this predicate references.
138    #[must_use]
139    pub fn column(&self) -> &str {
140        match self {
141            Self::Eq { column, .. }
142            | Self::NotEq { column, .. }
143            | Self::Lt { column, .. }
144            | Self::LtEq { column, .. }
145            | Self::Gt { column, .. }
146            | Self::GtEq { column, .. }
147            | Self::In { column, .. }
148            | Self::IsNull { column }
149            | Self::IsNotNull { column } => column,
150        }
151    }
152}
153
154/// Capabilities that a lookup source declares for predicate pushdown.
155///
156/// Used by [`split_predicates`] to decide which predicates can be
157/// pushed to the source vs. evaluated locally.
158#[derive(Debug, Clone, Default)]
159pub struct SourceCapabilities {
160    /// Columns that support equality pushdown.
161    pub eq_columns: Vec<String>,
162    /// Columns that support range pushdown (`Lt`, `LtEq`, `Gt`, `GtEq`).
163    pub range_columns: Vec<String>,
164    /// Columns that support IN-list pushdown.
165    pub in_columns: Vec<String>,
166    /// Whether the source supports IS NULL / IS NOT NULL pushdown.
167    pub supports_null_check: bool,
168}
169
170/// Result of splitting predicates into pushable and local sets.
171#[derive(Debug, Clone)]
172pub struct SplitPredicates {
173    /// Predicates that can be pushed down to the source.
174    pub pushable: Vec<Predicate>,
175    /// Predicates that must be evaluated locally after fetching.
176    pub local: Vec<Predicate>,
177}
178
179/// Classify predicates as pushable or local based on source capabilities.
180///
181/// # Arguments
182///
183/// * `predicates` - The full set of predicates from the query plan
184/// * `capabilities` - What the lookup source supports
185///
186/// # Returns
187///
188/// A [`SplitPredicates`] with predicates partitioned into pushable and local.
189#[must_use]
190pub fn split_predicates(
191    predicates: Vec<Predicate>,
192    capabilities: &SourceCapabilities,
193) -> SplitPredicates {
194    let mut pushable = Vec::new();
195    let mut local = Vec::new();
196
197    for pred in predicates {
198        let can_push = match &pred {
199            Predicate::Eq { column, .. } => capabilities.eq_columns.iter().any(|c| c == column),
200            // NotEq cannot use equality indexes — a != b requires a full
201            // scan in most databases. Always evaluate locally.
202            Predicate::NotEq { .. } => false,
203            Predicate::Lt { column, .. }
204            | Predicate::LtEq { column, .. }
205            | Predicate::Gt { column, .. }
206            | Predicate::GtEq { column, .. } => {
207                capabilities.range_columns.iter().any(|c| c == column)
208            }
209            Predicate::In { column, .. } => capabilities.in_columns.iter().any(|c| c == column),
210            Predicate::IsNull { .. } | Predicate::IsNotNull { .. } => {
211                capabilities.supports_null_check
212            }
213        };
214
215        if can_push {
216            pushable.push(pred);
217        } else {
218            local.push(pred);
219        }
220    }
221
222    SplitPredicates { pushable, local }
223}
224
225/// Convert a predicate to a SQL WHERE clause fragment.
226///
227/// This is used by SQL-based lookup sources (Postgres/MySQL) to
228/// construct parameterized queries.
229///
230/// # Returns
231///
232/// A SQL string fragment like `"column = 42"` or `"column IN (1, 2, 3)"`.
233#[must_use]
234pub fn predicate_to_sql(predicate: &Predicate) -> String {
235    let q = |col: &str| col.replace('"', "\"\"");
236    match predicate {
237        Predicate::Eq { column, value } => format!("\"{}\" = {value}", q(column)),
238        Predicate::NotEq { column, value } => format!("\"{}\" != {value}", q(column)),
239        Predicate::Lt { column, value } => format!("\"{}\" < {value}", q(column)),
240        Predicate::LtEq { column, value } => format!("\"{}\" <= {value}", q(column)),
241        Predicate::Gt { column, value } => format!("\"{}\" > {value}", q(column)),
242        Predicate::GtEq { column, value } => format!("\"{}\" >= {value}", q(column)),
243        Predicate::In { column, values } => {
244            let vals: Vec<String> = values.iter().map(ToString::to_string).collect();
245            format!("\"{}\" IN ({})", q(column), vals.join(", "))
246        }
247        Predicate::IsNull { column } => format!("\"{}\" IS NULL", q(column)),
248        Predicate::IsNotNull { column } => format!("\"{}\" IS NOT NULL", q(column)),
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn test_scalar_value_display() {
258        assert_eq!(ScalarValue::Null.to_string(), "NULL");
259        assert_eq!(ScalarValue::Bool(true).to_string(), "true");
260        assert_eq!(ScalarValue::Int64(42).to_string(), "42");
261        assert_eq!(ScalarValue::Float64(1.23).to_string(), "1.23");
262        assert_eq!(ScalarValue::Utf8("hello".into()).to_string(), "'hello'");
263        assert_eq!(ScalarValue::Binary(vec![0xDE, 0xAD]).to_string(), "X'dead'");
264    }
265
266    #[test]
267    fn test_predicate_column() {
268        let pred = Predicate::Eq {
269            column: "id".into(),
270            value: ScalarValue::Int64(1),
271        };
272        assert_eq!(pred.column(), "id");
273
274        let pred = Predicate::IsNull {
275            column: "name".into(),
276        };
277        assert_eq!(pred.column(), "name");
278    }
279
280    #[test]
281    fn test_predicate_to_sql() {
282        assert_eq!(
283            predicate_to_sql(&Predicate::Eq {
284                column: "id".into(),
285                value: ScalarValue::Int64(42),
286            }),
287            "\"id\" = 42"
288        );
289
290        assert_eq!(
291            predicate_to_sql(&Predicate::In {
292                column: "status".into(),
293                values: vec![
294                    ScalarValue::Utf8("active".into()),
295                    ScalarValue::Utf8("pending".into()),
296                ],
297            }),
298            "\"status\" IN ('active', 'pending')"
299        );
300
301        // Reserved word column name
302        assert_eq!(
303            predicate_to_sql(&Predicate::Gt {
304                column: "order".into(),
305                value: ScalarValue::Int64(10),
306            }),
307            "\"order\" > 10"
308        );
309
310        assert_eq!(
311            predicate_to_sql(&Predicate::IsNull {
312                column: "deleted_at".into(),
313            }),
314            "\"deleted_at\" IS NULL"
315        );
316    }
317
318    #[test]
319    fn test_split_predicates() {
320        let capabilities = SourceCapabilities {
321            eq_columns: vec!["id".into(), "name".into()],
322            range_columns: vec!["created_at".into()],
323            in_columns: vec!["status".into()],
324            supports_null_check: false,
325        };
326
327        let predicates = vec![
328            Predicate::Eq {
329                column: "id".into(),
330                value: ScalarValue::Int64(1),
331            },
332            Predicate::Gt {
333                column: "created_at".into(),
334                value: ScalarValue::Timestamp(1_000_000),
335            },
336            Predicate::IsNull {
337                column: "deleted_at".into(),
338            },
339            Predicate::In {
340                column: "status".into(),
341                values: vec![ScalarValue::Utf8("active".into())],
342            },
343            // This Eq is on a non-pushable column
344            Predicate::Eq {
345                column: "region".into(),
346                value: ScalarValue::Utf8("us-east".into()),
347            },
348        ];
349
350        let split = split_predicates(predicates, &capabilities);
351        assert_eq!(split.pushable.len(), 3); // id=, created_at>, status IN
352        assert_eq!(split.local.len(), 2); // IS NULL (no null support), region=
353    }
354
355    #[test]
356    fn test_scalar_value_display_escapes_single_quotes() {
357        // SQL injection vector: O'Brien must become O''Brien
358        assert_eq!(
359            ScalarValue::Utf8("O'Brien".into()).to_string(),
360            "'O''Brien'"
361        );
362        // Double quotes are not special in SQL string literals
363        assert_eq!(
364            ScalarValue::Utf8(r#"say "hello""#.into()).to_string(),
365            r#"'say "hello"'"#
366        );
367        // Multiple consecutive single quotes
368        assert_eq!(ScalarValue::Utf8("it''s".into()).to_string(), "'it''''s'");
369        // Empty string
370        assert_eq!(ScalarValue::Utf8(String::new()).to_string(), "''");
371    }
372
373    #[test]
374    fn test_not_eq_never_pushed_down() {
375        let capabilities = SourceCapabilities {
376            eq_columns: vec!["id".into()],
377            range_columns: vec![],
378            in_columns: vec![],
379            supports_null_check: false,
380        };
381
382        let predicates = vec![
383            Predicate::Eq {
384                column: "id".into(),
385                value: ScalarValue::Int64(1),
386            },
387            Predicate::NotEq {
388                column: "id".into(),
389                value: ScalarValue::Int64(2),
390            },
391        ];
392
393        let split = split_predicates(predicates, &capabilities);
394        // Eq should be pushed, NotEq should stay local
395        assert_eq!(split.pushable.len(), 1);
396        assert!(matches!(&split.pushable[0], Predicate::Eq { .. }));
397        assert_eq!(split.local.len(), 1);
398        assert!(matches!(&split.local[0], Predicate::NotEq { .. }));
399    }
400
401    #[test]
402    fn test_split_predicates_empty_capabilities() {
403        let capabilities = SourceCapabilities::default();
404        let predicates = vec![Predicate::Eq {
405            column: "id".into(),
406            value: ScalarValue::Int64(1),
407        }];
408
409        let split = split_predicates(predicates, &capabilities);
410        assert!(split.pushable.is_empty());
411        assert_eq!(split.local.len(), 1);
412    }
413}