1use std::fmt;
4
5#[derive(Debug, Clone, PartialEq)]
11pub enum ScalarValue {
12 Null,
14 Bool(bool),
16 Int64(i64),
18 Float64(f64),
20 Utf8(String),
22 Binary(Vec<u8>),
24 Timestamp(i64),
26}
27
28impl fmt::Display for ScalarValue {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 match self {
31 Self::Null => write!(f, "NULL"),
32 Self::Bool(v) => write!(f, "{v}"),
33 Self::Int64(v) => write!(f, "{v}"),
34 Self::Float64(v) => write!(f, "{v}"),
35 Self::Utf8(v) => {
36 write!(f, "'{}'", v.replace('\'', "''"))
38 }
39 Self::Binary(v) => write!(f, "X'{}'", hex_encode(v)),
40 Self::Timestamp(us) => write!(f, "TIMESTAMP '{us}'"),
41 }
42 }
43}
44
45fn hex_encode(bytes: &[u8]) -> String {
47 use std::fmt::Write;
48 bytes
49 .iter()
50 .fold(String::with_capacity(bytes.len() * 2), |mut s, b| {
51 let _ = write!(s, "{b:02x}");
52 s
53 })
54}
55
56#[derive(Debug, Clone, PartialEq)]
62pub enum Predicate {
63 Eq {
65 column: String,
67 value: ScalarValue,
69 },
70 NotEq {
72 column: String,
74 value: ScalarValue,
76 },
77 Lt {
79 column: String,
81 value: ScalarValue,
83 },
84 LtEq {
86 column: String,
88 value: ScalarValue,
90 },
91 Gt {
93 column: String,
95 value: ScalarValue,
97 },
98 GtEq {
100 column: String,
102 value: ScalarValue,
104 },
105 In {
107 column: String,
109 values: Vec<ScalarValue>,
111 },
112 IsNull {
114 column: String,
116 },
117 IsNotNull {
119 column: String,
121 },
122}
123
124impl Predicate {
125 #[must_use]
127 pub fn column(&self) -> &str {
128 match self {
129 Self::Eq { column, .. }
130 | Self::NotEq { column, .. }
131 | Self::Lt { column, .. }
132 | Self::LtEq { column, .. }
133 | Self::Gt { column, .. }
134 | Self::GtEq { column, .. }
135 | Self::In { column, .. }
136 | Self::IsNull { column }
137 | Self::IsNotNull { column } => column,
138 }
139 }
140}
141
142#[derive(Debug, Clone, Default)]
147pub struct SourceCapabilities {
148 pub eq_columns: Vec<String>,
150 pub range_columns: Vec<String>,
152 pub in_columns: Vec<String>,
154 pub supports_null_check: bool,
156}
157
158#[derive(Debug, Clone)]
160pub struct SplitPredicates {
161 pub pushable: Vec<Predicate>,
163 pub local: Vec<Predicate>,
165}
166
167#[must_use]
178pub fn split_predicates(
179 predicates: Vec<Predicate>,
180 capabilities: &SourceCapabilities,
181) -> SplitPredicates {
182 let mut pushable = Vec::new();
183 let mut local = Vec::new();
184
185 for pred in predicates {
186 let can_push = match &pred {
187 Predicate::Eq { column, .. } => capabilities.eq_columns.iter().any(|c| c == column),
188 Predicate::NotEq { .. } => false,
191 Predicate::Lt { column, .. }
192 | Predicate::LtEq { column, .. }
193 | Predicate::Gt { column, .. }
194 | Predicate::GtEq { column, .. } => {
195 capabilities.range_columns.iter().any(|c| c == column)
196 }
197 Predicate::In { column, .. } => capabilities.in_columns.iter().any(|c| c == column),
198 Predicate::IsNull { .. } | Predicate::IsNotNull { .. } => {
199 capabilities.supports_null_check
200 }
201 };
202
203 if can_push {
204 pushable.push(pred);
205 } else {
206 local.push(pred);
207 }
208 }
209
210 SplitPredicates { pushable, local }
211}
212
213#[must_use]
222pub fn predicate_to_sql(predicate: &Predicate) -> String {
223 let q = |col: &str| col.replace('"', "\"\"");
224 match predicate {
225 Predicate::Eq { column, value } => format!("\"{}\" = {value}", q(column)),
226 Predicate::NotEq { column, value } => format!("\"{}\" != {value}", q(column)),
227 Predicate::Lt { column, value } => format!("\"{}\" < {value}", q(column)),
228 Predicate::LtEq { column, value } => format!("\"{}\" <= {value}", q(column)),
229 Predicate::Gt { column, value } => format!("\"{}\" > {value}", q(column)),
230 Predicate::GtEq { column, value } => format!("\"{}\" >= {value}", q(column)),
231 Predicate::In { column, values } => {
232 let vals: Vec<String> = values.iter().map(ToString::to_string).collect();
233 format!("\"{}\" IN ({})", q(column), vals.join(", "))
234 }
235 Predicate::IsNull { column } => format!("\"{}\" IS NULL", q(column)),
236 Predicate::IsNotNull { column } => format!("\"{}\" IS NOT NULL", q(column)),
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[test]
245 fn test_scalar_value_display() {
246 assert_eq!(ScalarValue::Null.to_string(), "NULL");
247 assert_eq!(ScalarValue::Bool(true).to_string(), "true");
248 assert_eq!(ScalarValue::Int64(42).to_string(), "42");
249 assert_eq!(ScalarValue::Float64(1.23).to_string(), "1.23");
250 assert_eq!(ScalarValue::Utf8("hello".into()).to_string(), "'hello'");
251 assert_eq!(ScalarValue::Binary(vec![0xDE, 0xAD]).to_string(), "X'dead'");
252 }
253
254 #[test]
255 fn test_predicate_column() {
256 let pred = Predicate::Eq {
257 column: "id".into(),
258 value: ScalarValue::Int64(1),
259 };
260 assert_eq!(pred.column(), "id");
261
262 let pred = Predicate::IsNull {
263 column: "name".into(),
264 };
265 assert_eq!(pred.column(), "name");
266 }
267
268 #[test]
269 fn test_predicate_to_sql() {
270 assert_eq!(
271 predicate_to_sql(&Predicate::Eq {
272 column: "id".into(),
273 value: ScalarValue::Int64(42),
274 }),
275 "\"id\" = 42"
276 );
277
278 assert_eq!(
279 predicate_to_sql(&Predicate::In {
280 column: "status".into(),
281 values: vec![
282 ScalarValue::Utf8("active".into()),
283 ScalarValue::Utf8("pending".into()),
284 ],
285 }),
286 "\"status\" IN ('active', 'pending')"
287 );
288
289 assert_eq!(
291 predicate_to_sql(&Predicate::Gt {
292 column: "order".into(),
293 value: ScalarValue::Int64(10),
294 }),
295 "\"order\" > 10"
296 );
297
298 assert_eq!(
299 predicate_to_sql(&Predicate::IsNull {
300 column: "deleted_at".into(),
301 }),
302 "\"deleted_at\" IS NULL"
303 );
304 }
305
306 #[test]
307 fn test_split_predicates() {
308 let capabilities = SourceCapabilities {
309 eq_columns: vec!["id".into(), "name".into()],
310 range_columns: vec!["created_at".into()],
311 in_columns: vec!["status".into()],
312 supports_null_check: false,
313 };
314
315 let predicates = vec![
316 Predicate::Eq {
317 column: "id".into(),
318 value: ScalarValue::Int64(1),
319 },
320 Predicate::Gt {
321 column: "created_at".into(),
322 value: ScalarValue::Timestamp(1_000_000),
323 },
324 Predicate::IsNull {
325 column: "deleted_at".into(),
326 },
327 Predicate::In {
328 column: "status".into(),
329 values: vec![ScalarValue::Utf8("active".into())],
330 },
331 Predicate::Eq {
333 column: "region".into(),
334 value: ScalarValue::Utf8("us-east".into()),
335 },
336 ];
337
338 let split = split_predicates(predicates, &capabilities);
339 assert_eq!(split.pushable.len(), 3); assert_eq!(split.local.len(), 2); }
342
343 #[test]
344 fn test_scalar_value_display_escapes_single_quotes() {
345 assert_eq!(
347 ScalarValue::Utf8("O'Brien".into()).to_string(),
348 "'O''Brien'"
349 );
350 assert_eq!(
352 ScalarValue::Utf8(r#"say "hello""#.into()).to_string(),
353 r#"'say "hello"'"#
354 );
355 assert_eq!(ScalarValue::Utf8("it''s".into()).to_string(), "'it''''s'");
357 assert_eq!(ScalarValue::Utf8(String::new()).to_string(), "''");
359 }
360
361 #[test]
362 fn test_not_eq_never_pushed_down() {
363 let capabilities = SourceCapabilities {
364 eq_columns: vec!["id".into()],
365 range_columns: vec![],
366 in_columns: vec![],
367 supports_null_check: false,
368 };
369
370 let predicates = vec![
371 Predicate::Eq {
372 column: "id".into(),
373 value: ScalarValue::Int64(1),
374 },
375 Predicate::NotEq {
376 column: "id".into(),
377 value: ScalarValue::Int64(2),
378 },
379 ];
380
381 let split = split_predicates(predicates, &capabilities);
382 assert_eq!(split.pushable.len(), 1);
384 assert!(matches!(&split.pushable[0], Predicate::Eq { .. }));
385 assert_eq!(split.local.len(), 1);
386 assert!(matches!(&split.local[0], Predicate::NotEq { .. }));
387 }
388
389 #[test]
390 fn test_split_predicates_empty_capabilities() {
391 let capabilities = SourceCapabilities::default();
392 let predicates = vec![Predicate::Eq {
393 column: "id".into(),
394 value: ScalarValue::Int64(1),
395 }];
396
397 let split = split_predicates(predicates, &capabilities);
398 assert!(split.pushable.is_empty());
399 assert_eq!(split.local.len(), 1);
400 }
401}