1use std::future::Future;
4use std::time::Duration;
5
6use arrow::compute::filter_record_batch;
7use arrow_array::BooleanArray;
8use arrow_array::RecordBatch;
9use arrow_schema::SchemaRef;
10
11use crate::lookup::predicate::{split_predicates, Predicate, ScalarValue, SourceCapabilities};
12
13pub type ColumnId = u32;
15
16#[derive(Debug, thiserror::Error)]
18pub enum LookupError {
19 #[error("connection failed: {0}")]
21 Connection(String),
22
23 #[error("query failed: {0}")]
25 Query(String),
26
27 #[error("timeout after {0:?}")]
29 Timeout(Duration),
30
31 #[error("not available: {0}")]
33 NotAvailable(String),
34
35 #[error("internal: {0}")]
37 Internal(String),
38}
39
40#[derive(Debug, Clone, Default)]
46pub struct LookupSourceCapabilities {
47 pub supports_predicate_pushdown: bool,
49 pub supports_projection_pushdown: bool,
51 pub supports_batch_lookup: bool,
53 pub max_batch_size: usize,
55}
56
57impl LookupSourceCapabilities {
58 #[must_use]
60 pub fn none() -> Self {
61 Self::default()
62 }
63}
64
65pub fn projection_names(
73 schema: &SchemaRef,
74 projection: &[ColumnId],
75) -> Result<Vec<String>, LookupError> {
76 if projection.is_empty() {
77 return Ok(schema.fields().iter().map(|f| f.name().clone()).collect());
78 }
79 projection
80 .iter()
81 .map(|&c| {
82 schema
83 .fields()
84 .get(c as usize)
85 .map(|f| f.name().clone())
86 .ok_or_else(|| LookupError::Internal(format!("projection column {c} out of range")))
87 })
88 .collect()
89}
90
91pub trait LookupSource: Send + Sync {
104 fn query(
110 &self,
111 keys: &[&[u8]],
112 predicates: &[Predicate],
113 projection: &[ColumnId],
114 ) -> impl Future<Output = Result<Vec<Option<RecordBatch>>, LookupError>> + Send;
115
116 fn capabilities(&self) -> LookupSourceCapabilities;
118
119 fn source_name(&self) -> &str;
121
122 fn schema(&self) -> SchemaRef;
124
125 fn estimated_row_count(&self) -> Option<u64> {
127 None
128 }
129
130 fn health_check(&self) -> impl Future<Output = Result<(), LookupError>> + Send {
132 async { Ok(()) }
133 }
134}
135
136#[async_trait::async_trait]
142pub trait LookupSourceDyn: Send + Sync {
143 async fn query_batch(
145 &self,
146 keys: &[&[u8]],
147 predicates: &[Predicate],
148 projection: &[ColumnId],
149 ) -> Result<Vec<Option<RecordBatch>>, LookupError>;
150
151 fn schema(&self) -> SchemaRef;
153}
154
155#[async_trait::async_trait]
156impl<T: LookupSource> LookupSourceDyn for T {
157 async fn query_batch(
158 &self,
159 keys: &[&[u8]],
160 predicates: &[Predicate],
161 projection: &[ColumnId],
162 ) -> Result<Vec<Option<RecordBatch>>, LookupError> {
163 self.query(keys, predicates, projection).await
164 }
165
166 fn schema(&self) -> SchemaRef {
167 LookupSource::schema(self)
168 }
169}
170
171pub struct PushdownAdapter<S> {
177 inner: S,
178 column_capabilities: SourceCapabilities,
179}
180
181impl<S: LookupSource> PushdownAdapter<S> {
182 pub fn new(inner: S, column_capabilities: SourceCapabilities) -> Self {
188 Self {
189 inner,
190 column_capabilities,
191 }
192 }
193
194 fn split(&self, predicates: &[Predicate]) -> (Vec<Predicate>, Vec<Predicate>) {
196 let split = split_predicates(predicates.to_vec(), &self.column_capabilities);
197 (split.pushable, split.local)
198 }
199}
200
201fn compare_column_scalar(
205 batch: &RecordBatch,
206 column: &str,
207 value: &ScalarValue,
208 cmp_fn: fn(
209 &dyn arrow_array::Datum,
210 &dyn arrow_array::Datum,
211 ) -> Result<BooleanArray, arrow::error::ArrowError>,
212) -> Option<BooleanArray> {
213 use arrow_array::types::{TimestampMicrosecondType, TimestampMillisecondType};
214 use arrow_array::{Float64Array, Int64Array, PrimitiveArray, Scalar, StringArray};
215
216 let idx = batch.schema().index_of(column).ok()?;
217 let col = batch.column(idx);
218 match value {
219 ScalarValue::Int64(v) => cmp_fn(col, &Scalar::new(Int64Array::from(vec![*v]))).ok(),
220 ScalarValue::Float64(v) => cmp_fn(col, &Scalar::new(Float64Array::from(vec![*v]))).ok(),
221 ScalarValue::Utf8(v) => cmp_fn(col, &Scalar::new(StringArray::from(vec![v.as_str()]))).ok(),
222 ScalarValue::Bool(v) => cmp_fn(col, &Scalar::new(BooleanArray::from(vec![*v]))).ok(),
223 ScalarValue::Timestamp(us) => {
224 if col
225 .as_any()
226 .is::<PrimitiveArray<TimestampMicrosecondType>>()
227 {
228 let scalar = PrimitiveArray::<TimestampMicrosecondType>::from(vec![*us]);
229 cmp_fn(col, &Scalar::new(scalar)).ok()
230 } else if col
231 .as_any()
232 .is::<PrimitiveArray<TimestampMillisecondType>>()
233 {
234 let scalar = PrimitiveArray::<TimestampMillisecondType>::from(vec![*us / 1000]);
235 cmp_fn(col, &Scalar::new(scalar)).ok()
236 } else {
237 None
238 }
239 }
240 _ => None,
241 }
242}
243
244fn evaluate_predicate(batch: &RecordBatch, predicate: &Predicate) -> Option<BooleanArray> {
246 use arrow::compute::kernels::cmp;
247
248 match predicate {
249 Predicate::Eq { column, value } => compare_column_scalar(batch, column, value, cmp::eq),
250 Predicate::NotEq { column, value } => compare_column_scalar(batch, column, value, cmp::neq),
251 Predicate::Lt { column, value } => compare_column_scalar(batch, column, value, cmp::lt),
252 Predicate::LtEq { column, value } => {
253 compare_column_scalar(batch, column, value, cmp::lt_eq)
254 }
255 Predicate::Gt { column, value } => compare_column_scalar(batch, column, value, cmp::gt),
256 Predicate::GtEq { column, value } => {
257 compare_column_scalar(batch, column, value, cmp::gt_eq)
258 }
259 Predicate::IsNull { column } => {
260 let idx = batch.schema().index_of(column).ok()?;
261 let col = batch.column(idx);
262 Some(arrow::compute::is_null(col).ok()?)
263 }
264 Predicate::IsNotNull { column } => {
265 let idx = batch.schema().index_of(column).ok()?;
266 let col = batch.column(idx);
267 Some(arrow::compute::is_not_null(col).ok()?)
268 }
269 Predicate::In { column, values } => {
270 let idx = batch.schema().index_of(column).ok()?;
271 let col = batch.column(idx);
272 let mut mask: Option<BooleanArray> = None;
273 for v in values {
274 let eq_mask = evaluate_predicate(
275 batch,
276 &Predicate::Eq {
277 column: column.clone(),
278 value: v.clone(),
279 },
280 )?;
281 mask = Some(match mask {
282 Some(existing) => arrow::compute::or(&existing, &eq_mask).ok()?,
283 None => eq_mask,
284 });
285 }
286 mask.or_else(|| Some(BooleanArray::from(vec![false; col.len()])))
287 }
288 }
289}
290
291fn apply_local_predicates(batch: &RecordBatch, predicates: &[Predicate]) -> Option<RecordBatch> {
293 if predicates.is_empty() {
294 return Some(batch.clone());
295 }
296 let mut combined: Option<BooleanArray> = None;
297 for pred in predicates {
298 let mask = evaluate_predicate(batch, pred)?;
299 combined = Some(match combined {
300 Some(existing) => arrow::compute::and(&existing, &mask).ok()?,
301 None => mask,
302 });
303 }
304 match combined {
305 Some(mask) => filter_record_batch(batch, &mask).ok(),
306 None => Some(batch.clone()),
307 }
308}
309
310impl<S: LookupSource> LookupSource for PushdownAdapter<S> {
311 async fn query(
312 &self,
313 keys: &[&[u8]],
314 predicates: &[Predicate],
315 projection: &[ColumnId],
316 ) -> Result<Vec<Option<RecordBatch>>, LookupError> {
317 let (pushable, local) = self.split(predicates);
318 let results = self.inner.query(keys, &pushable, projection).await?;
319
320 if local.is_empty() {
321 return Ok(results);
322 }
323
324 Ok(results
325 .into_iter()
326 .map(|opt| {
327 opt.and_then(|batch| {
328 let filtered = apply_local_predicates(&batch, &local)?;
329 if filtered.num_rows() == 0 {
330 None
331 } else {
332 Some(filtered)
333 }
334 })
335 })
336 .collect())
337 }
338
339 fn capabilities(&self) -> LookupSourceCapabilities {
340 self.inner.capabilities()
341 }
342
343 fn source_name(&self) -> &str {
344 self.inner.source_name()
345 }
346
347 fn schema(&self) -> SchemaRef {
348 self.inner.schema()
349 }
350
351 fn estimated_row_count(&self) -> Option<u64> {
352 self.inner.estimated_row_count()
353 }
354
355 fn health_check(&self) -> impl Future<Output = Result<(), LookupError>> + Send {
356 self.inner.health_check()
357 }
358}
359
360#[cfg(test)]
361#[allow(clippy::disallowed_types)] mod tests {
363 use super::*;
364 use arrow_array::{Int64Array, StringArray};
365 use arrow_schema::{DataType, Field, Schema};
366 use std::sync::Arc;
367
368 fn test_schema() -> SchemaRef {
369 Arc::new(Schema::new(vec![
370 Field::new("id", DataType::Int64, false),
371 Field::new("name", DataType::Utf8, false),
372 ]))
373 }
374
375 fn make_batch(id: i64, name: &str) -> RecordBatch {
376 RecordBatch::try_new(
377 test_schema(),
378 vec![
379 Arc::new(Int64Array::from(vec![id])),
380 Arc::new(StringArray::from(vec![name])),
381 ],
382 )
383 .unwrap()
384 }
385
386 struct InMemoryLookupSource {
388 data: std::collections::HashMap<Vec<u8>, RecordBatch>,
389 capabilities: LookupSourceCapabilities,
390 source_schema: SchemaRef,
391 }
392
393 impl InMemoryLookupSource {
394 fn new() -> Self {
395 Self {
396 data: std::collections::HashMap::new(),
397 capabilities: LookupSourceCapabilities::default(),
398 source_schema: test_schema(),
399 }
400 }
401
402 fn insert(&mut self, key: Vec<u8>, value: RecordBatch) {
403 self.data.insert(key, value);
404 }
405
406 fn with_capabilities(mut self, caps: LookupSourceCapabilities) -> Self {
407 self.capabilities = caps;
408 self
409 }
410 }
411
412 impl LookupSource for InMemoryLookupSource {
413 fn query(
414 &self,
415 keys: &[&[u8]],
416 _predicates: &[Predicate],
417 _projection: &[ColumnId],
418 ) -> impl Future<Output = Result<Vec<Option<RecordBatch>>, LookupError>> + Send {
419 let results: Vec<Option<RecordBatch>> = keys
420 .iter()
421 .map(|k| self.data.get::<[u8]>(k.as_ref()).cloned())
422 .collect();
423 async move { Ok(results) }
424 }
425
426 fn capabilities(&self) -> LookupSourceCapabilities {
427 self.capabilities.clone()
428 }
429
430 fn source_name(&self) -> &'static str {
431 "in_memory_test"
432 }
433
434 fn schema(&self) -> SchemaRef {
435 Arc::clone(&self.source_schema)
436 }
437
438 fn estimated_row_count(&self) -> Option<u64> {
439 Some(self.data.len() as u64)
440 }
441 }
442
443 #[tokio::test]
444 async fn test_query_result_aligned_with_keys() {
445 let mut source = InMemoryLookupSource::new();
446 source.insert(b"k1".to_vec(), make_batch(1, "Alice"));
447 source.insert(b"k3".to_vec(), make_batch(3, "Carol"));
448
449 let keys: Vec<&[u8]> = vec![b"k1", b"k2", b"k3"];
450 let results = source.query(&keys, &[], &[]).await.unwrap();
451
452 assert_eq!(results.len(), keys.len());
453 assert!(results[0].is_some());
454 assert!(results[1].is_none());
455 assert!(results[2].is_some());
456 }
457
458 #[tokio::test]
459 async fn test_pushdown_adapter_splits_predicates() {
460 let mut source = InMemoryLookupSource::new();
461 source.insert(b"k1".to_vec(), make_batch(1, "Alice"));
462
463 let caps = SourceCapabilities {
464 eq_columns: vec!["id".into()],
465 range_columns: vec![],
466 in_columns: vec![],
467 supports_null_check: false,
468 };
469
470 let adapter = PushdownAdapter::new(
471 source.with_capabilities(LookupSourceCapabilities {
472 supports_predicate_pushdown: true,
473 ..Default::default()
474 }),
475 caps,
476 );
477
478 let predicates = vec![
479 Predicate::Eq {
480 column: "id".into(),
481 value: crate::lookup::ScalarValue::Int64(1),
482 },
483 Predicate::NotEq {
484 column: "id".into(),
485 value: crate::lookup::ScalarValue::Int64(2),
486 },
487 ];
488
489 let (pushable, local) = adapter.split(&predicates);
490 assert_eq!(pushable.len(), 1); assert_eq!(local.len(), 1); let keys: Vec<&[u8]> = vec![b"k1"];
494 let results = adapter.query(&keys, &predicates, &[]).await.unwrap();
495 assert_eq!(results.len(), 1);
496 assert!(results[0].is_some());
497 }
498
499 #[tokio::test]
500 async fn test_pushdown_adapter_local_predicate_filters() {
501 let mut source = InMemoryLookupSource::new();
502 source.insert(b"k1".to_vec(), make_batch(1, "Alice"));
503 source.insert(b"k2".to_vec(), make_batch(2, "Bob"));
504
505 let caps = SourceCapabilities {
506 eq_columns: vec![],
507 range_columns: vec![],
508 in_columns: vec![],
509 supports_null_check: false,
510 };
511
512 let adapter = PushdownAdapter::new(source, caps);
513
514 let predicates = vec![Predicate::Gt {
516 column: "id".into(),
517 value: ScalarValue::Int64(1),
518 }];
519
520 let keys: Vec<&[u8]> = vec![b"k1", b"k2"];
521 let results = adapter.query(&keys, &predicates, &[]).await.unwrap();
522 assert_eq!(results.len(), 2);
523 assert!(results[0].is_none()); assert!(results[1].is_some()); }
526
527 #[tokio::test]
528 async fn test_pushdown_adapter_not_eq_local_evaluation() {
529 let mut source = InMemoryLookupSource::new();
530 source.insert(b"k1".to_vec(), make_batch(1, "Alice"));
531 source.insert(b"k2".to_vec(), make_batch(2, "Bob"));
532
533 let caps = SourceCapabilities {
534 eq_columns: vec!["id".into()],
535 range_columns: vec![],
536 in_columns: vec![],
537 supports_null_check: false,
538 };
539
540 let adapter = PushdownAdapter::new(
541 source.with_capabilities(LookupSourceCapabilities {
542 supports_predicate_pushdown: true,
543 ..Default::default()
544 }),
545 caps,
546 );
547
548 let predicates = vec![Predicate::NotEq {
550 column: "id".into(),
551 value: ScalarValue::Int64(1),
552 }];
553
554 let keys: Vec<&[u8]> = vec![b"k1", b"k2"];
555 let results = adapter.query(&keys, &predicates, &[]).await.unwrap();
556 assert_eq!(results.len(), 2);
557 assert!(results[0].is_none()); assert!(results[1].is_some()); }
560
561 #[tokio::test]
562 async fn test_mock_source_batch_chunking() {
563 let mut source = InMemoryLookupSource::new();
564 for i in 0..10u8 {
565 source.insert(vec![i], make_batch(i64::from(i), &format!("name_{i}")));
566 }
567
568 let caps = LookupSourceCapabilities {
569 max_batch_size: 3,
570 supports_batch_lookup: true,
571 ..Default::default()
572 };
573 let source = source.with_capabilities(caps);
574
575 let keys: Vec<Vec<u8>> = (0..10u8).map(|i| vec![i]).collect();
576 let key_refs: Vec<&[u8]> = keys.iter().map(Vec::as_slice).collect();
577
578 let max = source.capabilities().max_batch_size;
579 let mut all_results = Vec::new();
580 for chunk in key_refs.chunks(max) {
581 let chunk_results = source.query(chunk, &[], &[]).await.unwrap();
582 all_results.extend(chunk_results);
583 }
584
585 assert_eq!(all_results.len(), 10);
586 for result in &all_results {
587 assert!(result.is_some());
588 }
589 }
590
591 #[tokio::test]
592 async fn test_health_check_default() {
593 let source = InMemoryLookupSource::new();
594 assert!(source.health_check().await.is_ok());
595 }
596
597 #[test]
598 fn test_estimated_row_count() {
599 let mut source = InMemoryLookupSource::new();
600 assert_eq!(source.estimated_row_count(), Some(0));
601 source.insert(b"k1".to_vec(), make_batch(1, "Alice"));
602 assert_eq!(source.estimated_row_count(), Some(1));
603 }
604
605 #[test]
606 fn test_capabilities_default() {
607 let caps = LookupSourceCapabilities::default();
608 assert!(!caps.supports_predicate_pushdown);
609 assert!(!caps.supports_projection_pushdown);
610 assert!(!caps.supports_batch_lookup);
611 assert_eq!(caps.max_batch_size, 0);
612 }
613
614 #[test]
615 fn test_schema_propagation() {
616 let source = InMemoryLookupSource::new();
617 let schema = LookupSource::schema(&source);
618 assert_eq!(schema.fields().len(), 2);
619 assert_eq!(schema.field(0).name(), "id");
620 assert_eq!(schema.field(1).name(), "name");
621 }
622
623 #[test]
624 fn test_pushdown_adapter_schema_propagation() {
625 let source = InMemoryLookupSource::new();
626 let caps = SourceCapabilities {
627 eq_columns: vec![],
628 range_columns: vec![],
629 in_columns: vec![],
630 supports_null_check: false,
631 };
632 let adapter = PushdownAdapter::new(source, caps);
633 let schema = LookupSource::schema(&adapter);
634 assert_eq!(schema.fields().len(), 2);
635 }
636
637 #[test]
638 fn test_evaluate_predicate_is_null() {
639 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, true)]));
640 let batch = RecordBatch::try_new(
641 schema,
642 vec![Arc::new(Int64Array::from(vec![Some(1), None, Some(3)]))],
643 )
644 .unwrap();
645
646 let pred = Predicate::IsNull {
647 column: "id".into(),
648 };
649 let mask = evaluate_predicate(&batch, &pred).unwrap();
650 assert!(!mask.value(0));
651 assert!(mask.value(1));
652 assert!(!mask.value(2));
653 }
654
655 #[test]
656 fn test_evaluate_predicate_in_list() {
657 let schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, false)]));
658 let batch = RecordBatch::try_new(
659 schema,
660 vec![Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol"]))],
661 )
662 .unwrap();
663
664 let pred = Predicate::In {
665 column: "name".into(),
666 values: vec![
667 ScalarValue::Utf8("Alice".into()),
668 ScalarValue::Utf8("Carol".into()),
669 ],
670 };
671 let mask = evaluate_predicate(&batch, &pred).unwrap();
672 assert!(mask.value(0));
673 assert!(!mask.value(1));
674 assert!(mask.value(2));
675 }
676
677 #[test]
678 fn test_evaluate_predicate_timestamp_microsecond() {
679 use arrow_array::types::TimestampMicrosecondType;
680 use arrow_array::PrimitiveArray;
681
682 let schema = Arc::new(Schema::new(vec![Field::new(
683 "ts",
684 DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, None),
685 false,
686 )]));
687 let ts_arr: PrimitiveArray<TimestampMicrosecondType> =
688 vec![1_000_000i64, 2_000_000, 3_000_000].into();
689 let batch = RecordBatch::try_new(schema, vec![Arc::new(ts_arr)]).unwrap();
690
691 let pred = Predicate::Eq {
692 column: "ts".into(),
693 value: ScalarValue::Timestamp(2_000_000),
694 };
695 let mask = evaluate_predicate(&batch, &pred).unwrap();
696 assert!(!mask.value(0));
697 assert!(mask.value(1));
698 assert!(!mask.value(2));
699 }
700
701 #[test]
702 fn test_evaluate_predicate_timestamp_millisecond() {
703 use arrow_array::types::TimestampMillisecondType;
704 use arrow_array::PrimitiveArray;
705
706 let schema = Arc::new(Schema::new(vec![Field::new(
707 "ts",
708 DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None),
709 false,
710 )]));
711 let ts_arr: PrimitiveArray<TimestampMillisecondType> = vec![1_000i64, 2_000, 3_000].into();
713 let batch = RecordBatch::try_new(schema, vec![Arc::new(ts_arr)]).unwrap();
714
715 let pred = Predicate::Gt {
717 column: "ts".into(),
718 value: ScalarValue::Timestamp(2_000_000),
719 };
720 let mask = evaluate_predicate(&batch, &pred).unwrap();
721 assert!(!mask.value(0)); assert!(!mask.value(1)); assert!(mask.value(2)); }
725}