1use std::future::Future;
15use std::time::Duration;
16
17use arrow::compute::filter_record_batch;
18use arrow_array::BooleanArray;
19use arrow_array::RecordBatch;
20use arrow_schema::SchemaRef;
21
22use crate::lookup::predicate::{split_predicates, Predicate, ScalarValue, SourceCapabilities};
23
24pub type ColumnId = u32;
26
27#[derive(Debug, thiserror::Error)]
29pub enum LookupError {
30 #[error("connection failed: {0}")]
32 Connection(String),
33
34 #[error("query failed: {0}")]
36 Query(String),
37
38 #[error("timeout after {0:?}")]
40 Timeout(Duration),
41
42 #[error("not available: {0}")]
44 NotAvailable(String),
45
46 #[error("internal: {0}")]
48 Internal(String),
49}
50
51#[derive(Debug, Clone, Default)]
57pub struct LookupSourceCapabilities {
58 pub supports_predicate_pushdown: bool,
60 pub supports_projection_pushdown: bool,
62 pub supports_batch_lookup: bool,
64 pub max_batch_size: usize,
66}
67
68impl LookupSourceCapabilities {
69 #[must_use]
71 pub fn none() -> Self {
72 Self::default()
73 }
74}
75
76pub trait LookupSource: Send + Sync {
89 fn query(
95 &self,
96 keys: &[&[u8]],
97 predicates: &[Predicate],
98 projection: &[ColumnId],
99 ) -> impl Future<Output = Result<Vec<Option<RecordBatch>>, LookupError>> + Send;
100
101 fn capabilities(&self) -> LookupSourceCapabilities;
103
104 fn source_name(&self) -> &str;
106
107 fn schema(&self) -> SchemaRef;
109
110 fn estimated_row_count(&self) -> Option<u64> {
112 None
113 }
114
115 fn health_check(&self) -> impl Future<Output = Result<(), LookupError>> + Send {
117 async { Ok(()) }
118 }
119}
120
121#[async_trait::async_trait]
127pub trait LookupSourceDyn: Send + Sync {
128 async fn query_batch(
130 &self,
131 keys: &[&[u8]],
132 predicates: &[Predicate],
133 projection: &[ColumnId],
134 ) -> Result<Vec<Option<RecordBatch>>, LookupError>;
135
136 fn schema(&self) -> SchemaRef;
138}
139
140#[async_trait::async_trait]
141impl<T: LookupSource> LookupSourceDyn for T {
142 async fn query_batch(
143 &self,
144 keys: &[&[u8]],
145 predicates: &[Predicate],
146 projection: &[ColumnId],
147 ) -> Result<Vec<Option<RecordBatch>>, LookupError> {
148 self.query(keys, predicates, projection).await
149 }
150
151 fn schema(&self) -> SchemaRef {
152 LookupSource::schema(self)
153 }
154}
155
156pub struct PushdownAdapter<S> {
162 inner: S,
163 column_capabilities: SourceCapabilities,
164}
165
166impl<S: LookupSource> PushdownAdapter<S> {
167 pub fn new(inner: S, column_capabilities: SourceCapabilities) -> Self {
173 Self {
174 inner,
175 column_capabilities,
176 }
177 }
178
179 fn split(&self, predicates: &[Predicate]) -> (Vec<Predicate>, Vec<Predicate>) {
181 let split = split_predicates(predicates.to_vec(), &self.column_capabilities);
182 (split.pushable, split.local)
183 }
184}
185
186fn compare_column_scalar(
190 batch: &RecordBatch,
191 column: &str,
192 value: &ScalarValue,
193 cmp_fn: fn(
194 &dyn arrow_array::Datum,
195 &dyn arrow_array::Datum,
196 ) -> Result<BooleanArray, arrow::error::ArrowError>,
197) -> Option<BooleanArray> {
198 use arrow_array::types::{TimestampMicrosecondType, TimestampMillisecondType};
199 use arrow_array::{Float64Array, Int64Array, PrimitiveArray, Scalar, StringArray};
200
201 let idx = batch.schema().index_of(column).ok()?;
202 let col = batch.column(idx);
203 match value {
204 ScalarValue::Int64(v) => cmp_fn(col, &Scalar::new(Int64Array::from(vec![*v]))).ok(),
205 ScalarValue::Float64(v) => cmp_fn(col, &Scalar::new(Float64Array::from(vec![*v]))).ok(),
206 ScalarValue::Utf8(v) => cmp_fn(col, &Scalar::new(StringArray::from(vec![v.as_str()]))).ok(),
207 ScalarValue::Bool(v) => cmp_fn(col, &Scalar::new(BooleanArray::from(vec![*v]))).ok(),
208 ScalarValue::Timestamp(us) => {
209 if col
210 .as_any()
211 .is::<PrimitiveArray<TimestampMicrosecondType>>()
212 {
213 let scalar = PrimitiveArray::<TimestampMicrosecondType>::from(vec![*us]);
214 cmp_fn(col, &Scalar::new(scalar)).ok()
215 } else if col
216 .as_any()
217 .is::<PrimitiveArray<TimestampMillisecondType>>()
218 {
219 let scalar = PrimitiveArray::<TimestampMillisecondType>::from(vec![*us / 1000]);
220 cmp_fn(col, &Scalar::new(scalar)).ok()
221 } else {
222 None
223 }
224 }
225 _ => None,
226 }
227}
228
229fn evaluate_predicate(batch: &RecordBatch, predicate: &Predicate) -> Option<BooleanArray> {
231 use arrow::compute::kernels::cmp;
232
233 match predicate {
234 Predicate::Eq { column, value } => compare_column_scalar(batch, column, value, cmp::eq),
235 Predicate::NotEq { column, value } => compare_column_scalar(batch, column, value, cmp::neq),
236 Predicate::Lt { column, value } => compare_column_scalar(batch, column, value, cmp::lt),
237 Predicate::LtEq { column, value } => {
238 compare_column_scalar(batch, column, value, cmp::lt_eq)
239 }
240 Predicate::Gt { column, value } => compare_column_scalar(batch, column, value, cmp::gt),
241 Predicate::GtEq { column, value } => {
242 compare_column_scalar(batch, column, value, cmp::gt_eq)
243 }
244 Predicate::IsNull { column } => {
245 let idx = batch.schema().index_of(column).ok()?;
246 let col = batch.column(idx);
247 Some(arrow::compute::is_null(col).ok()?)
248 }
249 Predicate::IsNotNull { column } => {
250 let idx = batch.schema().index_of(column).ok()?;
251 let col = batch.column(idx);
252 Some(arrow::compute::is_not_null(col).ok()?)
253 }
254 Predicate::In { column, values } => {
255 let idx = batch.schema().index_of(column).ok()?;
256 let col = batch.column(idx);
257 let mut mask: Option<BooleanArray> = None;
258 for v in values {
259 let eq_mask = evaluate_predicate(
260 batch,
261 &Predicate::Eq {
262 column: column.clone(),
263 value: v.clone(),
264 },
265 )?;
266 mask = Some(match mask {
267 Some(existing) => arrow::compute::or(&existing, &eq_mask).ok()?,
268 None => eq_mask,
269 });
270 }
271 mask.or_else(|| Some(BooleanArray::from(vec![false; col.len()])))
272 }
273 }
274}
275
276fn apply_local_predicates(batch: &RecordBatch, predicates: &[Predicate]) -> Option<RecordBatch> {
278 if predicates.is_empty() {
279 return Some(batch.clone());
280 }
281 let mut combined: Option<BooleanArray> = None;
282 for pred in predicates {
283 let mask = evaluate_predicate(batch, pred)?;
284 combined = Some(match combined {
285 Some(existing) => arrow::compute::and(&existing, &mask).ok()?,
286 None => mask,
287 });
288 }
289 match combined {
290 Some(mask) => filter_record_batch(batch, &mask).ok(),
291 None => Some(batch.clone()),
292 }
293}
294
295impl<S: LookupSource> LookupSource for PushdownAdapter<S> {
296 async fn query(
297 &self,
298 keys: &[&[u8]],
299 predicates: &[Predicate],
300 projection: &[ColumnId],
301 ) -> Result<Vec<Option<RecordBatch>>, LookupError> {
302 let (pushable, local) = self.split(predicates);
303 let results = self.inner.query(keys, &pushable, projection).await?;
304
305 if local.is_empty() {
306 return Ok(results);
307 }
308
309 Ok(results
310 .into_iter()
311 .map(|opt| {
312 opt.and_then(|batch| {
313 let filtered = apply_local_predicates(&batch, &local)?;
314 if filtered.num_rows() == 0 {
315 None
316 } else {
317 Some(filtered)
318 }
319 })
320 })
321 .collect())
322 }
323
324 fn capabilities(&self) -> LookupSourceCapabilities {
325 self.inner.capabilities()
326 }
327
328 fn source_name(&self) -> &str {
329 self.inner.source_name()
330 }
331
332 fn schema(&self) -> SchemaRef {
333 self.inner.schema()
334 }
335
336 fn estimated_row_count(&self) -> Option<u64> {
337 self.inner.estimated_row_count()
338 }
339
340 fn health_check(&self) -> impl Future<Output = Result<(), LookupError>> + Send {
341 self.inner.health_check()
342 }
343}
344
345#[cfg(test)]
346#[allow(clippy::disallowed_types)] mod tests {
348 use super::*;
349 use arrow_array::{Int64Array, StringArray};
350 use arrow_schema::{DataType, Field, Schema};
351 use std::sync::Arc;
352
353 fn test_schema() -> SchemaRef {
354 Arc::new(Schema::new(vec![
355 Field::new("id", DataType::Int64, false),
356 Field::new("name", DataType::Utf8, false),
357 ]))
358 }
359
360 fn make_batch(id: i64, name: &str) -> RecordBatch {
361 RecordBatch::try_new(
362 test_schema(),
363 vec![
364 Arc::new(Int64Array::from(vec![id])),
365 Arc::new(StringArray::from(vec![name])),
366 ],
367 )
368 .unwrap()
369 }
370
371 struct InMemoryLookupSource {
373 data: std::collections::HashMap<Vec<u8>, RecordBatch>,
374 capabilities: LookupSourceCapabilities,
375 source_schema: SchemaRef,
376 }
377
378 impl InMemoryLookupSource {
379 fn new() -> Self {
380 Self {
381 data: std::collections::HashMap::new(),
382 capabilities: LookupSourceCapabilities::default(),
383 source_schema: test_schema(),
384 }
385 }
386
387 fn insert(&mut self, key: Vec<u8>, value: RecordBatch) {
388 self.data.insert(key, value);
389 }
390
391 fn with_capabilities(mut self, caps: LookupSourceCapabilities) -> Self {
392 self.capabilities = caps;
393 self
394 }
395 }
396
397 impl LookupSource for InMemoryLookupSource {
398 fn query(
399 &self,
400 keys: &[&[u8]],
401 _predicates: &[Predicate],
402 _projection: &[ColumnId],
403 ) -> impl Future<Output = Result<Vec<Option<RecordBatch>>, LookupError>> + Send {
404 let results: Vec<Option<RecordBatch>> = keys
405 .iter()
406 .map(|k| self.data.get::<[u8]>(k.as_ref()).cloned())
407 .collect();
408 async move { Ok(results) }
409 }
410
411 fn capabilities(&self) -> LookupSourceCapabilities {
412 self.capabilities.clone()
413 }
414
415 fn source_name(&self) -> &'static str {
416 "in_memory_test"
417 }
418
419 fn schema(&self) -> SchemaRef {
420 Arc::clone(&self.source_schema)
421 }
422
423 fn estimated_row_count(&self) -> Option<u64> {
424 Some(self.data.len() as u64)
425 }
426 }
427
428 #[tokio::test]
429 async fn test_query_result_aligned_with_keys() {
430 let mut source = InMemoryLookupSource::new();
431 source.insert(b"k1".to_vec(), make_batch(1, "Alice"));
432 source.insert(b"k3".to_vec(), make_batch(3, "Carol"));
433
434 let keys: Vec<&[u8]> = vec![b"k1", b"k2", b"k3"];
435 let results = source.query(&keys, &[], &[]).await.unwrap();
436
437 assert_eq!(results.len(), keys.len());
438 assert!(results[0].is_some());
439 assert!(results[1].is_none());
440 assert!(results[2].is_some());
441 }
442
443 #[tokio::test]
444 async fn test_pushdown_adapter_splits_predicates() {
445 let mut source = InMemoryLookupSource::new();
446 source.insert(b"k1".to_vec(), make_batch(1, "Alice"));
447
448 let caps = SourceCapabilities {
449 eq_columns: vec!["id".into()],
450 range_columns: vec![],
451 in_columns: vec![],
452 supports_null_check: false,
453 };
454
455 let adapter = PushdownAdapter::new(
456 source.with_capabilities(LookupSourceCapabilities {
457 supports_predicate_pushdown: true,
458 ..Default::default()
459 }),
460 caps,
461 );
462
463 let predicates = vec![
464 Predicate::Eq {
465 column: "id".into(),
466 value: crate::lookup::ScalarValue::Int64(1),
467 },
468 Predicate::NotEq {
469 column: "id".into(),
470 value: crate::lookup::ScalarValue::Int64(2),
471 },
472 ];
473
474 let (pushable, local) = adapter.split(&predicates);
475 assert_eq!(pushable.len(), 1); assert_eq!(local.len(), 1); let keys: Vec<&[u8]> = vec![b"k1"];
479 let results = adapter.query(&keys, &predicates, &[]).await.unwrap();
480 assert_eq!(results.len(), 1);
481 assert!(results[0].is_some());
482 }
483
484 #[tokio::test]
485 async fn test_pushdown_adapter_local_predicate_filters() {
486 let mut source = InMemoryLookupSource::new();
487 source.insert(b"k1".to_vec(), make_batch(1, "Alice"));
488 source.insert(b"k2".to_vec(), make_batch(2, "Bob"));
489
490 let caps = SourceCapabilities {
491 eq_columns: vec![],
492 range_columns: vec![],
493 in_columns: vec![],
494 supports_null_check: false,
495 };
496
497 let adapter = PushdownAdapter::new(source, caps);
498
499 let predicates = vec![Predicate::Gt {
501 column: "id".into(),
502 value: ScalarValue::Int64(1),
503 }];
504
505 let keys: Vec<&[u8]> = vec![b"k1", b"k2"];
506 let results = adapter.query(&keys, &predicates, &[]).await.unwrap();
507 assert_eq!(results.len(), 2);
508 assert!(results[0].is_none()); assert!(results[1].is_some()); }
511
512 #[tokio::test]
513 async fn test_pushdown_adapter_not_eq_local_evaluation() {
514 let mut source = InMemoryLookupSource::new();
515 source.insert(b"k1".to_vec(), make_batch(1, "Alice"));
516 source.insert(b"k2".to_vec(), make_batch(2, "Bob"));
517
518 let caps = SourceCapabilities {
519 eq_columns: vec!["id".into()],
520 range_columns: vec![],
521 in_columns: vec![],
522 supports_null_check: false,
523 };
524
525 let adapter = PushdownAdapter::new(
526 source.with_capabilities(LookupSourceCapabilities {
527 supports_predicate_pushdown: true,
528 ..Default::default()
529 }),
530 caps,
531 );
532
533 let predicates = vec![Predicate::NotEq {
535 column: "id".into(),
536 value: ScalarValue::Int64(1),
537 }];
538
539 let keys: Vec<&[u8]> = vec![b"k1", b"k2"];
540 let results = adapter.query(&keys, &predicates, &[]).await.unwrap();
541 assert_eq!(results.len(), 2);
542 assert!(results[0].is_none()); assert!(results[1].is_some()); }
545
546 #[tokio::test]
547 async fn test_mock_source_batch_chunking() {
548 let mut source = InMemoryLookupSource::new();
549 for i in 0..10u8 {
550 source.insert(vec![i], make_batch(i64::from(i), &format!("name_{i}")));
551 }
552
553 let caps = LookupSourceCapabilities {
554 max_batch_size: 3,
555 supports_batch_lookup: true,
556 ..Default::default()
557 };
558 let source = source.with_capabilities(caps);
559
560 let keys: Vec<Vec<u8>> = (0..10u8).map(|i| vec![i]).collect();
561 let key_refs: Vec<&[u8]> = keys.iter().map(Vec::as_slice).collect();
562
563 let max = source.capabilities().max_batch_size;
564 let mut all_results = Vec::new();
565 for chunk in key_refs.chunks(max) {
566 let chunk_results = source.query(chunk, &[], &[]).await.unwrap();
567 all_results.extend(chunk_results);
568 }
569
570 assert_eq!(all_results.len(), 10);
571 for result in &all_results {
572 assert!(result.is_some());
573 }
574 }
575
576 #[tokio::test]
577 async fn test_health_check_default() {
578 let source = InMemoryLookupSource::new();
579 assert!(source.health_check().await.is_ok());
580 }
581
582 #[test]
583 fn test_estimated_row_count() {
584 let mut source = InMemoryLookupSource::new();
585 assert_eq!(source.estimated_row_count(), Some(0));
586 source.insert(b"k1".to_vec(), make_batch(1, "Alice"));
587 assert_eq!(source.estimated_row_count(), Some(1));
588 }
589
590 #[test]
591 fn test_capabilities_default() {
592 let caps = LookupSourceCapabilities::default();
593 assert!(!caps.supports_predicate_pushdown);
594 assert!(!caps.supports_projection_pushdown);
595 assert!(!caps.supports_batch_lookup);
596 assert_eq!(caps.max_batch_size, 0);
597 }
598
599 #[test]
600 fn test_schema_propagation() {
601 let source = InMemoryLookupSource::new();
602 let schema = LookupSource::schema(&source);
603 assert_eq!(schema.fields().len(), 2);
604 assert_eq!(schema.field(0).name(), "id");
605 assert_eq!(schema.field(1).name(), "name");
606 }
607
608 #[test]
609 fn test_pushdown_adapter_schema_propagation() {
610 let source = InMemoryLookupSource::new();
611 let caps = SourceCapabilities {
612 eq_columns: vec![],
613 range_columns: vec![],
614 in_columns: vec![],
615 supports_null_check: false,
616 };
617 let adapter = PushdownAdapter::new(source, caps);
618 let schema = LookupSource::schema(&adapter);
619 assert_eq!(schema.fields().len(), 2);
620 }
621
622 #[test]
623 fn test_evaluate_predicate_is_null() {
624 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, true)]));
625 let batch = RecordBatch::try_new(
626 schema,
627 vec![Arc::new(Int64Array::from(vec![Some(1), None, Some(3)]))],
628 )
629 .unwrap();
630
631 let pred = Predicate::IsNull {
632 column: "id".into(),
633 };
634 let mask = evaluate_predicate(&batch, &pred).unwrap();
635 assert!(!mask.value(0));
636 assert!(mask.value(1));
637 assert!(!mask.value(2));
638 }
639
640 #[test]
641 fn test_evaluate_predicate_in_list() {
642 let schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, false)]));
643 let batch = RecordBatch::try_new(
644 schema,
645 vec![Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol"]))],
646 )
647 .unwrap();
648
649 let pred = Predicate::In {
650 column: "name".into(),
651 values: vec![
652 ScalarValue::Utf8("Alice".into()),
653 ScalarValue::Utf8("Carol".into()),
654 ],
655 };
656 let mask = evaluate_predicate(&batch, &pred).unwrap();
657 assert!(mask.value(0));
658 assert!(!mask.value(1));
659 assert!(mask.value(2));
660 }
661
662 #[test]
663 fn test_evaluate_predicate_timestamp_microsecond() {
664 use arrow_array::types::TimestampMicrosecondType;
665 use arrow_array::PrimitiveArray;
666
667 let schema = Arc::new(Schema::new(vec![Field::new(
668 "ts",
669 DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, None),
670 false,
671 )]));
672 let ts_arr: PrimitiveArray<TimestampMicrosecondType> =
673 vec![1_000_000i64, 2_000_000, 3_000_000].into();
674 let batch = RecordBatch::try_new(schema, vec![Arc::new(ts_arr)]).unwrap();
675
676 let pred = Predicate::Eq {
677 column: "ts".into(),
678 value: ScalarValue::Timestamp(2_000_000),
679 };
680 let mask = evaluate_predicate(&batch, &pred).unwrap();
681 assert!(!mask.value(0));
682 assert!(mask.value(1));
683 assert!(!mask.value(2));
684 }
685
686 #[test]
687 fn test_evaluate_predicate_timestamp_millisecond() {
688 use arrow_array::types::TimestampMillisecondType;
689 use arrow_array::PrimitiveArray;
690
691 let schema = Arc::new(Schema::new(vec![Field::new(
692 "ts",
693 DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None),
694 false,
695 )]));
696 let ts_arr: PrimitiveArray<TimestampMillisecondType> = vec![1_000i64, 2_000, 3_000].into();
698 let batch = RecordBatch::try_new(schema, vec![Arc::new(ts_arr)]).unwrap();
699
700 let pred = Predicate::Gt {
702 column: "ts".into(),
703 value: ScalarValue::Timestamp(2_000_000),
704 };
705 let mask = evaluate_predicate(&batch, &pred).unwrap();
706 assert!(!mask.value(0)); assert!(!mask.value(1)); assert!(mask.value(2)); }
710}