1use std::future::Future;
4use std::time::Duration;
5
6use arrow::compute::filter_record_batch;
7use arrow_array::BooleanArray;
8use arrow_array::RecordBatch;
9use arrow_schema::SchemaRef;
10
11use crate::lookup::predicate::{split_predicates, Predicate, ScalarValue, SourceCapabilities};
12
13pub type ColumnId = u32;
15
16#[derive(Debug, thiserror::Error)]
18pub enum LookupError {
19 #[error("connection failed: {0}")]
21 Connection(String),
22
23 #[error("query failed: {0}")]
25 Query(String),
26
27 #[error("timeout after {0:?}")]
29 Timeout(Duration),
30
31 #[error("not available: {0}")]
33 NotAvailable(String),
34
35 #[error("internal: {0}")]
37 Internal(String),
38}
39
40#[derive(Debug, Clone, Default)]
46pub struct LookupSourceCapabilities {
47 pub supports_predicate_pushdown: bool,
49 pub supports_projection_pushdown: bool,
51 pub supports_batch_lookup: bool,
53 pub max_batch_size: usize,
55}
56
57impl LookupSourceCapabilities {
58 #[must_use]
60 pub fn none() -> Self {
61 Self::default()
62 }
63}
64
65pub trait LookupSource: Send + Sync {
78 fn query(
84 &self,
85 keys: &[&[u8]],
86 predicates: &[Predicate],
87 projection: &[ColumnId],
88 ) -> impl Future<Output = Result<Vec<Option<RecordBatch>>, LookupError>> + Send;
89
90 fn capabilities(&self) -> LookupSourceCapabilities;
92
93 fn source_name(&self) -> &str;
95
96 fn schema(&self) -> SchemaRef;
98
99 fn estimated_row_count(&self) -> Option<u64> {
101 None
102 }
103
104 fn health_check(&self) -> impl Future<Output = Result<(), LookupError>> + Send {
106 async { Ok(()) }
107 }
108}
109
110#[async_trait::async_trait]
116pub trait LookupSourceDyn: Send + Sync {
117 async fn query_batch(
119 &self,
120 keys: &[&[u8]],
121 predicates: &[Predicate],
122 projection: &[ColumnId],
123 ) -> Result<Vec<Option<RecordBatch>>, LookupError>;
124
125 fn schema(&self) -> SchemaRef;
127}
128
129#[async_trait::async_trait]
130impl<T: LookupSource> LookupSourceDyn for T {
131 async fn query_batch(
132 &self,
133 keys: &[&[u8]],
134 predicates: &[Predicate],
135 projection: &[ColumnId],
136 ) -> Result<Vec<Option<RecordBatch>>, LookupError> {
137 self.query(keys, predicates, projection).await
138 }
139
140 fn schema(&self) -> SchemaRef {
141 LookupSource::schema(self)
142 }
143}
144
145pub struct PushdownAdapter<S> {
151 inner: S,
152 column_capabilities: SourceCapabilities,
153}
154
155impl<S: LookupSource> PushdownAdapter<S> {
156 pub fn new(inner: S, column_capabilities: SourceCapabilities) -> Self {
162 Self {
163 inner,
164 column_capabilities,
165 }
166 }
167
168 fn split(&self, predicates: &[Predicate]) -> (Vec<Predicate>, Vec<Predicate>) {
170 let split = split_predicates(predicates.to_vec(), &self.column_capabilities);
171 (split.pushable, split.local)
172 }
173}
174
175fn compare_column_scalar(
179 batch: &RecordBatch,
180 column: &str,
181 value: &ScalarValue,
182 cmp_fn: fn(
183 &dyn arrow_array::Datum,
184 &dyn arrow_array::Datum,
185 ) -> Result<BooleanArray, arrow::error::ArrowError>,
186) -> Option<BooleanArray> {
187 use arrow_array::types::{TimestampMicrosecondType, TimestampMillisecondType};
188 use arrow_array::{Float64Array, Int64Array, PrimitiveArray, Scalar, StringArray};
189
190 let idx = batch.schema().index_of(column).ok()?;
191 let col = batch.column(idx);
192 match value {
193 ScalarValue::Int64(v) => cmp_fn(col, &Scalar::new(Int64Array::from(vec![*v]))).ok(),
194 ScalarValue::Float64(v) => cmp_fn(col, &Scalar::new(Float64Array::from(vec![*v]))).ok(),
195 ScalarValue::Utf8(v) => cmp_fn(col, &Scalar::new(StringArray::from(vec![v.as_str()]))).ok(),
196 ScalarValue::Bool(v) => cmp_fn(col, &Scalar::new(BooleanArray::from(vec![*v]))).ok(),
197 ScalarValue::Timestamp(us) => {
198 if col
199 .as_any()
200 .is::<PrimitiveArray<TimestampMicrosecondType>>()
201 {
202 let scalar = PrimitiveArray::<TimestampMicrosecondType>::from(vec![*us]);
203 cmp_fn(col, &Scalar::new(scalar)).ok()
204 } else if col
205 .as_any()
206 .is::<PrimitiveArray<TimestampMillisecondType>>()
207 {
208 let scalar = PrimitiveArray::<TimestampMillisecondType>::from(vec![*us / 1000]);
209 cmp_fn(col, &Scalar::new(scalar)).ok()
210 } else {
211 None
212 }
213 }
214 _ => None,
215 }
216}
217
218fn evaluate_predicate(batch: &RecordBatch, predicate: &Predicate) -> Option<BooleanArray> {
220 use arrow::compute::kernels::cmp;
221
222 match predicate {
223 Predicate::Eq { column, value } => compare_column_scalar(batch, column, value, cmp::eq),
224 Predicate::NotEq { column, value } => compare_column_scalar(batch, column, value, cmp::neq),
225 Predicate::Lt { column, value } => compare_column_scalar(batch, column, value, cmp::lt),
226 Predicate::LtEq { column, value } => {
227 compare_column_scalar(batch, column, value, cmp::lt_eq)
228 }
229 Predicate::Gt { column, value } => compare_column_scalar(batch, column, value, cmp::gt),
230 Predicate::GtEq { column, value } => {
231 compare_column_scalar(batch, column, value, cmp::gt_eq)
232 }
233 Predicate::IsNull { column } => {
234 let idx = batch.schema().index_of(column).ok()?;
235 let col = batch.column(idx);
236 Some(arrow::compute::is_null(col).ok()?)
237 }
238 Predicate::IsNotNull { column } => {
239 let idx = batch.schema().index_of(column).ok()?;
240 let col = batch.column(idx);
241 Some(arrow::compute::is_not_null(col).ok()?)
242 }
243 Predicate::In { column, values } => {
244 let idx = batch.schema().index_of(column).ok()?;
245 let col = batch.column(idx);
246 let mut mask: Option<BooleanArray> = None;
247 for v in values {
248 let eq_mask = evaluate_predicate(
249 batch,
250 &Predicate::Eq {
251 column: column.clone(),
252 value: v.clone(),
253 },
254 )?;
255 mask = Some(match mask {
256 Some(existing) => arrow::compute::or(&existing, &eq_mask).ok()?,
257 None => eq_mask,
258 });
259 }
260 mask.or_else(|| Some(BooleanArray::from(vec![false; col.len()])))
261 }
262 }
263}
264
265fn apply_local_predicates(batch: &RecordBatch, predicates: &[Predicate]) -> Option<RecordBatch> {
267 if predicates.is_empty() {
268 return Some(batch.clone());
269 }
270 let mut combined: Option<BooleanArray> = None;
271 for pred in predicates {
272 let mask = evaluate_predicate(batch, pred)?;
273 combined = Some(match combined {
274 Some(existing) => arrow::compute::and(&existing, &mask).ok()?,
275 None => mask,
276 });
277 }
278 match combined {
279 Some(mask) => filter_record_batch(batch, &mask).ok(),
280 None => Some(batch.clone()),
281 }
282}
283
284impl<S: LookupSource> LookupSource for PushdownAdapter<S> {
285 async fn query(
286 &self,
287 keys: &[&[u8]],
288 predicates: &[Predicate],
289 projection: &[ColumnId],
290 ) -> Result<Vec<Option<RecordBatch>>, LookupError> {
291 let (pushable, local) = self.split(predicates);
292 let results = self.inner.query(keys, &pushable, projection).await?;
293
294 if local.is_empty() {
295 return Ok(results);
296 }
297
298 Ok(results
299 .into_iter()
300 .map(|opt| {
301 opt.and_then(|batch| {
302 let filtered = apply_local_predicates(&batch, &local)?;
303 if filtered.num_rows() == 0 {
304 None
305 } else {
306 Some(filtered)
307 }
308 })
309 })
310 .collect())
311 }
312
313 fn capabilities(&self) -> LookupSourceCapabilities {
314 self.inner.capabilities()
315 }
316
317 fn source_name(&self) -> &str {
318 self.inner.source_name()
319 }
320
321 fn schema(&self) -> SchemaRef {
322 self.inner.schema()
323 }
324
325 fn estimated_row_count(&self) -> Option<u64> {
326 self.inner.estimated_row_count()
327 }
328
329 fn health_check(&self) -> impl Future<Output = Result<(), LookupError>> + Send {
330 self.inner.health_check()
331 }
332}
333
334#[cfg(test)]
335#[allow(clippy::disallowed_types)] mod tests {
337 use super::*;
338 use arrow_array::{Int64Array, StringArray};
339 use arrow_schema::{DataType, Field, Schema};
340 use std::sync::Arc;
341
342 fn test_schema() -> SchemaRef {
343 Arc::new(Schema::new(vec![
344 Field::new("id", DataType::Int64, false),
345 Field::new("name", DataType::Utf8, false),
346 ]))
347 }
348
349 fn make_batch(id: i64, name: &str) -> RecordBatch {
350 RecordBatch::try_new(
351 test_schema(),
352 vec![
353 Arc::new(Int64Array::from(vec![id])),
354 Arc::new(StringArray::from(vec![name])),
355 ],
356 )
357 .unwrap()
358 }
359
360 struct InMemoryLookupSource {
362 data: std::collections::HashMap<Vec<u8>, RecordBatch>,
363 capabilities: LookupSourceCapabilities,
364 source_schema: SchemaRef,
365 }
366
367 impl InMemoryLookupSource {
368 fn new() -> Self {
369 Self {
370 data: std::collections::HashMap::new(),
371 capabilities: LookupSourceCapabilities::default(),
372 source_schema: test_schema(),
373 }
374 }
375
376 fn insert(&mut self, key: Vec<u8>, value: RecordBatch) {
377 self.data.insert(key, value);
378 }
379
380 fn with_capabilities(mut self, caps: LookupSourceCapabilities) -> Self {
381 self.capabilities = caps;
382 self
383 }
384 }
385
386 impl LookupSource for InMemoryLookupSource {
387 fn query(
388 &self,
389 keys: &[&[u8]],
390 _predicates: &[Predicate],
391 _projection: &[ColumnId],
392 ) -> impl Future<Output = Result<Vec<Option<RecordBatch>>, LookupError>> + Send {
393 let results: Vec<Option<RecordBatch>> = keys
394 .iter()
395 .map(|k| self.data.get::<[u8]>(k.as_ref()).cloned())
396 .collect();
397 async move { Ok(results) }
398 }
399
400 fn capabilities(&self) -> LookupSourceCapabilities {
401 self.capabilities.clone()
402 }
403
404 fn source_name(&self) -> &'static str {
405 "in_memory_test"
406 }
407
408 fn schema(&self) -> SchemaRef {
409 Arc::clone(&self.source_schema)
410 }
411
412 fn estimated_row_count(&self) -> Option<u64> {
413 Some(self.data.len() as u64)
414 }
415 }
416
417 #[tokio::test]
418 async fn test_query_result_aligned_with_keys() {
419 let mut source = InMemoryLookupSource::new();
420 source.insert(b"k1".to_vec(), make_batch(1, "Alice"));
421 source.insert(b"k3".to_vec(), make_batch(3, "Carol"));
422
423 let keys: Vec<&[u8]> = vec![b"k1", b"k2", b"k3"];
424 let results = source.query(&keys, &[], &[]).await.unwrap();
425
426 assert_eq!(results.len(), keys.len());
427 assert!(results[0].is_some());
428 assert!(results[1].is_none());
429 assert!(results[2].is_some());
430 }
431
432 #[tokio::test]
433 async fn test_pushdown_adapter_splits_predicates() {
434 let mut source = InMemoryLookupSource::new();
435 source.insert(b"k1".to_vec(), make_batch(1, "Alice"));
436
437 let caps = SourceCapabilities {
438 eq_columns: vec!["id".into()],
439 range_columns: vec![],
440 in_columns: vec![],
441 supports_null_check: false,
442 };
443
444 let adapter = PushdownAdapter::new(
445 source.with_capabilities(LookupSourceCapabilities {
446 supports_predicate_pushdown: true,
447 ..Default::default()
448 }),
449 caps,
450 );
451
452 let predicates = vec![
453 Predicate::Eq {
454 column: "id".into(),
455 value: crate::lookup::ScalarValue::Int64(1),
456 },
457 Predicate::NotEq {
458 column: "id".into(),
459 value: crate::lookup::ScalarValue::Int64(2),
460 },
461 ];
462
463 let (pushable, local) = adapter.split(&predicates);
464 assert_eq!(pushable.len(), 1); assert_eq!(local.len(), 1); let keys: Vec<&[u8]> = vec![b"k1"];
468 let results = adapter.query(&keys, &predicates, &[]).await.unwrap();
469 assert_eq!(results.len(), 1);
470 assert!(results[0].is_some());
471 }
472
473 #[tokio::test]
474 async fn test_pushdown_adapter_local_predicate_filters() {
475 let mut source = InMemoryLookupSource::new();
476 source.insert(b"k1".to_vec(), make_batch(1, "Alice"));
477 source.insert(b"k2".to_vec(), make_batch(2, "Bob"));
478
479 let caps = SourceCapabilities {
480 eq_columns: vec![],
481 range_columns: vec![],
482 in_columns: vec![],
483 supports_null_check: false,
484 };
485
486 let adapter = PushdownAdapter::new(source, caps);
487
488 let predicates = vec![Predicate::Gt {
490 column: "id".into(),
491 value: ScalarValue::Int64(1),
492 }];
493
494 let keys: Vec<&[u8]> = vec![b"k1", b"k2"];
495 let results = adapter.query(&keys, &predicates, &[]).await.unwrap();
496 assert_eq!(results.len(), 2);
497 assert!(results[0].is_none()); assert!(results[1].is_some()); }
500
501 #[tokio::test]
502 async fn test_pushdown_adapter_not_eq_local_evaluation() {
503 let mut source = InMemoryLookupSource::new();
504 source.insert(b"k1".to_vec(), make_batch(1, "Alice"));
505 source.insert(b"k2".to_vec(), make_batch(2, "Bob"));
506
507 let caps = SourceCapabilities {
508 eq_columns: vec!["id".into()],
509 range_columns: vec![],
510 in_columns: vec![],
511 supports_null_check: false,
512 };
513
514 let adapter = PushdownAdapter::new(
515 source.with_capabilities(LookupSourceCapabilities {
516 supports_predicate_pushdown: true,
517 ..Default::default()
518 }),
519 caps,
520 );
521
522 let predicates = vec![Predicate::NotEq {
524 column: "id".into(),
525 value: ScalarValue::Int64(1),
526 }];
527
528 let keys: Vec<&[u8]> = vec![b"k1", b"k2"];
529 let results = adapter.query(&keys, &predicates, &[]).await.unwrap();
530 assert_eq!(results.len(), 2);
531 assert!(results[0].is_none()); assert!(results[1].is_some()); }
534
535 #[tokio::test]
536 async fn test_mock_source_batch_chunking() {
537 let mut source = InMemoryLookupSource::new();
538 for i in 0..10u8 {
539 source.insert(vec![i], make_batch(i64::from(i), &format!("name_{i}")));
540 }
541
542 let caps = LookupSourceCapabilities {
543 max_batch_size: 3,
544 supports_batch_lookup: true,
545 ..Default::default()
546 };
547 let source = source.with_capabilities(caps);
548
549 let keys: Vec<Vec<u8>> = (0..10u8).map(|i| vec![i]).collect();
550 let key_refs: Vec<&[u8]> = keys.iter().map(Vec::as_slice).collect();
551
552 let max = source.capabilities().max_batch_size;
553 let mut all_results = Vec::new();
554 for chunk in key_refs.chunks(max) {
555 let chunk_results = source.query(chunk, &[], &[]).await.unwrap();
556 all_results.extend(chunk_results);
557 }
558
559 assert_eq!(all_results.len(), 10);
560 for result in &all_results {
561 assert!(result.is_some());
562 }
563 }
564
565 #[tokio::test]
566 async fn test_health_check_default() {
567 let source = InMemoryLookupSource::new();
568 assert!(source.health_check().await.is_ok());
569 }
570
571 #[test]
572 fn test_estimated_row_count() {
573 let mut source = InMemoryLookupSource::new();
574 assert_eq!(source.estimated_row_count(), Some(0));
575 source.insert(b"k1".to_vec(), make_batch(1, "Alice"));
576 assert_eq!(source.estimated_row_count(), Some(1));
577 }
578
579 #[test]
580 fn test_capabilities_default() {
581 let caps = LookupSourceCapabilities::default();
582 assert!(!caps.supports_predicate_pushdown);
583 assert!(!caps.supports_projection_pushdown);
584 assert!(!caps.supports_batch_lookup);
585 assert_eq!(caps.max_batch_size, 0);
586 }
587
588 #[test]
589 fn test_schema_propagation() {
590 let source = InMemoryLookupSource::new();
591 let schema = LookupSource::schema(&source);
592 assert_eq!(schema.fields().len(), 2);
593 assert_eq!(schema.field(0).name(), "id");
594 assert_eq!(schema.field(1).name(), "name");
595 }
596
597 #[test]
598 fn test_pushdown_adapter_schema_propagation() {
599 let source = InMemoryLookupSource::new();
600 let caps = SourceCapabilities {
601 eq_columns: vec![],
602 range_columns: vec![],
603 in_columns: vec![],
604 supports_null_check: false,
605 };
606 let adapter = PushdownAdapter::new(source, caps);
607 let schema = LookupSource::schema(&adapter);
608 assert_eq!(schema.fields().len(), 2);
609 }
610
611 #[test]
612 fn test_evaluate_predicate_is_null() {
613 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, true)]));
614 let batch = RecordBatch::try_new(
615 schema,
616 vec![Arc::new(Int64Array::from(vec![Some(1), None, Some(3)]))],
617 )
618 .unwrap();
619
620 let pred = Predicate::IsNull {
621 column: "id".into(),
622 };
623 let mask = evaluate_predicate(&batch, &pred).unwrap();
624 assert!(!mask.value(0));
625 assert!(mask.value(1));
626 assert!(!mask.value(2));
627 }
628
629 #[test]
630 fn test_evaluate_predicate_in_list() {
631 let schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, false)]));
632 let batch = RecordBatch::try_new(
633 schema,
634 vec![Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol"]))],
635 )
636 .unwrap();
637
638 let pred = Predicate::In {
639 column: "name".into(),
640 values: vec![
641 ScalarValue::Utf8("Alice".into()),
642 ScalarValue::Utf8("Carol".into()),
643 ],
644 };
645 let mask = evaluate_predicate(&batch, &pred).unwrap();
646 assert!(mask.value(0));
647 assert!(!mask.value(1));
648 assert!(mask.value(2));
649 }
650
651 #[test]
652 fn test_evaluate_predicate_timestamp_microsecond() {
653 use arrow_array::types::TimestampMicrosecondType;
654 use arrow_array::PrimitiveArray;
655
656 let schema = Arc::new(Schema::new(vec![Field::new(
657 "ts",
658 DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, None),
659 false,
660 )]));
661 let ts_arr: PrimitiveArray<TimestampMicrosecondType> =
662 vec![1_000_000i64, 2_000_000, 3_000_000].into();
663 let batch = RecordBatch::try_new(schema, vec![Arc::new(ts_arr)]).unwrap();
664
665 let pred = Predicate::Eq {
666 column: "ts".into(),
667 value: ScalarValue::Timestamp(2_000_000),
668 };
669 let mask = evaluate_predicate(&batch, &pred).unwrap();
670 assert!(!mask.value(0));
671 assert!(mask.value(1));
672 assert!(!mask.value(2));
673 }
674
675 #[test]
676 fn test_evaluate_predicate_timestamp_millisecond() {
677 use arrow_array::types::TimestampMillisecondType;
678 use arrow_array::PrimitiveArray;
679
680 let schema = Arc::new(Schema::new(vec![Field::new(
681 "ts",
682 DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None),
683 false,
684 )]));
685 let ts_arr: PrimitiveArray<TimestampMillisecondType> = vec![1_000i64, 2_000, 3_000].into();
687 let batch = RecordBatch::try_new(schema, vec![Arc::new(ts_arr)]).unwrap();
688
689 let pred = Predicate::Gt {
691 column: "ts".into(),
692 value: ScalarValue::Timestamp(2_000_000),
693 };
694 let mask = evaluate_predicate(&batch, &pred).unwrap();
695 assert!(!mask.value(0)); assert!(!mask.value(1)); assert!(mask.value(2)); }
699}