1use std::fmt;
16
17#[derive(Debug, Clone, PartialEq)]
23pub enum ScalarValue {
24 Null,
26 Bool(bool),
28 Int64(i64),
30 Float64(f64),
32 Utf8(String),
34 Binary(Vec<u8>),
36 Timestamp(i64),
38}
39
40impl fmt::Display for ScalarValue {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 match self {
43 Self::Null => write!(f, "NULL"),
44 Self::Bool(v) => write!(f, "{v}"),
45 Self::Int64(v) => write!(f, "{v}"),
46 Self::Float64(v) => write!(f, "{v}"),
47 Self::Utf8(v) => {
48 write!(f, "'{}'", v.replace('\'', "''"))
50 }
51 Self::Binary(v) => write!(f, "X'{}'", hex_encode(v)),
52 Self::Timestamp(us) => write!(f, "TIMESTAMP '{us}'"),
53 }
54 }
55}
56
57fn hex_encode(bytes: &[u8]) -> String {
59 use std::fmt::Write;
60 bytes
61 .iter()
62 .fold(String::with_capacity(bytes.len() * 2), |mut s, b| {
63 let _ = write!(s, "{b:02x}");
64 s
65 })
66}
67
68#[derive(Debug, Clone, PartialEq)]
74pub enum Predicate {
75 Eq {
77 column: String,
79 value: ScalarValue,
81 },
82 NotEq {
84 column: String,
86 value: ScalarValue,
88 },
89 Lt {
91 column: String,
93 value: ScalarValue,
95 },
96 LtEq {
98 column: String,
100 value: ScalarValue,
102 },
103 Gt {
105 column: String,
107 value: ScalarValue,
109 },
110 GtEq {
112 column: String,
114 value: ScalarValue,
116 },
117 In {
119 column: String,
121 values: Vec<ScalarValue>,
123 },
124 IsNull {
126 column: String,
128 },
129 IsNotNull {
131 column: String,
133 },
134}
135
136impl Predicate {
137 #[must_use]
139 pub fn column(&self) -> &str {
140 match self {
141 Self::Eq { column, .. }
142 | Self::NotEq { column, .. }
143 | Self::Lt { column, .. }
144 | Self::LtEq { column, .. }
145 | Self::Gt { column, .. }
146 | Self::GtEq { column, .. }
147 | Self::In { column, .. }
148 | Self::IsNull { column }
149 | Self::IsNotNull { column } => column,
150 }
151 }
152}
153
154#[derive(Debug, Clone, Default)]
159pub struct SourceCapabilities {
160 pub eq_columns: Vec<String>,
162 pub range_columns: Vec<String>,
164 pub in_columns: Vec<String>,
166 pub supports_null_check: bool,
168}
169
170#[derive(Debug, Clone)]
172pub struct SplitPredicates {
173 pub pushable: Vec<Predicate>,
175 pub local: Vec<Predicate>,
177}
178
179#[must_use]
190pub fn split_predicates(
191 predicates: Vec<Predicate>,
192 capabilities: &SourceCapabilities,
193) -> SplitPredicates {
194 let mut pushable = Vec::new();
195 let mut local = Vec::new();
196
197 for pred in predicates {
198 let can_push = match &pred {
199 Predicate::Eq { column, .. } => capabilities.eq_columns.iter().any(|c| c == column),
200 Predicate::NotEq { .. } => false,
203 Predicate::Lt { column, .. }
204 | Predicate::LtEq { column, .. }
205 | Predicate::Gt { column, .. }
206 | Predicate::GtEq { column, .. } => {
207 capabilities.range_columns.iter().any(|c| c == column)
208 }
209 Predicate::In { column, .. } => capabilities.in_columns.iter().any(|c| c == column),
210 Predicate::IsNull { .. } | Predicate::IsNotNull { .. } => {
211 capabilities.supports_null_check
212 }
213 };
214
215 if can_push {
216 pushable.push(pred);
217 } else {
218 local.push(pred);
219 }
220 }
221
222 SplitPredicates { pushable, local }
223}
224
225#[must_use]
234pub fn predicate_to_sql(predicate: &Predicate) -> String {
235 let q = |col: &str| col.replace('"', "\"\"");
236 match predicate {
237 Predicate::Eq { column, value } => format!("\"{}\" = {value}", q(column)),
238 Predicate::NotEq { column, value } => format!("\"{}\" != {value}", q(column)),
239 Predicate::Lt { column, value } => format!("\"{}\" < {value}", q(column)),
240 Predicate::LtEq { column, value } => format!("\"{}\" <= {value}", q(column)),
241 Predicate::Gt { column, value } => format!("\"{}\" > {value}", q(column)),
242 Predicate::GtEq { column, value } => format!("\"{}\" >= {value}", q(column)),
243 Predicate::In { column, values } => {
244 let vals: Vec<String> = values.iter().map(ToString::to_string).collect();
245 format!("\"{}\" IN ({})", q(column), vals.join(", "))
246 }
247 Predicate::IsNull { column } => format!("\"{}\" IS NULL", q(column)),
248 Predicate::IsNotNull { column } => format!("\"{}\" IS NOT NULL", q(column)),
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[test]
257 fn test_scalar_value_display() {
258 assert_eq!(ScalarValue::Null.to_string(), "NULL");
259 assert_eq!(ScalarValue::Bool(true).to_string(), "true");
260 assert_eq!(ScalarValue::Int64(42).to_string(), "42");
261 assert_eq!(ScalarValue::Float64(1.23).to_string(), "1.23");
262 assert_eq!(ScalarValue::Utf8("hello".into()).to_string(), "'hello'");
263 assert_eq!(ScalarValue::Binary(vec![0xDE, 0xAD]).to_string(), "X'dead'");
264 }
265
266 #[test]
267 fn test_predicate_column() {
268 let pred = Predicate::Eq {
269 column: "id".into(),
270 value: ScalarValue::Int64(1),
271 };
272 assert_eq!(pred.column(), "id");
273
274 let pred = Predicate::IsNull {
275 column: "name".into(),
276 };
277 assert_eq!(pred.column(), "name");
278 }
279
280 #[test]
281 fn test_predicate_to_sql() {
282 assert_eq!(
283 predicate_to_sql(&Predicate::Eq {
284 column: "id".into(),
285 value: ScalarValue::Int64(42),
286 }),
287 "\"id\" = 42"
288 );
289
290 assert_eq!(
291 predicate_to_sql(&Predicate::In {
292 column: "status".into(),
293 values: vec![
294 ScalarValue::Utf8("active".into()),
295 ScalarValue::Utf8("pending".into()),
296 ],
297 }),
298 "\"status\" IN ('active', 'pending')"
299 );
300
301 assert_eq!(
303 predicate_to_sql(&Predicate::Gt {
304 column: "order".into(),
305 value: ScalarValue::Int64(10),
306 }),
307 "\"order\" > 10"
308 );
309
310 assert_eq!(
311 predicate_to_sql(&Predicate::IsNull {
312 column: "deleted_at".into(),
313 }),
314 "\"deleted_at\" IS NULL"
315 );
316 }
317
318 #[test]
319 fn test_split_predicates() {
320 let capabilities = SourceCapabilities {
321 eq_columns: vec!["id".into(), "name".into()],
322 range_columns: vec!["created_at".into()],
323 in_columns: vec!["status".into()],
324 supports_null_check: false,
325 };
326
327 let predicates = vec![
328 Predicate::Eq {
329 column: "id".into(),
330 value: ScalarValue::Int64(1),
331 },
332 Predicate::Gt {
333 column: "created_at".into(),
334 value: ScalarValue::Timestamp(1_000_000),
335 },
336 Predicate::IsNull {
337 column: "deleted_at".into(),
338 },
339 Predicate::In {
340 column: "status".into(),
341 values: vec![ScalarValue::Utf8("active".into())],
342 },
343 Predicate::Eq {
345 column: "region".into(),
346 value: ScalarValue::Utf8("us-east".into()),
347 },
348 ];
349
350 let split = split_predicates(predicates, &capabilities);
351 assert_eq!(split.pushable.len(), 3); assert_eq!(split.local.len(), 2); }
354
355 #[test]
356 fn test_scalar_value_display_escapes_single_quotes() {
357 assert_eq!(
359 ScalarValue::Utf8("O'Brien".into()).to_string(),
360 "'O''Brien'"
361 );
362 assert_eq!(
364 ScalarValue::Utf8(r#"say "hello""#.into()).to_string(),
365 r#"'say "hello"'"#
366 );
367 assert_eq!(ScalarValue::Utf8("it''s".into()).to_string(), "'it''''s'");
369 assert_eq!(ScalarValue::Utf8(String::new()).to_string(), "''");
371 }
372
373 #[test]
374 fn test_not_eq_never_pushed_down() {
375 let capabilities = SourceCapabilities {
376 eq_columns: vec!["id".into()],
377 range_columns: vec![],
378 in_columns: vec![],
379 supports_null_check: false,
380 };
381
382 let predicates = vec![
383 Predicate::Eq {
384 column: "id".into(),
385 value: ScalarValue::Int64(1),
386 },
387 Predicate::NotEq {
388 column: "id".into(),
389 value: ScalarValue::Int64(2),
390 },
391 ];
392
393 let split = split_predicates(predicates, &capabilities);
394 assert_eq!(split.pushable.len(), 1);
396 assert!(matches!(&split.pushable[0], Predicate::Eq { .. }));
397 assert_eq!(split.local.len(), 1);
398 assert!(matches!(&split.local[0], Predicate::NotEq { .. }));
399 }
400
401 #[test]
402 fn test_split_predicates_empty_capabilities() {
403 let capabilities = SourceCapabilities::default();
404 let predicates = vec![Predicate::Eq {
405 column: "id".into(),
406 value: ScalarValue::Int64(1),
407 }];
408
409 let split = split_predicates(predicates, &capabilities);
410 assert!(split.pushable.is_empty());
411 assert_eq!(split.local.len(), 1);
412 }
413}