Skip to main content

laminar_core/lookup/
predicate.rs

1//! Predicate types for lookup source pushdown.
2
3use std::fmt;
4
5/// A scalar value used in predicate evaluation.
6///
7/// This enum covers the value types supported by lookup table predicates.
8/// It is intentionally kept small — only types that can be pushed down to
9/// external sources are included.
10#[derive(Debug, Clone, PartialEq)]
11pub enum ScalarValue {
12    /// SQL NULL
13    Null,
14    /// Boolean
15    Bool(bool),
16    /// 64-bit signed integer (covers i8/i16/i32/i64)
17    Int64(i64),
18    /// 64-bit float (covers f32/f64)
19    Float64(f64),
20    /// UTF-8 string
21    Utf8(String),
22    /// Raw binary data
23    Binary(Vec<u8>),
24    /// Timestamp as microseconds since Unix epoch
25    Timestamp(i64),
26}
27
28impl fmt::Display for ScalarValue {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        match self {
31            Self::Null => write!(f, "NULL"),
32            Self::Bool(v) => write!(f, "{v}"),
33            Self::Int64(v) => write!(f, "{v}"),
34            Self::Float64(v) => write!(f, "{v}"),
35            Self::Utf8(v) => {
36                // Escape single quotes to prevent SQL injection
37                write!(f, "'{}'", v.replace('\'', "''"))
38            }
39            Self::Binary(v) => write!(f, "X'{}'", hex_encode(v)),
40            Self::Timestamp(us) => write!(f, "TIMESTAMP '{us}'"),
41        }
42    }
43}
44
45/// Encode bytes as lowercase hex string.
46fn hex_encode(bytes: &[u8]) -> String {
47    use std::fmt::Write;
48    bytes
49        .iter()
50        .fold(String::with_capacity(bytes.len() * 2), |mut s, b| {
51            let _ = write!(s, "{b:02x}");
52            s
53        })
54}
55
56/// A filter predicate for lookup table queries.
57///
58/// Each variant maps directly to a SQL comparison operator. The `column`
59/// field names the lookup table column, and the value(s) come from the
60/// stream side of the join.
61#[derive(Debug, Clone, PartialEq)]
62pub enum Predicate {
63    /// `column = value`
64    Eq {
65        /// Column name
66        column: String,
67        /// Value to compare against
68        value: ScalarValue,
69    },
70    /// `column != value`
71    NotEq {
72        /// Column name
73        column: String,
74        /// Value to compare against
75        value: ScalarValue,
76    },
77    /// `column < value`
78    Lt {
79        /// Column name
80        column: String,
81        /// Value to compare against
82        value: ScalarValue,
83    },
84    /// `column <= value`
85    LtEq {
86        /// Column name
87        column: String,
88        /// Value to compare against
89        value: ScalarValue,
90    },
91    /// `column > value`
92    Gt {
93        /// Column name
94        column: String,
95        /// Value to compare against
96        value: ScalarValue,
97    },
98    /// `column >= value`
99    GtEq {
100        /// Column name
101        column: String,
102        /// Value to compare against
103        value: ScalarValue,
104    },
105    /// `column IN (values...)`
106    In {
107        /// Column name
108        column: String,
109        /// Set of values to match against
110        values: Vec<ScalarValue>,
111    },
112    /// `column IS NULL`
113    IsNull {
114        /// Column name
115        column: String,
116    },
117    /// `column IS NOT NULL`
118    IsNotNull {
119        /// Column name
120        column: String,
121    },
122}
123
124impl Predicate {
125    /// Returns the column name this predicate references.
126    #[must_use]
127    pub fn column(&self) -> &str {
128        match self {
129            Self::Eq { column, .. }
130            | Self::NotEq { column, .. }
131            | Self::Lt { column, .. }
132            | Self::LtEq { column, .. }
133            | Self::Gt { column, .. }
134            | Self::GtEq { column, .. }
135            | Self::In { column, .. }
136            | Self::IsNull { column }
137            | Self::IsNotNull { column } => column,
138        }
139    }
140}
141
142/// Capabilities that a lookup source declares for predicate pushdown.
143///
144/// Used by [`split_predicates`] to decide which predicates can be
145/// pushed to the source vs. evaluated locally.
146#[derive(Debug, Clone, Default)]
147pub struct SourceCapabilities {
148    /// Columns that support equality pushdown.
149    pub eq_columns: Vec<String>,
150    /// Columns that support range pushdown (`Lt`, `LtEq`, `Gt`, `GtEq`).
151    pub range_columns: Vec<String>,
152    /// Columns that support IN-list pushdown.
153    pub in_columns: Vec<String>,
154    /// Whether the source supports IS NULL / IS NOT NULL pushdown.
155    pub supports_null_check: bool,
156}
157
158/// Result of splitting predicates into pushable and local sets.
159#[derive(Debug, Clone)]
160pub struct SplitPredicates {
161    /// Predicates that can be pushed down to the source.
162    pub pushable: Vec<Predicate>,
163    /// Predicates that must be evaluated locally after fetching.
164    pub local: Vec<Predicate>,
165}
166
167/// Classify predicates as pushable or local based on source capabilities.
168///
169/// # Arguments
170///
171/// * `predicates` - The full set of predicates from the query plan
172/// * `capabilities` - What the lookup source supports
173///
174/// # Returns
175///
176/// A [`SplitPredicates`] with predicates partitioned into pushable and local.
177#[must_use]
178pub fn split_predicates(
179    predicates: Vec<Predicate>,
180    capabilities: &SourceCapabilities,
181) -> SplitPredicates {
182    let mut pushable = Vec::new();
183    let mut local = Vec::new();
184
185    for pred in predicates {
186        let can_push = match &pred {
187            Predicate::Eq { column, .. } => capabilities.eq_columns.iter().any(|c| c == column),
188            // NotEq cannot use equality indexes — a != b requires a full
189            // scan in most databases. Always evaluate locally.
190            Predicate::NotEq { .. } => false,
191            Predicate::Lt { column, .. }
192            | Predicate::LtEq { column, .. }
193            | Predicate::Gt { column, .. }
194            | Predicate::GtEq { column, .. } => {
195                capabilities.range_columns.iter().any(|c| c == column)
196            }
197            Predicate::In { column, .. } => capabilities.in_columns.iter().any(|c| c == column),
198            Predicate::IsNull { .. } | Predicate::IsNotNull { .. } => {
199                capabilities.supports_null_check
200            }
201        };
202
203        if can_push {
204            pushable.push(pred);
205        } else {
206            local.push(pred);
207        }
208    }
209
210    SplitPredicates { pushable, local }
211}
212
213/// Convert a predicate to a SQL WHERE clause fragment.
214///
215/// This is used by SQL-based lookup sources (Postgres/MySQL) to
216/// construct parameterized queries.
217///
218/// # Returns
219///
220/// A SQL string fragment like `"column = 42"` or `"column IN (1, 2, 3)"`.
221#[must_use]
222pub fn predicate_to_sql(predicate: &Predicate) -> String {
223    let q = |col: &str| col.replace('"', "\"\"");
224    match predicate {
225        Predicate::Eq { column, value } => format!("\"{}\" = {value}", q(column)),
226        Predicate::NotEq { column, value } => format!("\"{}\" != {value}", q(column)),
227        Predicate::Lt { column, value } => format!("\"{}\" < {value}", q(column)),
228        Predicate::LtEq { column, value } => format!("\"{}\" <= {value}", q(column)),
229        Predicate::Gt { column, value } => format!("\"{}\" > {value}", q(column)),
230        Predicate::GtEq { column, value } => format!("\"{}\" >= {value}", q(column)),
231        Predicate::In { column, values } => {
232            let vals: Vec<String> = values.iter().map(ToString::to_string).collect();
233            format!("\"{}\" IN ({})", q(column), vals.join(", "))
234        }
235        Predicate::IsNull { column } => format!("\"{}\" IS NULL", q(column)),
236        Predicate::IsNotNull { column } => format!("\"{}\" IS NOT NULL", q(column)),
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn test_scalar_value_display() {
246        assert_eq!(ScalarValue::Null.to_string(), "NULL");
247        assert_eq!(ScalarValue::Bool(true).to_string(), "true");
248        assert_eq!(ScalarValue::Int64(42).to_string(), "42");
249        assert_eq!(ScalarValue::Float64(1.23).to_string(), "1.23");
250        assert_eq!(ScalarValue::Utf8("hello".into()).to_string(), "'hello'");
251        assert_eq!(ScalarValue::Binary(vec![0xDE, 0xAD]).to_string(), "X'dead'");
252    }
253
254    #[test]
255    fn test_predicate_column() {
256        let pred = Predicate::Eq {
257            column: "id".into(),
258            value: ScalarValue::Int64(1),
259        };
260        assert_eq!(pred.column(), "id");
261
262        let pred = Predicate::IsNull {
263            column: "name".into(),
264        };
265        assert_eq!(pred.column(), "name");
266    }
267
268    #[test]
269    fn test_predicate_to_sql() {
270        assert_eq!(
271            predicate_to_sql(&Predicate::Eq {
272                column: "id".into(),
273                value: ScalarValue::Int64(42),
274            }),
275            "\"id\" = 42"
276        );
277
278        assert_eq!(
279            predicate_to_sql(&Predicate::In {
280                column: "status".into(),
281                values: vec![
282                    ScalarValue::Utf8("active".into()),
283                    ScalarValue::Utf8("pending".into()),
284                ],
285            }),
286            "\"status\" IN ('active', 'pending')"
287        );
288
289        // Reserved word column name
290        assert_eq!(
291            predicate_to_sql(&Predicate::Gt {
292                column: "order".into(),
293                value: ScalarValue::Int64(10),
294            }),
295            "\"order\" > 10"
296        );
297
298        assert_eq!(
299            predicate_to_sql(&Predicate::IsNull {
300                column: "deleted_at".into(),
301            }),
302            "\"deleted_at\" IS NULL"
303        );
304    }
305
306    #[test]
307    fn test_split_predicates() {
308        let capabilities = SourceCapabilities {
309            eq_columns: vec!["id".into(), "name".into()],
310            range_columns: vec!["created_at".into()],
311            in_columns: vec!["status".into()],
312            supports_null_check: false,
313        };
314
315        let predicates = vec![
316            Predicate::Eq {
317                column: "id".into(),
318                value: ScalarValue::Int64(1),
319            },
320            Predicate::Gt {
321                column: "created_at".into(),
322                value: ScalarValue::Timestamp(1_000_000),
323            },
324            Predicate::IsNull {
325                column: "deleted_at".into(),
326            },
327            Predicate::In {
328                column: "status".into(),
329                values: vec![ScalarValue::Utf8("active".into())],
330            },
331            // This Eq is on a non-pushable column
332            Predicate::Eq {
333                column: "region".into(),
334                value: ScalarValue::Utf8("us-east".into()),
335            },
336        ];
337
338        let split = split_predicates(predicates, &capabilities);
339        assert_eq!(split.pushable.len(), 3); // id=, created_at>, status IN
340        assert_eq!(split.local.len(), 2); // IS NULL (no null support), region=
341    }
342
343    #[test]
344    fn test_scalar_value_display_escapes_single_quotes() {
345        // SQL injection vector: O'Brien must become O''Brien
346        assert_eq!(
347            ScalarValue::Utf8("O'Brien".into()).to_string(),
348            "'O''Brien'"
349        );
350        // Double quotes are not special in SQL string literals
351        assert_eq!(
352            ScalarValue::Utf8(r#"say "hello""#.into()).to_string(),
353            r#"'say "hello"'"#
354        );
355        // Multiple consecutive single quotes
356        assert_eq!(ScalarValue::Utf8("it''s".into()).to_string(), "'it''''s'");
357        // Empty string
358        assert_eq!(ScalarValue::Utf8(String::new()).to_string(), "''");
359    }
360
361    #[test]
362    fn test_not_eq_never_pushed_down() {
363        let capabilities = SourceCapabilities {
364            eq_columns: vec!["id".into()],
365            range_columns: vec![],
366            in_columns: vec![],
367            supports_null_check: false,
368        };
369
370        let predicates = vec![
371            Predicate::Eq {
372                column: "id".into(),
373                value: ScalarValue::Int64(1),
374            },
375            Predicate::NotEq {
376                column: "id".into(),
377                value: ScalarValue::Int64(2),
378            },
379        ];
380
381        let split = split_predicates(predicates, &capabilities);
382        // Eq should be pushed, NotEq should stay local
383        assert_eq!(split.pushable.len(), 1);
384        assert!(matches!(&split.pushable[0], Predicate::Eq { .. }));
385        assert_eq!(split.local.len(), 1);
386        assert!(matches!(&split.local[0], Predicate::NotEq { .. }));
387    }
388
389    #[test]
390    fn test_split_predicates_empty_capabilities() {
391        let capabilities = SourceCapabilities::default();
392        let predicates = vec![Predicate::Eq {
393            column: "id".into(),
394            value: ScalarValue::Int64(1),
395        }];
396
397        let split = split_predicates(predicates, &capabilities);
398        assert!(split.pushable.is_empty());
399        assert_eq!(split.local.len(), 1);
400    }
401}