Skip to main content

laminar_sql/datafusion/
lookup_join_exec.rs

1//! Physical execution plan for lookup joins.
2//!
3//! Bridges `LookupJoinNode` (logical) to a hash-probe executor that
4//! joins streaming input against a pre-indexed lookup table snapshot.
5//!
6//! ## Data flow
7//!
8//! ```text
9//! Stream input ──► LookupJoinExec ──► Output (stream + lookup columns)
10//!                       │
11//!                  HashIndex probe
12//!                       │
13//!                  LookupSnapshot (pre-indexed RecordBatch)
14//! ```
15
16use std::any::Any;
17use std::collections::HashMap;
18use std::fmt::{self, Debug, Formatter};
19use std::sync::Arc;
20
21use parking_lot::RwLock;
22
23use std::collections::BTreeMap;
24
25use arrow::compute::take;
26use arrow::row::{RowConverter, SortField};
27use arrow_array::{RecordBatch, UInt32Array};
28use arrow_schema::{Schema, SchemaRef};
29use async_trait::async_trait;
30use datafusion::execution::{SendableRecordBatchStream, SessionState, TaskContext};
31use datafusion::logical_expr::{LogicalPlan, UserDefinedLogicalNode};
32use datafusion::physical_expr::{EquivalenceProperties, LexOrdering, Partitioning};
33use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
34use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
35use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
36use datafusion::physical_planner::{ExtensionPlanner, PhysicalPlanner};
37use datafusion_common::{DataFusionError, Result};
38use datafusion_expr::Expr;
39use futures::StreamExt;
40use laminar_core::lookup::foyer_cache::FoyerMemoryCache;
41use laminar_core::lookup::source::LookupSourceDyn;
42use laminar_core::lookup::table::LookupTable;
43use tokio::sync::Semaphore;
44
45use super::lookup_join::{LookupJoinNode, LookupJoinType};
46
47// ── Registry ─────────────────────────────────────────────────────
48
49/// Thread-safe registry of lookup table entries (snapshot or partial).
50///
51/// The db layer populates this when `CREATE LOOKUP TABLE` executes;
52/// the [`LookupJoinExtensionPlanner`] reads it at physical plan time.
53#[derive(Default)]
54pub struct LookupTableRegistry {
55    tables: RwLock<HashMap<String, RegisteredLookup>>,
56}
57
58/// A registered lookup table entry — snapshot, partial (on-demand), or
59/// versioned (temporal join with version history).
60pub enum RegisteredLookup {
61    /// Full snapshot: all rows pre-loaded in a single batch.
62    Snapshot(Arc<LookupSnapshot>),
63    /// Partial (on-demand): bounded foyer cache with S3-FIFO eviction.
64    Partial(Arc<PartialLookupState>),
65    /// Versioned: all versions of all keys for temporal joins.
66    Versioned(Arc<VersionedLookupState>),
67}
68
69/// Point-in-time snapshot of a lookup table for join execution.
70pub struct LookupSnapshot {
71    /// All rows concatenated into a single batch.
72    pub batch: RecordBatch,
73    /// Primary key column names used to build the hash index.
74    pub key_columns: Vec<String>,
75}
76
77/// State for a versioned (temporal) lookup table.
78///
79/// Holds all versions of all keys in a single `RecordBatch`, plus a
80/// pre-built `VersionedIndex` for efficient point-in-time lookups.
81/// The index is built once at registration time and rebuilt only when
82/// the table is updated via CDC.
83pub struct VersionedLookupState {
84    /// All rows (all versions) concatenated into a single batch.
85    pub batch: RecordBatch,
86    /// Pre-built versioned index (built at registration time, not per-cycle).
87    pub index: Arc<VersionedIndex>,
88    /// Primary key column names for the equi-join.
89    pub key_columns: Vec<String>,
90    /// Column containing the version timestamp in the table.
91    pub version_column: String,
92    /// Stream-side column name for event time (the AS OF column).
93    pub stream_time_column: String,
94}
95
96/// State for a partial (on-demand) lookup table.
97pub struct PartialLookupState {
98    /// Bounded foyer memory cache with S3-FIFO eviction.
99    pub foyer_cache: Arc<FoyerMemoryCache>,
100    /// Schema of the lookup table.
101    pub schema: SchemaRef,
102    /// Key column names for row encoding.
103    pub key_columns: Vec<String>,
104    /// `SortField` descriptors for key encoding via `RowConverter`.
105    pub key_sort_fields: Vec<SortField>,
106    /// Async source for cache miss fallback (None = cache-only mode).
107    pub source: Option<Arc<dyn LookupSourceDyn>>,
108    /// Limits concurrent source queries to avoid overloading the source.
109    pub fetch_semaphore: Arc<Semaphore>,
110}
111
112impl LookupTableRegistry {
113    /// Creates an empty registry.
114    #[must_use]
115    pub fn new() -> Self {
116        Self::default()
117    }
118
119    /// Registers or replaces a lookup table snapshot.
120    ///
121    /// # Panics
122    ///
123    /// Panics if the internal lock is poisoned.
124    pub fn register(&self, name: &str, snapshot: LookupSnapshot) {
125        self.tables.write().insert(
126            name.to_lowercase(),
127            RegisteredLookup::Snapshot(Arc::new(snapshot)),
128        );
129    }
130
131    /// Registers or replaces a partial (on-demand) lookup table.
132    ///
133    /// # Panics
134    ///
135    /// Panics if the internal lock is poisoned.
136    pub fn register_partial(&self, name: &str, state: PartialLookupState) {
137        self.tables.write().insert(
138            name.to_lowercase(),
139            RegisteredLookup::Partial(Arc::new(state)),
140        );
141    }
142
143    /// Registers or replaces a versioned (temporal) lookup table.
144    ///
145    /// # Panics
146    ///
147    /// Panics if the internal lock is poisoned.
148    pub fn register_versioned(&self, name: &str, state: VersionedLookupState) {
149        self.tables.write().insert(
150            name.to_lowercase(),
151            RegisteredLookup::Versioned(Arc::new(state)),
152        );
153    }
154
155    /// Removes a lookup table from the registry.
156    ///
157    /// # Panics
158    ///
159    /// Panics if the internal lock is poisoned.
160    pub fn unregister(&self, name: &str) {
161        self.tables.write().remove(&name.to_lowercase());
162    }
163
164    /// Returns the current snapshot for a table, if registered as a snapshot.
165    ///
166    /// # Panics
167    ///
168    /// Panics if the internal lock is poisoned.
169    #[must_use]
170    pub fn get(&self, name: &str) -> Option<Arc<LookupSnapshot>> {
171        let tables = self.tables.read();
172        match tables.get(&name.to_lowercase())? {
173            RegisteredLookup::Snapshot(s) => Some(Arc::clone(s)),
174            RegisteredLookup::Partial(_) | RegisteredLookup::Versioned(_) => None,
175        }
176    }
177
178    /// Returns the registered lookup entry (snapshot, partial, or versioned).
179    ///
180    /// # Panics
181    ///
182    /// Panics if the internal lock is poisoned.
183    pub fn get_entry(&self, name: &str) -> Option<RegisteredLookup> {
184        let tables = self.tables.read();
185        tables.get(&name.to_lowercase()).map(|e| match e {
186            RegisteredLookup::Snapshot(s) => RegisteredLookup::Snapshot(Arc::clone(s)),
187            RegisteredLookup::Partial(p) => RegisteredLookup::Partial(Arc::clone(p)),
188            RegisteredLookup::Versioned(v) => RegisteredLookup::Versioned(Arc::clone(v)),
189        })
190    }
191}
192
193// ── Hash Index ───────────────────────────────────────────────────
194
195/// Pre-built hash index mapping encoded key bytes to row indices.
196struct HashIndex {
197    map: HashMap<Box<[u8]>, Vec<u32>>,
198}
199
200impl HashIndex {
201    /// Builds an index over `key_indices` columns in `batch`.
202    ///
203    /// Uses Arrow's `RowConverter` for binary-comparable key encoding
204    /// so any Arrow data type is handled without manual serialization.
205    fn build(batch: &RecordBatch, key_indices: &[usize]) -> Result<Self> {
206        if batch.num_rows() == 0 {
207            return Ok(Self {
208                map: HashMap::new(),
209            });
210        }
211
212        let sort_fields: Vec<SortField> = key_indices
213            .iter()
214            .map(|&i| SortField::new(batch.schema().field(i).data_type().clone()))
215            .collect();
216        let converter = RowConverter::new(sort_fields)?;
217
218        let key_cols: Vec<_> = key_indices
219            .iter()
220            .map(|&i| batch.column(i).clone())
221            .collect();
222        let rows = converter.convert_columns(&key_cols)?;
223
224        let num_rows = batch.num_rows();
225        let mut map: HashMap<Box<[u8]>, Vec<u32>> = HashMap::with_capacity(num_rows);
226        #[allow(clippy::cast_possible_truncation)] // batch row count fits u32
227        for i in 0..num_rows {
228            map.entry(Box::from(rows.row(i).as_ref()))
229                .or_default()
230                .push(i as u32);
231        }
232
233        Ok(Self { map })
234    }
235
236    fn probe(&self, key: &[u8]) -> Option<&[u32]> {
237        self.map.get(key).map(Vec::as_slice)
238    }
239}
240
241// ── Versioned Index ──────────────────────────────────────────────
242
243/// Pre-built versioned index mapping encoded key bytes to a BTreeMap
244/// of version timestamps to row indices. Supports point-in-time lookups
245/// via `probe_at_time` for temporal joins.
246#[derive(Default)]
247pub struct VersionedIndex {
248    map: HashMap<Box<[u8]>, BTreeMap<i64, Vec<u32>>>,
249}
250
251impl VersionedIndex {
252    /// Builds a versioned index over `key_indices` and `version_col_idx`
253    /// columns in `batch`.
254    ///
255    /// Uses Arrow's `RowConverter` for binary-comparable key encoding.
256    /// Null keys and null version timestamps are skipped.
257    ///
258    /// # Errors
259    ///
260    /// Returns an error if key encoding or timestamp extraction fails.
261    pub fn build(
262        batch: &RecordBatch,
263        key_indices: &[usize],
264        version_col_idx: usize,
265    ) -> Result<Self> {
266        if batch.num_rows() == 0 {
267            return Ok(Self {
268                map: HashMap::new(),
269            });
270        }
271
272        let sort_fields: Vec<SortField> = key_indices
273            .iter()
274            .map(|&i| SortField::new(batch.schema().field(i).data_type().clone()))
275            .collect();
276        let converter = RowConverter::new(sort_fields)?;
277
278        let key_cols: Vec<_> = key_indices
279            .iter()
280            .map(|&i| batch.column(i).clone())
281            .collect();
282        let rows = converter.convert_columns(&key_cols)?;
283
284        let timestamps = extract_i64_timestamps(batch.column(version_col_idx))?;
285
286        let num_rows = batch.num_rows();
287        let mut map: HashMap<Box<[u8]>, BTreeMap<i64, Vec<u32>>> = HashMap::with_capacity(num_rows);
288        #[allow(clippy::cast_possible_truncation)]
289        for (i, ts_opt) in timestamps.iter().enumerate() {
290            // Skip rows with null keys or null version timestamps.
291            let Some(version_ts) = ts_opt else { continue };
292            if key_cols.iter().any(|c| c.is_null(i)) {
293                continue;
294            }
295            let key = Box::from(rows.row(i).as_ref());
296            map.entry(key)
297                .or_default()
298                .entry(*version_ts)
299                .or_default()
300                .push(i as u32);
301        }
302
303        Ok(Self { map })
304    }
305
306    /// Finds the row index for the latest version `<= event_ts` for the
307    /// given key. Returns the last row index at that version.
308    fn probe_at_time(&self, key: &[u8], event_ts: i64) -> Option<u32> {
309        let versions = self.map.get(key)?;
310        let (_, indices) = versions.range(..=event_ts).next_back()?;
311        indices.last().copied()
312    }
313}
314
315/// Extracts `Option<i64>` timestamp values from an Arrow array column.
316///
317/// Returns `None` for null entries (callers must handle nulls explicitly).
318/// Supports `Int64`, all `Timestamp` variants (scaled to milliseconds),
319/// and `Float64` (truncated to `i64`).
320fn extract_i64_timestamps(col: &dyn arrow_array::Array) -> Result<Vec<Option<i64>>> {
321    use arrow_array::{
322        Float64Array, Int64Array, TimestampMicrosecondArray, TimestampMillisecondArray,
323        TimestampNanosecondArray, TimestampSecondArray,
324    };
325    use arrow_schema::{DataType, TimeUnit};
326
327    let n = col.len();
328    let mut out = Vec::with_capacity(n);
329    macro_rules! extract_typed {
330        ($arr_type:ty, $scale:expr) => {{
331            let arr = col.as_any().downcast_ref::<$arr_type>().ok_or_else(|| {
332                DataFusionError::Internal(concat!("expected ", stringify!($arr_type)).into())
333            })?;
334            for i in 0..n {
335                out.push(if col.is_null(i) {
336                    None
337                } else {
338                    Some(arr.value(i) * $scale)
339                });
340            }
341        }};
342    }
343
344    match col.data_type() {
345        DataType::Int64 => extract_typed!(Int64Array, 1),
346        DataType::Timestamp(TimeUnit::Millisecond, _) => {
347            extract_typed!(TimestampMillisecondArray, 1);
348        }
349        DataType::Timestamp(TimeUnit::Microsecond, _) => {
350            let arr = col
351                .as_any()
352                .downcast_ref::<TimestampMicrosecondArray>()
353                .ok_or_else(|| {
354                    DataFusionError::Internal("expected TimestampMicrosecondArray".into())
355                })?;
356            for i in 0..n {
357                out.push(if col.is_null(i) {
358                    None
359                } else {
360                    Some(arr.value(i) / 1000)
361                });
362            }
363        }
364        DataType::Timestamp(TimeUnit::Second, _) => {
365            extract_typed!(TimestampSecondArray, 1000);
366        }
367        DataType::Timestamp(TimeUnit::Nanosecond, _) => {
368            let arr = col
369                .as_any()
370                .downcast_ref::<TimestampNanosecondArray>()
371                .ok_or_else(|| {
372                    DataFusionError::Internal("expected TimestampNanosecondArray".into())
373                })?;
374            for i in 0..n {
375                out.push(if col.is_null(i) {
376                    None
377                } else {
378                    Some(arr.value(i) / 1_000_000)
379                });
380            }
381        }
382        DataType::Float64 => {
383            let arr = col
384                .as_any()
385                .downcast_ref::<Float64Array>()
386                .ok_or_else(|| DataFusionError::Internal("expected Float64Array".into()))?;
387            #[allow(clippy::cast_possible_truncation)]
388            for i in 0..n {
389                out.push(if col.is_null(i) {
390                    None
391                } else {
392                    Some(arr.value(i) as i64)
393                });
394            }
395        }
396        other => {
397            return Err(DataFusionError::Plan(format!(
398                "unsupported timestamp type for temporal join: {other:?}"
399            )));
400        }
401    }
402
403    Ok(out)
404}
405
406// ── Physical Execution Plan ──────────────────────────────────────
407
408/// Physical plan that hash-probes a pre-indexed lookup table for
409/// each batch from the streaming input.
410pub struct LookupJoinExec {
411    input: Arc<dyn ExecutionPlan>,
412    index: Arc<HashIndex>,
413    lookup_batch: Arc<RecordBatch>,
414    stream_key_indices: Vec<usize>,
415    join_type: LookupJoinType,
416    schema: SchemaRef,
417    properties: PlanProperties,
418    /// `RowConverter` config for encoding probe keys identically to the index.
419    key_sort_fields: Vec<SortField>,
420    stream_field_count: usize,
421}
422
423impl LookupJoinExec {
424    /// Creates a new lookup join executor.
425    ///
426    /// `stream_key_indices` and `lookup_key_indices` must be the same
427    /// length and correspond pairwise (stream key 0 matches lookup key 0).
428    ///
429    /// # Errors
430    ///
431    /// Returns an error if the hash index cannot be built (e.g., unsupported key type).
432    #[allow(clippy::needless_pass_by_value)] // lookup_batch is moved into Arc
433    pub fn try_new(
434        input: Arc<dyn ExecutionPlan>,
435        lookup_batch: RecordBatch,
436        stream_key_indices: Vec<usize>,
437        lookup_key_indices: Vec<usize>,
438        join_type: LookupJoinType,
439        output_schema: SchemaRef,
440    ) -> Result<Self> {
441        let index = HashIndex::build(&lookup_batch, &lookup_key_indices)?;
442
443        let key_sort_fields: Vec<SortField> = lookup_key_indices
444            .iter()
445            .map(|&i| SortField::new(lookup_batch.schema().field(i).data_type().clone()))
446            .collect();
447
448        // Left outer joins produce NULLs for non-matching lookup rows,
449        // so force all lookup columns nullable in the output schema.
450        let output_schema = if join_type == LookupJoinType::LeftOuter {
451            let stream_count = input.schema().fields().len();
452            let mut fields = output_schema.fields().to_vec();
453            for f in &mut fields[stream_count..] {
454                if !f.is_nullable() {
455                    *f = Arc::new(f.as_ref().clone().with_nullable(true));
456                }
457            }
458            Arc::new(Schema::new_with_metadata(
459                fields,
460                output_schema.metadata().clone(),
461            ))
462        } else {
463            output_schema
464        };
465
466        let properties = PlanProperties::new(
467            EquivalenceProperties::new(Arc::clone(&output_schema)),
468            Partitioning::UnknownPartitioning(1),
469            EmissionType::Incremental,
470            Boundedness::Unbounded {
471                requires_infinite_memory: false,
472            },
473        );
474
475        let stream_field_count = input.schema().fields().len();
476
477        Ok(Self {
478            input,
479            index: Arc::new(index),
480            lookup_batch: Arc::new(lookup_batch),
481            stream_key_indices,
482            join_type,
483            schema: output_schema,
484            properties,
485            key_sort_fields,
486            stream_field_count,
487        })
488    }
489}
490
491impl Debug for LookupJoinExec {
492    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
493        f.debug_struct("LookupJoinExec")
494            .field("join_type", &self.join_type)
495            .field("stream_keys", &self.stream_key_indices)
496            .field("lookup_rows", &self.lookup_batch.num_rows())
497            .finish_non_exhaustive()
498    }
499}
500
501impl DisplayAs for LookupJoinExec {
502    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter<'_>) -> fmt::Result {
503        match t {
504            DisplayFormatType::Default | DisplayFormatType::Verbose => {
505                write!(
506                    f,
507                    "LookupJoinExec: type={}, stream_keys={:?}, lookup_rows={}",
508                    self.join_type,
509                    self.stream_key_indices,
510                    self.lookup_batch.num_rows(),
511                )
512            }
513            DisplayFormatType::TreeRender => write!(f, "LookupJoinExec"),
514        }
515    }
516}
517
518impl ExecutionPlan for LookupJoinExec {
519    fn name(&self) -> &'static str {
520        "LookupJoinExec"
521    }
522
523    fn as_any(&self) -> &dyn Any {
524        self
525    }
526
527    fn schema(&self) -> SchemaRef {
528        Arc::clone(&self.schema)
529    }
530
531    fn properties(&self) -> &PlanProperties {
532        &self.properties
533    }
534
535    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
536        vec![&self.input]
537    }
538
539    fn with_new_children(
540        self: Arc<Self>,
541        mut children: Vec<Arc<dyn ExecutionPlan>>,
542    ) -> Result<Arc<dyn ExecutionPlan>> {
543        if children.len() != 1 {
544            return Err(DataFusionError::Plan(
545                "LookupJoinExec requires exactly one child".into(),
546            ));
547        }
548        Ok(Arc::new(Self {
549            input: children.swap_remove(0),
550            index: Arc::clone(&self.index),
551            lookup_batch: Arc::clone(&self.lookup_batch),
552            stream_key_indices: self.stream_key_indices.clone(),
553            join_type: self.join_type,
554            schema: Arc::clone(&self.schema),
555            properties: self.properties.clone(),
556            key_sort_fields: self.key_sort_fields.clone(),
557            stream_field_count: self.stream_field_count,
558        }))
559    }
560
561    fn execute(
562        &self,
563        partition: usize,
564        context: Arc<TaskContext>,
565    ) -> Result<SendableRecordBatchStream> {
566        let input_stream = self.input.execute(partition, context)?;
567        let converter = RowConverter::new(self.key_sort_fields.clone())?;
568        let index = Arc::clone(&self.index);
569        let lookup_batch = Arc::clone(&self.lookup_batch);
570        let stream_key_indices = self.stream_key_indices.clone();
571        let join_type = self.join_type;
572        let schema = self.schema();
573        let stream_field_count = self.stream_field_count;
574
575        let output = input_stream.map(move |result| {
576            let batch = result?;
577            if batch.num_rows() == 0 {
578                return Ok(RecordBatch::new_empty(Arc::clone(&schema)));
579            }
580            probe_batch(
581                &batch,
582                &converter,
583                &index,
584                &lookup_batch,
585                &stream_key_indices,
586                join_type,
587                &schema,
588                stream_field_count,
589            )
590        });
591
592        Ok(Box::pin(RecordBatchStreamAdapter::new(
593            self.schema(),
594            output,
595        )))
596    }
597}
598
599impl datafusion::physical_plan::ExecutionPlanProperties for LookupJoinExec {
600    fn output_partitioning(&self) -> &Partitioning {
601        self.properties.output_partitioning()
602    }
603
604    fn output_ordering(&self) -> Option<&LexOrdering> {
605        self.properties.output_ordering()
606    }
607
608    fn boundedness(&self) -> Boundedness {
609        Boundedness::Unbounded {
610            requires_infinite_memory: false,
611        }
612    }
613
614    fn pipeline_behavior(&self) -> EmissionType {
615        EmissionType::Incremental
616    }
617
618    fn equivalence_properties(&self) -> &EquivalenceProperties {
619        self.properties.equivalence_properties()
620    }
621}
622
623// ── Probe Logic ──────────────────────────────────────────────────
624
625/// Probes the hash index for each row in `stream_batch` and builds
626/// the joined output batch.
627#[allow(clippy::too_many_arguments)]
628fn probe_batch(
629    stream_batch: &RecordBatch,
630    converter: &RowConverter,
631    index: &HashIndex,
632    lookup_batch: &RecordBatch,
633    stream_key_indices: &[usize],
634    join_type: LookupJoinType,
635    output_schema: &SchemaRef,
636    stream_field_count: usize,
637) -> Result<RecordBatch> {
638    let key_cols: Vec<_> = stream_key_indices
639        .iter()
640        .map(|&i| stream_batch.column(i).clone())
641        .collect();
642    let rows = converter.convert_columns(&key_cols)?;
643
644    let num_rows = stream_batch.num_rows();
645    let mut stream_indices: Vec<u32> = Vec::with_capacity(num_rows);
646    let mut lookup_indices: Vec<Option<u32>> = Vec::with_capacity(num_rows);
647
648    #[allow(clippy::cast_possible_truncation)] // batch row count fits u32
649    for row in 0..num_rows {
650        // SQL semantics: NULL != NULL, so rows with any null key never match.
651        if key_cols.iter().any(|c| c.is_null(row)) {
652            if join_type == LookupJoinType::LeftOuter {
653                stream_indices.push(row as u32);
654                lookup_indices.push(None);
655            }
656            continue;
657        }
658
659        let key = rows.row(row);
660        match index.probe(key.as_ref()) {
661            Some(matches) => {
662                for &lookup_row in matches {
663                    stream_indices.push(row as u32);
664                    lookup_indices.push(Some(lookup_row));
665                }
666            }
667            None if join_type == LookupJoinType::LeftOuter => {
668                stream_indices.push(row as u32);
669                lookup_indices.push(None);
670            }
671            None => {}
672        }
673    }
674
675    if stream_indices.is_empty() {
676        return Ok(RecordBatch::new_empty(Arc::clone(output_schema)));
677    }
678
679    // Gather stream-side columns
680    let take_stream = UInt32Array::from(stream_indices);
681    let mut columns = Vec::with_capacity(output_schema.fields().len());
682
683    for col in stream_batch.columns() {
684        columns.push(take(col.as_ref(), &take_stream, None)?);
685    }
686
687    // Gather lookup-side columns (None → null in output)
688    let take_lookup: UInt32Array = lookup_indices.into_iter().collect();
689    for col in lookup_batch.columns() {
690        columns.push(take(col.as_ref(), &take_lookup, None)?);
691    }
692
693    debug_assert_eq!(
694        columns.len(),
695        stream_field_count + lookup_batch.num_columns(),
696        "output column count mismatch"
697    );
698
699    Ok(RecordBatch::try_new(Arc::clone(output_schema), columns)?)
700}
701
702// ── Versioned Lookup Join Exec ────────────────────────────────────
703
704/// Physical plan that probes a versioned (temporal) index for each
705/// batch from the streaming input. For each stream row, finds the
706/// table row with the latest version timestamp `<= event_ts`.
707pub struct VersionedLookupJoinExec {
708    input: Arc<dyn ExecutionPlan>,
709    index: Arc<VersionedIndex>,
710    table_batch: Arc<RecordBatch>,
711    stream_key_indices: Vec<usize>,
712    stream_time_col_idx: usize,
713    join_type: LookupJoinType,
714    schema: SchemaRef,
715    properties: PlanProperties,
716    key_sort_fields: Vec<SortField>,
717    stream_field_count: usize,
718}
719
720impl VersionedLookupJoinExec {
721    /// Creates a new versioned lookup join executor.
722    ///
723    /// The `index` should be pre-built via `VersionedIndex::build()` and
724    /// cached in `VersionedLookupState`. The index is only rebuilt when
725    /// the table data changes (CDC update), not per execution cycle.
726    ///
727    /// # Errors
728    ///
729    /// Returns an error if the output schema cannot be constructed.
730    #[allow(clippy::too_many_arguments, clippy::needless_pass_by_value)]
731    pub fn try_new(
732        input: Arc<dyn ExecutionPlan>,
733        table_batch: RecordBatch,
734        index: Arc<VersionedIndex>,
735        stream_key_indices: Vec<usize>,
736        stream_time_col_idx: usize,
737        join_type: LookupJoinType,
738        output_schema: SchemaRef,
739        key_sort_fields: Vec<SortField>,
740    ) -> Result<Self> {
741        let output_schema = if join_type == LookupJoinType::LeftOuter {
742            let stream_count = input.schema().fields().len();
743            let mut fields = output_schema.fields().to_vec();
744            for f in &mut fields[stream_count..] {
745                if !f.is_nullable() {
746                    *f = Arc::new(f.as_ref().clone().with_nullable(true));
747                }
748            }
749            Arc::new(Schema::new_with_metadata(
750                fields,
751                output_schema.metadata().clone(),
752            ))
753        } else {
754            output_schema
755        };
756
757        let properties = PlanProperties::new(
758            EquivalenceProperties::new(Arc::clone(&output_schema)),
759            Partitioning::UnknownPartitioning(1),
760            EmissionType::Incremental,
761            Boundedness::Unbounded {
762                requires_infinite_memory: false,
763            },
764        );
765
766        let stream_field_count = input.schema().fields().len();
767
768        Ok(Self {
769            input,
770            index,
771            table_batch: Arc::new(table_batch),
772            stream_key_indices,
773            stream_time_col_idx,
774            join_type,
775            schema: output_schema,
776            properties,
777            key_sort_fields,
778            stream_field_count,
779        })
780    }
781}
782
783impl Debug for VersionedLookupJoinExec {
784    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
785        f.debug_struct("VersionedLookupJoinExec")
786            .field("join_type", &self.join_type)
787            .field("stream_keys", &self.stream_key_indices)
788            .field("table_rows", &self.table_batch.num_rows())
789            .finish_non_exhaustive()
790    }
791}
792
793impl DisplayAs for VersionedLookupJoinExec {
794    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter<'_>) -> fmt::Result {
795        match t {
796            DisplayFormatType::Default | DisplayFormatType::Verbose => {
797                write!(
798                    f,
799                    "VersionedLookupJoinExec: type={}, stream_keys={:?}, table_rows={}",
800                    self.join_type,
801                    self.stream_key_indices,
802                    self.table_batch.num_rows(),
803                )
804            }
805            DisplayFormatType::TreeRender => write!(f, "VersionedLookupJoinExec"),
806        }
807    }
808}
809
810impl ExecutionPlan for VersionedLookupJoinExec {
811    fn name(&self) -> &'static str {
812        "VersionedLookupJoinExec"
813    }
814
815    fn as_any(&self) -> &dyn Any {
816        self
817    }
818
819    fn schema(&self) -> SchemaRef {
820        Arc::clone(&self.schema)
821    }
822
823    fn properties(&self) -> &PlanProperties {
824        &self.properties
825    }
826
827    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
828        vec![&self.input]
829    }
830
831    fn with_new_children(
832        self: Arc<Self>,
833        mut children: Vec<Arc<dyn ExecutionPlan>>,
834    ) -> Result<Arc<dyn ExecutionPlan>> {
835        if children.len() != 1 {
836            return Err(DataFusionError::Plan(
837                "VersionedLookupJoinExec requires exactly one child".into(),
838            ));
839        }
840        Ok(Arc::new(Self {
841            input: children.swap_remove(0),
842            index: Arc::clone(&self.index),
843            table_batch: Arc::clone(&self.table_batch),
844            stream_key_indices: self.stream_key_indices.clone(),
845            stream_time_col_idx: self.stream_time_col_idx,
846            join_type: self.join_type,
847            schema: Arc::clone(&self.schema),
848            properties: self.properties.clone(),
849            key_sort_fields: self.key_sort_fields.clone(),
850            stream_field_count: self.stream_field_count,
851        }))
852    }
853
854    fn execute(
855        &self,
856        partition: usize,
857        context: Arc<TaskContext>,
858    ) -> Result<SendableRecordBatchStream> {
859        let input_stream = self.input.execute(partition, context)?;
860        let converter = RowConverter::new(self.key_sort_fields.clone())?;
861        let index = Arc::clone(&self.index);
862        let table_batch = Arc::clone(&self.table_batch);
863        let stream_key_indices = self.stream_key_indices.clone();
864        let stream_time_col_idx = self.stream_time_col_idx;
865        let join_type = self.join_type;
866        let schema = self.schema();
867        let stream_field_count = self.stream_field_count;
868
869        let output = input_stream.map(move |result| {
870            let batch = result?;
871            if batch.num_rows() == 0 {
872                return Ok(RecordBatch::new_empty(Arc::clone(&schema)));
873            }
874            probe_versioned_batch(
875                &batch,
876                &converter,
877                &index,
878                &table_batch,
879                &stream_key_indices,
880                stream_time_col_idx,
881                join_type,
882                &schema,
883                stream_field_count,
884            )
885        });
886
887        Ok(Box::pin(RecordBatchStreamAdapter::new(
888            self.schema(),
889            output,
890        )))
891    }
892}
893
894impl datafusion::physical_plan::ExecutionPlanProperties for VersionedLookupJoinExec {
895    fn output_partitioning(&self) -> &Partitioning {
896        self.properties.output_partitioning()
897    }
898
899    fn output_ordering(&self) -> Option<&LexOrdering> {
900        self.properties.output_ordering()
901    }
902
903    fn boundedness(&self) -> Boundedness {
904        Boundedness::Unbounded {
905            requires_infinite_memory: false,
906        }
907    }
908
909    fn pipeline_behavior(&self) -> EmissionType {
910        EmissionType::Incremental
911    }
912
913    fn equivalence_properties(&self) -> &EquivalenceProperties {
914        self.properties.equivalence_properties()
915    }
916}
917
918/// Probes the versioned index for each row in `stream_batch`, finding
919/// the table row with the latest version `<= event_ts`.
920#[allow(clippy::too_many_arguments)]
921fn probe_versioned_batch(
922    stream_batch: &RecordBatch,
923    converter: &RowConverter,
924    index: &VersionedIndex,
925    table_batch: &RecordBatch,
926    stream_key_indices: &[usize],
927    stream_time_col_idx: usize,
928    join_type: LookupJoinType,
929    output_schema: &SchemaRef,
930    stream_field_count: usize,
931) -> Result<RecordBatch> {
932    let key_cols: Vec<_> = stream_key_indices
933        .iter()
934        .map(|&i| stream_batch.column(i).clone())
935        .collect();
936    let rows = converter.convert_columns(&key_cols)?;
937    let event_timestamps =
938        extract_i64_timestamps(stream_batch.column(stream_time_col_idx).as_ref())?;
939
940    let num_rows = stream_batch.num_rows();
941    let mut stream_indices: Vec<u32> = Vec::with_capacity(num_rows);
942    let mut lookup_indices: Vec<Option<u32>> = Vec::with_capacity(num_rows);
943
944    #[allow(clippy::cast_possible_truncation)]
945    for (row, event_ts_opt) in event_timestamps.iter().enumerate() {
946        // Null keys or null event timestamps cannot match.
947        if key_cols.iter().any(|c| c.is_null(row)) || event_ts_opt.is_none() {
948            if join_type == LookupJoinType::LeftOuter {
949                stream_indices.push(row as u32);
950                lookup_indices.push(None);
951            }
952            continue;
953        }
954
955        let key = rows.row(row);
956        let event_ts = event_ts_opt.unwrap();
957        match index.probe_at_time(key.as_ref(), event_ts) {
958            Some(table_row_idx) => {
959                stream_indices.push(row as u32);
960                lookup_indices.push(Some(table_row_idx));
961            }
962            None if join_type == LookupJoinType::LeftOuter => {
963                stream_indices.push(row as u32);
964                lookup_indices.push(None);
965            }
966            None => {}
967        }
968    }
969
970    if stream_indices.is_empty() {
971        return Ok(RecordBatch::new_empty(Arc::clone(output_schema)));
972    }
973
974    let take_stream = UInt32Array::from(stream_indices);
975    let mut columns = Vec::with_capacity(output_schema.fields().len());
976
977    for col in stream_batch.columns() {
978        columns.push(take(col.as_ref(), &take_stream, None)?);
979    }
980
981    let take_lookup: UInt32Array = lookup_indices.into_iter().collect();
982    for col in table_batch.columns() {
983        columns.push(take(col.as_ref(), &take_lookup, None)?);
984    }
985
986    debug_assert_eq!(
987        columns.len(),
988        stream_field_count + table_batch.num_columns(),
989        "output column count mismatch"
990    );
991
992    Ok(RecordBatch::try_new(Arc::clone(output_schema), columns)?)
993}
994
995// ── Partial Lookup Join Exec ──────────────────────────────────────
996
997/// Physical plan that probes a bounded foyer cache per key for each
998/// batch from the streaming input. Used for on-demand/partial tables
999/// where the full dataset does not fit in memory.
1000pub struct PartialLookupJoinExec {
1001    input: Arc<dyn ExecutionPlan>,
1002    foyer_cache: Arc<FoyerMemoryCache>,
1003    stream_key_indices: Vec<usize>,
1004    join_type: LookupJoinType,
1005    schema: SchemaRef,
1006    properties: PlanProperties,
1007    key_sort_fields: Vec<SortField>,
1008    stream_field_count: usize,
1009    lookup_schema: SchemaRef,
1010    source: Option<Arc<dyn LookupSourceDyn>>,
1011    fetch_semaphore: Arc<Semaphore>,
1012}
1013
1014impl PartialLookupJoinExec {
1015    /// Creates a new partial lookup join executor.
1016    ///
1017    /// # Errors
1018    ///
1019    /// Returns an error if the output schema cannot be constructed.
1020    pub fn try_new(
1021        input: Arc<dyn ExecutionPlan>,
1022        foyer_cache: Arc<FoyerMemoryCache>,
1023        stream_key_indices: Vec<usize>,
1024        key_sort_fields: Vec<SortField>,
1025        join_type: LookupJoinType,
1026        lookup_schema: SchemaRef,
1027        output_schema: SchemaRef,
1028    ) -> Result<Self> {
1029        Self::try_new_with_source(
1030            input,
1031            foyer_cache,
1032            stream_key_indices,
1033            key_sort_fields,
1034            join_type,
1035            lookup_schema,
1036            output_schema,
1037            None,
1038            Arc::new(Semaphore::new(64)),
1039        )
1040    }
1041
1042    /// Creates a new partial lookup join executor with optional source fallback.
1043    ///
1044    /// # Errors
1045    ///
1046    /// Returns an error if the output schema cannot be constructed.
1047    #[allow(clippy::too_many_arguments)]
1048    pub fn try_new_with_source(
1049        input: Arc<dyn ExecutionPlan>,
1050        foyer_cache: Arc<FoyerMemoryCache>,
1051        stream_key_indices: Vec<usize>,
1052        key_sort_fields: Vec<SortField>,
1053        join_type: LookupJoinType,
1054        lookup_schema: SchemaRef,
1055        output_schema: SchemaRef,
1056        source: Option<Arc<dyn LookupSourceDyn>>,
1057        fetch_semaphore: Arc<Semaphore>,
1058    ) -> Result<Self> {
1059        let output_schema = if join_type == LookupJoinType::LeftOuter {
1060            let stream_count = input.schema().fields().len();
1061            let mut fields = output_schema.fields().to_vec();
1062            for f in &mut fields[stream_count..] {
1063                if !f.is_nullable() {
1064                    *f = Arc::new(f.as_ref().clone().with_nullable(true));
1065                }
1066            }
1067            Arc::new(Schema::new_with_metadata(
1068                fields,
1069                output_schema.metadata().clone(),
1070            ))
1071        } else {
1072            output_schema
1073        };
1074
1075        let properties = PlanProperties::new(
1076            EquivalenceProperties::new(Arc::clone(&output_schema)),
1077            Partitioning::UnknownPartitioning(1),
1078            EmissionType::Incremental,
1079            Boundedness::Unbounded {
1080                requires_infinite_memory: false,
1081            },
1082        );
1083
1084        let stream_field_count = input.schema().fields().len();
1085
1086        Ok(Self {
1087            input,
1088            foyer_cache,
1089            stream_key_indices,
1090            join_type,
1091            schema: output_schema,
1092            properties,
1093            key_sort_fields,
1094            stream_field_count,
1095            lookup_schema,
1096            source,
1097            fetch_semaphore,
1098        })
1099    }
1100}
1101
1102impl Debug for PartialLookupJoinExec {
1103    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
1104        f.debug_struct("PartialLookupJoinExec")
1105            .field("join_type", &self.join_type)
1106            .field("stream_keys", &self.stream_key_indices)
1107            .field("cache_table_id", &self.foyer_cache.table_id())
1108            .finish_non_exhaustive()
1109    }
1110}
1111
1112impl DisplayAs for PartialLookupJoinExec {
1113    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter<'_>) -> fmt::Result {
1114        match t {
1115            DisplayFormatType::Default | DisplayFormatType::Verbose => {
1116                write!(
1117                    f,
1118                    "PartialLookupJoinExec: type={}, stream_keys={:?}, cache_entries={}",
1119                    self.join_type,
1120                    self.stream_key_indices,
1121                    self.foyer_cache.len(),
1122                )
1123            }
1124            DisplayFormatType::TreeRender => write!(f, "PartialLookupJoinExec"),
1125        }
1126    }
1127}
1128
1129impl ExecutionPlan for PartialLookupJoinExec {
1130    fn name(&self) -> &'static str {
1131        "PartialLookupJoinExec"
1132    }
1133
1134    fn as_any(&self) -> &dyn Any {
1135        self
1136    }
1137
1138    fn schema(&self) -> SchemaRef {
1139        Arc::clone(&self.schema)
1140    }
1141
1142    fn properties(&self) -> &PlanProperties {
1143        &self.properties
1144    }
1145
1146    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1147        vec![&self.input]
1148    }
1149
1150    fn with_new_children(
1151        self: Arc<Self>,
1152        mut children: Vec<Arc<dyn ExecutionPlan>>,
1153    ) -> Result<Arc<dyn ExecutionPlan>> {
1154        if children.len() != 1 {
1155            return Err(DataFusionError::Plan(
1156                "PartialLookupJoinExec requires exactly one child".into(),
1157            ));
1158        }
1159        Ok(Arc::new(Self {
1160            input: children.swap_remove(0),
1161            foyer_cache: Arc::clone(&self.foyer_cache),
1162            stream_key_indices: self.stream_key_indices.clone(),
1163            join_type: self.join_type,
1164            schema: Arc::clone(&self.schema),
1165            properties: self.properties.clone(),
1166            key_sort_fields: self.key_sort_fields.clone(),
1167            stream_field_count: self.stream_field_count,
1168            lookup_schema: Arc::clone(&self.lookup_schema),
1169            source: self.source.clone(),
1170            fetch_semaphore: Arc::clone(&self.fetch_semaphore),
1171        }))
1172    }
1173
1174    fn execute(
1175        &self,
1176        partition: usize,
1177        context: Arc<TaskContext>,
1178    ) -> Result<SendableRecordBatchStream> {
1179        let input_stream = self.input.execute(partition, context)?;
1180        let converter = Arc::new(RowConverter::new(self.key_sort_fields.clone())?);
1181        let foyer_cache = Arc::clone(&self.foyer_cache);
1182        let stream_key_indices = self.stream_key_indices.clone();
1183        let join_type = self.join_type;
1184        let schema = self.schema();
1185        let stream_field_count = self.stream_field_count;
1186        let lookup_schema = Arc::clone(&self.lookup_schema);
1187        let source = self.source.clone();
1188        let fetch_semaphore = Arc::clone(&self.fetch_semaphore);
1189
1190        let output = input_stream.then(move |result| {
1191            let foyer_cache = Arc::clone(&foyer_cache);
1192            let converter = Arc::clone(&converter);
1193            let stream_key_indices = stream_key_indices.clone();
1194            let schema = Arc::clone(&schema);
1195            let lookup_schema = Arc::clone(&lookup_schema);
1196            let source = source.clone();
1197            let fetch_semaphore = Arc::clone(&fetch_semaphore);
1198            async move {
1199                let batch = result?;
1200                if batch.num_rows() == 0 {
1201                    return Ok(RecordBatch::new_empty(Arc::clone(&schema)));
1202                }
1203                probe_partial_batch_with_fallback(
1204                    &batch,
1205                    &converter,
1206                    &foyer_cache,
1207                    &stream_key_indices,
1208                    join_type,
1209                    &schema,
1210                    stream_field_count,
1211                    &lookup_schema,
1212                    source.as_deref(),
1213                    &fetch_semaphore,
1214                )
1215                .await
1216            }
1217        });
1218
1219        Ok(Box::pin(RecordBatchStreamAdapter::new(
1220            self.schema(),
1221            output,
1222        )))
1223    }
1224}
1225
1226impl datafusion::physical_plan::ExecutionPlanProperties for PartialLookupJoinExec {
1227    fn output_partitioning(&self) -> &Partitioning {
1228        self.properties.output_partitioning()
1229    }
1230
1231    fn output_ordering(&self) -> Option<&LexOrdering> {
1232        self.properties.output_ordering()
1233    }
1234
1235    fn boundedness(&self) -> Boundedness {
1236        Boundedness::Unbounded {
1237            requires_infinite_memory: false,
1238        }
1239    }
1240
1241    fn pipeline_behavior(&self) -> EmissionType {
1242        EmissionType::Incremental
1243    }
1244
1245    fn equivalence_properties(&self) -> &EquivalenceProperties {
1246        self.properties.equivalence_properties()
1247    }
1248}
1249
1250/// Probes the foyer cache for each row in `stream_batch`, falling back
1251/// to the async source for cache misses. Inserts source results into
1252/// the cache before building the output.
1253#[allow(clippy::too_many_arguments)]
1254async fn probe_partial_batch_with_fallback(
1255    stream_batch: &RecordBatch,
1256    converter: &RowConverter,
1257    foyer_cache: &FoyerMemoryCache,
1258    stream_key_indices: &[usize],
1259    join_type: LookupJoinType,
1260    output_schema: &SchemaRef,
1261    stream_field_count: usize,
1262    lookup_schema: &SchemaRef,
1263    source: Option<&dyn LookupSourceDyn>,
1264    fetch_semaphore: &Semaphore,
1265) -> Result<RecordBatch> {
1266    let key_cols: Vec<_> = stream_key_indices
1267        .iter()
1268        .map(|&i| stream_batch.column(i).clone())
1269        .collect();
1270    let rows = converter.convert_columns(&key_cols)?;
1271
1272    let num_rows = stream_batch.num_rows();
1273    let mut stream_indices: Vec<u32> = Vec::with_capacity(num_rows);
1274    let mut lookup_batches: Vec<Option<RecordBatch>> = Vec::with_capacity(num_rows);
1275    let mut miss_keys: Vec<(usize, Vec<u8>)> = Vec::new();
1276
1277    #[allow(clippy::cast_possible_truncation)]
1278    for row in 0..num_rows {
1279        // SQL semantics: NULL != NULL, so rows with any null key never match.
1280        if key_cols.iter().any(|c| c.is_null(row)) {
1281            if join_type == LookupJoinType::LeftOuter {
1282                stream_indices.push(row as u32);
1283                lookup_batches.push(None);
1284            }
1285            continue;
1286        }
1287
1288        let key = rows.row(row);
1289        let result = foyer_cache.get_cached(key.as_ref());
1290        if let Some(batch) = result.into_batch() {
1291            stream_indices.push(row as u32);
1292            lookup_batches.push(Some(batch));
1293        } else {
1294            let idx = stream_indices.len();
1295            stream_indices.push(row as u32);
1296            lookup_batches.push(None);
1297            miss_keys.push((idx, key.as_ref().to_vec()));
1298        }
1299    }
1300
1301    // Fetch missed keys from the source in a single batch query
1302    if let Some(source) = source {
1303        if !miss_keys.is_empty() {
1304            let _permit = fetch_semaphore
1305                .acquire()
1306                .await
1307                .map_err(|_| DataFusionError::Internal("fetch semaphore closed".into()))?;
1308
1309            let key_refs: Vec<&[u8]> = miss_keys.iter().map(|(_, k)| k.as_slice()).collect();
1310            let source_results = source.query_batch(&key_refs, &[], &[]).await;
1311
1312            match source_results {
1313                Ok(results) => {
1314                    for ((idx, key_bytes), maybe_batch) in miss_keys.iter().zip(results.into_iter())
1315                    {
1316                        if let Some(batch) = maybe_batch {
1317                            foyer_cache.insert(key_bytes, batch.clone());
1318                            lookup_batches[*idx] = Some(batch);
1319                        }
1320                    }
1321                }
1322                Err(e) => {
1323                    tracing::warn!(error = %e, "source fallback failed, serving cache-only results");
1324                }
1325            }
1326        }
1327    }
1328
1329    // For inner joins, remove rows that still have no match
1330    if join_type == LookupJoinType::Inner {
1331        let mut write = 0;
1332        for read in 0..stream_indices.len() {
1333            if lookup_batches[read].is_some() {
1334                stream_indices[write] = stream_indices[read];
1335                lookup_batches.swap(write, read);
1336                write += 1;
1337            }
1338        }
1339        stream_indices.truncate(write);
1340        lookup_batches.truncate(write);
1341    }
1342
1343    if stream_indices.is_empty() {
1344        return Ok(RecordBatch::new_empty(Arc::clone(output_schema)));
1345    }
1346
1347    let take_indices = UInt32Array::from(stream_indices);
1348    let mut columns = Vec::with_capacity(output_schema.fields().len());
1349
1350    for col in stream_batch.columns() {
1351        columns.push(take(col.as_ref(), &take_indices, None)?);
1352    }
1353
1354    let lookup_col_count = lookup_schema.fields().len();
1355    for col_idx in 0..lookup_col_count {
1356        let arrays: Vec<_> = lookup_batches
1357            .iter()
1358            .map(|opt| match opt {
1359                Some(b) => b.column(col_idx).clone(),
1360                None => arrow_array::new_null_array(lookup_schema.field(col_idx).data_type(), 1),
1361            })
1362            .collect();
1363        let refs: Vec<&dyn arrow_array::Array> = arrays.iter().map(AsRef::as_ref).collect();
1364        columns.push(arrow::compute::concat(&refs)?);
1365    }
1366
1367    debug_assert_eq!(
1368        columns.len(),
1369        stream_field_count + lookup_col_count,
1370        "output column count mismatch"
1371    );
1372
1373    Ok(RecordBatch::try_new(Arc::clone(output_schema), columns)?)
1374}
1375
1376// ── Extension Planner ────────────────────────────────────────────
1377
1378/// Converts `LookupJoinNode` logical plans to [`LookupJoinExec`]
1379/// or [`PartialLookupJoinExec`] physical plans by resolving table
1380/// data from the registry.
1381pub struct LookupJoinExtensionPlanner {
1382    registry: Arc<LookupTableRegistry>,
1383}
1384
1385impl LookupJoinExtensionPlanner {
1386    /// Creates a planner backed by the given registry.
1387    pub fn new(registry: Arc<LookupTableRegistry>) -> Self {
1388        Self { registry }
1389    }
1390}
1391
1392#[async_trait]
1393impl ExtensionPlanner for LookupJoinExtensionPlanner {
1394    #[allow(clippy::too_many_lines)]
1395    async fn plan_extension(
1396        &self,
1397        _planner: &dyn PhysicalPlanner,
1398        node: &dyn UserDefinedLogicalNode,
1399        _logical_inputs: &[&LogicalPlan],
1400        physical_inputs: &[Arc<dyn ExecutionPlan>],
1401        session_state: &SessionState,
1402    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
1403        let Some(lookup_node) = node.as_any().downcast_ref::<LookupJoinNode>() else {
1404            return Ok(None);
1405        };
1406
1407        let entry = self
1408            .registry
1409            .get_entry(lookup_node.lookup_table_name())
1410            .ok_or_else(|| {
1411                DataFusionError::Plan(format!(
1412                    "lookup table '{}' not registered",
1413                    lookup_node.lookup_table_name()
1414                ))
1415            })?;
1416
1417        let input = Arc::clone(&physical_inputs[0]);
1418        let stream_schema = input.schema();
1419
1420        match entry {
1421            RegisteredLookup::Partial(partial_state) => {
1422                let stream_key_indices = resolve_stream_keys(lookup_node, &stream_schema)?;
1423
1424                let mut output_fields = stream_schema.fields().to_vec();
1425                output_fields.extend(partial_state.schema.fields().iter().cloned());
1426                let output_schema = Arc::new(Schema::new(output_fields));
1427
1428                let exec = PartialLookupJoinExec::try_new_with_source(
1429                    input,
1430                    Arc::clone(&partial_state.foyer_cache),
1431                    stream_key_indices,
1432                    partial_state.key_sort_fields.clone(),
1433                    lookup_node.join_type(),
1434                    Arc::clone(&partial_state.schema),
1435                    output_schema,
1436                    partial_state.source.clone(),
1437                    Arc::clone(&partial_state.fetch_semaphore),
1438                )?;
1439                Ok(Some(Arc::new(exec)))
1440            }
1441            RegisteredLookup::Snapshot(snapshot) => {
1442                let lookup_schema = snapshot.batch.schema();
1443                let lookup_key_indices = resolve_lookup_keys(lookup_node, &lookup_schema)?;
1444
1445                let lookup_batch = if lookup_node.pushdown_predicates().is_empty()
1446                    || snapshot.batch.num_rows() == 0
1447                {
1448                    snapshot.batch.clone()
1449                } else {
1450                    apply_pushdown_predicates(
1451                        &snapshot.batch,
1452                        lookup_node.pushdown_predicates(),
1453                        session_state,
1454                    )?
1455                };
1456
1457                let stream_key_indices = resolve_stream_keys(lookup_node, &stream_schema)?;
1458
1459                // Validate join key types are compatible
1460                for (si, li) in stream_key_indices.iter().zip(&lookup_key_indices) {
1461                    let st = stream_schema.field(*si).data_type();
1462                    let lt = lookup_schema.field(*li).data_type();
1463                    if st != lt {
1464                        return Err(DataFusionError::Plan(format!(
1465                            "Lookup join key type mismatch: stream '{}' is {st:?} \
1466                             but lookup '{}' is {lt:?}",
1467                            stream_schema.field(*si).name(),
1468                            lookup_schema.field(*li).name(),
1469                        )));
1470                    }
1471                }
1472
1473                let mut output_fields = stream_schema.fields().to_vec();
1474                output_fields.extend(lookup_batch.schema().fields().iter().cloned());
1475                let output_schema = Arc::new(Schema::new(output_fields));
1476
1477                let exec = LookupJoinExec::try_new(
1478                    input,
1479                    lookup_batch,
1480                    stream_key_indices,
1481                    lookup_key_indices,
1482                    lookup_node.join_type(),
1483                    output_schema,
1484                )?;
1485
1486                Ok(Some(Arc::new(exec)))
1487            }
1488            RegisteredLookup::Versioned(versioned_state) => {
1489                let table_schema = versioned_state.batch.schema();
1490                let lookup_key_indices = resolve_lookup_keys(lookup_node, &table_schema)?;
1491                let stream_key_indices = resolve_stream_keys(lookup_node, &stream_schema)?;
1492
1493                // Validate key type compatibility.
1494                for (si, li) in stream_key_indices.iter().zip(&lookup_key_indices) {
1495                    let st = stream_schema.field(*si).data_type();
1496                    let lt = table_schema.field(*li).data_type();
1497                    if st != lt {
1498                        return Err(DataFusionError::Plan(format!(
1499                            "Temporal join key type mismatch: stream '{}' is {st:?} \
1500                             but table '{}' is {lt:?}",
1501                            stream_schema.field(*si).name(),
1502                            table_schema.field(*li).name(),
1503                        )));
1504                    }
1505                }
1506
1507                let stream_time_col_idx = stream_schema
1508                    .index_of(&versioned_state.stream_time_column)
1509                    .map_err(|_| {
1510                        DataFusionError::Plan(format!(
1511                            "stream time column '{}' not found in stream schema",
1512                            versioned_state.stream_time_column
1513                        ))
1514                    })?;
1515
1516                let key_sort_fields: Vec<SortField> = lookup_key_indices
1517                    .iter()
1518                    .map(|&i| SortField::new(table_schema.field(i).data_type().clone()))
1519                    .collect();
1520
1521                let mut output_fields = stream_schema.fields().to_vec();
1522                output_fields.extend(table_schema.fields().iter().cloned());
1523                let output_schema = Arc::new(Schema::new(output_fields));
1524
1525                let exec = VersionedLookupJoinExec::try_new(
1526                    input,
1527                    versioned_state.batch.clone(),
1528                    Arc::clone(&versioned_state.index),
1529                    stream_key_indices,
1530                    stream_time_col_idx,
1531                    lookup_node.join_type(),
1532                    output_schema,
1533                    key_sort_fields,
1534                )?;
1535
1536                Ok(Some(Arc::new(exec)))
1537            }
1538        }
1539    }
1540}
1541
1542/// Evaluates pushdown predicates against the lookup snapshot, returning
1543/// only the rows that pass all predicates. This shrinks the hash index.
1544fn apply_pushdown_predicates(
1545    batch: &RecordBatch,
1546    predicates: &[Expr],
1547    session_state: &SessionState,
1548) -> Result<RecordBatch> {
1549    use arrow::compute::filter_record_batch;
1550    use datafusion::physical_expr::create_physical_expr;
1551
1552    let schema = batch.schema();
1553    let df_schema = datafusion::common::DFSchema::try_from(schema.as_ref().clone())?;
1554
1555    let mut mask = None::<arrow_array::BooleanArray>;
1556    for pred in predicates {
1557        let phys_expr = create_physical_expr(pred, &df_schema, session_state.execution_props())?;
1558        let result = phys_expr.evaluate(batch)?;
1559        let bool_arr = result
1560            .into_array(batch.num_rows())?
1561            .as_any()
1562            .downcast_ref::<arrow_array::BooleanArray>()
1563            .ok_or_else(|| {
1564                DataFusionError::Internal("pushdown predicate did not evaluate to boolean".into())
1565            })?
1566            .clone();
1567        mask = Some(match mask {
1568            Some(existing) => arrow::compute::and(&existing, &bool_arr)?,
1569            None => bool_arr,
1570        });
1571    }
1572
1573    match mask {
1574        Some(m) => Ok(filter_record_batch(batch, &m)?),
1575        None => Ok(batch.clone()),
1576    }
1577}
1578
1579fn resolve_stream_keys(node: &LookupJoinNode, schema: &SchemaRef) -> Result<Vec<usize>> {
1580    node.join_keys()
1581        .iter()
1582        .map(|pair| match &pair.stream_expr {
1583            Expr::Column(col) => schema.index_of(&col.name).map_err(|_| {
1584                DataFusionError::Plan(format!(
1585                    "stream key column '{}' not found in physical schema",
1586                    col.name
1587                ))
1588            }),
1589            other => Err(DataFusionError::NotImplemented(format!(
1590                "lookup join requires column references as stream keys, got: {other}"
1591            ))),
1592        })
1593        .collect()
1594}
1595
1596fn resolve_lookup_keys(node: &LookupJoinNode, schema: &SchemaRef) -> Result<Vec<usize>> {
1597    node.join_keys()
1598        .iter()
1599        .map(|pair| {
1600            schema.index_of(&pair.lookup_column).map_err(|_| {
1601                DataFusionError::Plan(format!(
1602                    "lookup key column '{}' not found in lookup table schema",
1603                    pair.lookup_column
1604                ))
1605            })
1606        })
1607        .collect()
1608}
1609
1610// ── Tests ────────────────────────────────────────────────────────
1611
1612#[cfg(test)]
1613mod tests {
1614    use super::*;
1615    use arrow_array::{Array, Float64Array, Int64Array, StringArray};
1616    use arrow_schema::{DataType, Field};
1617    use datafusion::physical_plan::stream::RecordBatchStreamAdapter as TestStreamAdapter;
1618    use futures::TryStreamExt;
1619
1620    /// Creates a bounded `ExecutionPlan` from a single `RecordBatch`.
1621    fn batch_exec(batch: RecordBatch) -> Arc<dyn ExecutionPlan> {
1622        let schema = batch.schema();
1623        let batches = vec![batch];
1624        let stream_schema = Arc::clone(&schema);
1625        Arc::new(StreamExecStub {
1626            schema,
1627            batches: std::sync::Mutex::new(Some(batches)),
1628            stream_schema,
1629        })
1630    }
1631
1632    /// Minimal bounded exec for tests — produces one partition of batches.
1633    struct StreamExecStub {
1634        schema: SchemaRef,
1635        batches: std::sync::Mutex<Option<Vec<RecordBatch>>>,
1636        stream_schema: SchemaRef,
1637    }
1638
1639    impl Debug for StreamExecStub {
1640        fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
1641            write!(f, "StreamExecStub")
1642        }
1643    }
1644
1645    impl DisplayAs for StreamExecStub {
1646        fn fmt_as(&self, _: DisplayFormatType, f: &mut Formatter<'_>) -> fmt::Result {
1647            write!(f, "StreamExecStub")
1648        }
1649    }
1650
1651    impl ExecutionPlan for StreamExecStub {
1652        fn name(&self) -> &'static str {
1653            "StreamExecStub"
1654        }
1655        fn as_any(&self) -> &dyn Any {
1656            self
1657        }
1658        fn schema(&self) -> SchemaRef {
1659            Arc::clone(&self.schema)
1660        }
1661        fn properties(&self) -> &PlanProperties {
1662            // Leak a static PlanProperties for test simplicity
1663            Box::leak(Box::new(PlanProperties::new(
1664                EquivalenceProperties::new(Arc::clone(&self.schema)),
1665                Partitioning::UnknownPartitioning(1),
1666                EmissionType::Final,
1667                Boundedness::Bounded,
1668            )))
1669        }
1670        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1671            vec![]
1672        }
1673        fn with_new_children(
1674            self: Arc<Self>,
1675            _: Vec<Arc<dyn ExecutionPlan>>,
1676        ) -> Result<Arc<dyn ExecutionPlan>> {
1677            Ok(self)
1678        }
1679        fn execute(&self, _: usize, _: Arc<TaskContext>) -> Result<SendableRecordBatchStream> {
1680            let batches = self.batches.lock().unwrap().take().unwrap_or_default();
1681            let schema = Arc::clone(&self.stream_schema);
1682            let stream = futures::stream::iter(batches.into_iter().map(Ok));
1683            Ok(Box::pin(TestStreamAdapter::new(schema, stream)))
1684        }
1685    }
1686
1687    impl datafusion::physical_plan::ExecutionPlanProperties for StreamExecStub {
1688        fn output_partitioning(&self) -> &Partitioning {
1689            self.properties().output_partitioning()
1690        }
1691        fn output_ordering(&self) -> Option<&LexOrdering> {
1692            None
1693        }
1694        fn boundedness(&self) -> Boundedness {
1695            Boundedness::Bounded
1696        }
1697        fn pipeline_behavior(&self) -> EmissionType {
1698            EmissionType::Final
1699        }
1700        fn equivalence_properties(&self) -> &EquivalenceProperties {
1701            self.properties().equivalence_properties()
1702        }
1703    }
1704
1705    fn orders_schema() -> SchemaRef {
1706        Arc::new(Schema::new(vec![
1707            Field::new("order_id", DataType::Int64, false),
1708            Field::new("customer_id", DataType::Int64, false),
1709            Field::new("amount", DataType::Float64, false),
1710        ]))
1711    }
1712
1713    fn customers_schema() -> SchemaRef {
1714        Arc::new(Schema::new(vec![
1715            Field::new("id", DataType::Int64, false),
1716            Field::new("name", DataType::Utf8, true),
1717        ]))
1718    }
1719
1720    fn output_schema() -> SchemaRef {
1721        Arc::new(Schema::new(vec![
1722            Field::new("order_id", DataType::Int64, false),
1723            Field::new("customer_id", DataType::Int64, false),
1724            Field::new("amount", DataType::Float64, false),
1725            Field::new("id", DataType::Int64, false),
1726            Field::new("name", DataType::Utf8, true),
1727        ]))
1728    }
1729
1730    fn customers_batch() -> RecordBatch {
1731        RecordBatch::try_new(
1732            customers_schema(),
1733            vec![
1734                Arc::new(Int64Array::from(vec![1, 2, 3])),
1735                Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
1736            ],
1737        )
1738        .unwrap()
1739    }
1740
1741    fn orders_batch() -> RecordBatch {
1742        RecordBatch::try_new(
1743            orders_schema(),
1744            vec![
1745                Arc::new(Int64Array::from(vec![100, 101, 102, 103])),
1746                Arc::new(Int64Array::from(vec![1, 2, 99, 3])),
1747                Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0, 40.0])),
1748            ],
1749        )
1750        .unwrap()
1751    }
1752
1753    fn make_exec(join_type: LookupJoinType) -> LookupJoinExec {
1754        let input = batch_exec(orders_batch());
1755        LookupJoinExec::try_new(
1756            input,
1757            customers_batch(),
1758            vec![1], // customer_id
1759            vec![0], // id
1760            join_type,
1761            output_schema(),
1762        )
1763        .unwrap()
1764    }
1765
1766    #[tokio::test]
1767    async fn inner_join_filters_non_matches() {
1768        let exec = make_exec(LookupJoinType::Inner);
1769        let ctx = Arc::new(TaskContext::default());
1770        let stream = exec.execute(0, ctx).unwrap();
1771        let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
1772
1773        let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
1774        assert_eq!(total, 3, "customer_id=99 has no match, filtered by inner");
1775
1776        let names = batches[0]
1777            .column(4)
1778            .as_any()
1779            .downcast_ref::<StringArray>()
1780            .unwrap();
1781        assert_eq!(names.value(0), "Alice");
1782        assert_eq!(names.value(1), "Bob");
1783        assert_eq!(names.value(2), "Charlie");
1784    }
1785
1786    #[tokio::test]
1787    async fn left_outer_preserves_non_matches() {
1788        let exec = make_exec(LookupJoinType::LeftOuter);
1789        let ctx = Arc::new(TaskContext::default());
1790        let stream = exec.execute(0, ctx).unwrap();
1791        let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
1792
1793        let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
1794        assert_eq!(total, 4, "all 4 stream rows preserved in left outer");
1795
1796        let names = batches[0]
1797            .column(4)
1798            .as_any()
1799            .downcast_ref::<StringArray>()
1800            .unwrap();
1801        // Row 2 (customer_id=99) should have null name
1802        assert!(names.is_null(2));
1803    }
1804
1805    #[tokio::test]
1806    async fn empty_lookup_inner_produces_no_rows() {
1807        let empty = RecordBatch::new_empty(customers_schema());
1808        let input = batch_exec(orders_batch());
1809        let exec = LookupJoinExec::try_new(
1810            input,
1811            empty,
1812            vec![1],
1813            vec![0],
1814            LookupJoinType::Inner,
1815            output_schema(),
1816        )
1817        .unwrap();
1818
1819        let ctx = Arc::new(TaskContext::default());
1820        let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
1821        let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
1822        assert_eq!(total, 0);
1823    }
1824
1825    #[tokio::test]
1826    async fn empty_lookup_left_outer_preserves_all_stream_rows() {
1827        let empty = RecordBatch::new_empty(customers_schema());
1828        let input = batch_exec(orders_batch());
1829        let exec = LookupJoinExec::try_new(
1830            input,
1831            empty,
1832            vec![1],
1833            vec![0],
1834            LookupJoinType::LeftOuter,
1835            output_schema(),
1836        )
1837        .unwrap();
1838
1839        let ctx = Arc::new(TaskContext::default());
1840        let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
1841        let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
1842        assert_eq!(total, 4);
1843    }
1844
1845    #[tokio::test]
1846    async fn duplicate_keys_produce_multiple_rows() {
1847        let lookup = RecordBatch::try_new(
1848            customers_schema(),
1849            vec![
1850                Arc::new(Int64Array::from(vec![1, 1])),
1851                Arc::new(StringArray::from(vec!["Alice-A", "Alice-B"])),
1852            ],
1853        )
1854        .unwrap();
1855
1856        let stream = RecordBatch::try_new(
1857            orders_schema(),
1858            vec![
1859                Arc::new(Int64Array::from(vec![100])),
1860                Arc::new(Int64Array::from(vec![1])),
1861                Arc::new(Float64Array::from(vec![10.0])),
1862            ],
1863        )
1864        .unwrap();
1865
1866        let input = batch_exec(stream);
1867        let exec = LookupJoinExec::try_new(
1868            input,
1869            lookup,
1870            vec![1],
1871            vec![0],
1872            LookupJoinType::Inner,
1873            output_schema(),
1874        )
1875        .unwrap();
1876
1877        let ctx = Arc::new(TaskContext::default());
1878        let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
1879        let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
1880        assert_eq!(total, 2, "one stream row matched two lookup rows");
1881    }
1882
1883    #[test]
1884    fn with_new_children_preserves_state() {
1885        let exec = Arc::new(make_exec(LookupJoinType::Inner));
1886        let expected_schema = exec.schema();
1887        let children = exec.children().into_iter().cloned().collect();
1888        let rebuilt = exec.with_new_children(children).unwrap();
1889        assert_eq!(rebuilt.schema(), expected_schema);
1890        assert_eq!(rebuilt.name(), "LookupJoinExec");
1891    }
1892
1893    #[test]
1894    fn display_format() {
1895        let exec = make_exec(LookupJoinType::Inner);
1896        let s = format!("{exec:?}");
1897        assert!(s.contains("LookupJoinExec"));
1898        assert!(s.contains("lookup_rows: 3"));
1899    }
1900
1901    #[test]
1902    fn registry_crud() {
1903        let reg = LookupTableRegistry::new();
1904        assert!(reg.get("customers").is_none());
1905
1906        reg.register(
1907            "customers",
1908            LookupSnapshot {
1909                batch: customers_batch(),
1910                key_columns: vec!["id".into()],
1911            },
1912        );
1913        assert!(reg.get("customers").is_some());
1914        assert!(reg.get("CUSTOMERS").is_some(), "case-insensitive");
1915
1916        reg.unregister("customers");
1917        assert!(reg.get("customers").is_none());
1918    }
1919
1920    #[test]
1921    fn registry_update_replaces() {
1922        let reg = LookupTableRegistry::new();
1923        reg.register(
1924            "t",
1925            LookupSnapshot {
1926                batch: RecordBatch::new_empty(customers_schema()),
1927                key_columns: vec![],
1928            },
1929        );
1930        assert_eq!(reg.get("t").unwrap().batch.num_rows(), 0);
1931
1932        reg.register(
1933            "t",
1934            LookupSnapshot {
1935                batch: customers_batch(),
1936                key_columns: vec![],
1937            },
1938        );
1939        assert_eq!(reg.get("t").unwrap().batch.num_rows(), 3);
1940    }
1941
1942    #[test]
1943    fn pushdown_predicates_filter_snapshot() {
1944        use datafusion::logical_expr::{col, lit};
1945
1946        let batch = customers_batch(); // id=[1,2,3], name=[Alice,Bob,Charlie]
1947        let ctx = datafusion::prelude::SessionContext::new();
1948        let state = ctx.state();
1949
1950        // Filter: id > 1 (should keep rows 2 and 3)
1951        let predicates = vec![col("id").gt(lit(1i64))];
1952        let filtered = apply_pushdown_predicates(&batch, &predicates, &state).unwrap();
1953        assert_eq!(filtered.num_rows(), 2);
1954
1955        let ids = filtered
1956            .column(0)
1957            .as_any()
1958            .downcast_ref::<Int64Array>()
1959            .unwrap();
1960        assert_eq!(ids.value(0), 2);
1961        assert_eq!(ids.value(1), 3);
1962    }
1963
1964    #[test]
1965    fn pushdown_predicates_empty_passes_all() {
1966        let batch = customers_batch();
1967        let ctx = datafusion::prelude::SessionContext::new();
1968        let state = ctx.state();
1969
1970        let filtered = apply_pushdown_predicates(&batch, &[], &state).unwrap();
1971        assert_eq!(filtered.num_rows(), 3);
1972    }
1973
1974    #[test]
1975    fn pushdown_predicates_multiple_and() {
1976        use datafusion::logical_expr::{col, lit};
1977
1978        let batch = customers_batch(); // id=[1,2,3]
1979        let ctx = datafusion::prelude::SessionContext::new();
1980        let state = ctx.state();
1981
1982        // id >= 2 AND id < 3 → only row with id=2
1983        let predicates = vec![col("id").gt_eq(lit(2i64)), col("id").lt(lit(3i64))];
1984        let filtered = apply_pushdown_predicates(&batch, &predicates, &state).unwrap();
1985        assert_eq!(filtered.num_rows(), 1);
1986    }
1987
1988    // ── PartialLookupJoinExec Tests ──────────────────────────────
1989
1990    use laminar_core::lookup::foyer_cache::FoyerMemoryCacheConfig;
1991
1992    fn make_foyer_cache() -> Arc<FoyerMemoryCache> {
1993        Arc::new(FoyerMemoryCache::new(
1994            1,
1995            FoyerMemoryCacheConfig {
1996                capacity: 64,
1997                shards: 4,
1998            },
1999        ))
2000    }
2001
2002    fn customer_row(id: i64, name: &str) -> RecordBatch {
2003        RecordBatch::try_new(
2004            customers_schema(),
2005            vec![
2006                Arc::new(Int64Array::from(vec![id])),
2007                Arc::new(StringArray::from(vec![name])),
2008            ],
2009        )
2010        .unwrap()
2011    }
2012
2013    fn warm_cache(cache: &FoyerMemoryCache) {
2014        let converter = RowConverter::new(vec![SortField::new(DataType::Int64)]).unwrap();
2015
2016        for (id, name) in [(1, "Alice"), (2, "Bob"), (3, "Charlie")] {
2017            let key_col = Int64Array::from(vec![id]);
2018            let rows = converter.convert_columns(&[Arc::new(key_col)]).unwrap();
2019            let key = rows.row(0);
2020            cache.insert(key.as_ref(), customer_row(id, name));
2021        }
2022    }
2023
2024    fn make_partial_exec(join_type: LookupJoinType) -> PartialLookupJoinExec {
2025        let cache = make_foyer_cache();
2026        warm_cache(&cache);
2027
2028        let input = batch_exec(orders_batch());
2029        let key_sort_fields = vec![SortField::new(DataType::Int64)];
2030
2031        PartialLookupJoinExec::try_new(
2032            input,
2033            cache,
2034            vec![1], // customer_id
2035            key_sort_fields,
2036            join_type,
2037            customers_schema(),
2038            output_schema(),
2039        )
2040        .unwrap()
2041    }
2042
2043    #[tokio::test]
2044    async fn partial_inner_join_filters_non_matches() {
2045        let exec = make_partial_exec(LookupJoinType::Inner);
2046        let ctx = Arc::new(TaskContext::default());
2047        let stream = exec.execute(0, ctx).unwrap();
2048        let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
2049
2050        let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2051        assert_eq!(total, 3, "customer_id=99 has no match, filtered by inner");
2052
2053        let names = batches[0]
2054            .column(4)
2055            .as_any()
2056            .downcast_ref::<StringArray>()
2057            .unwrap();
2058        assert_eq!(names.value(0), "Alice");
2059        assert_eq!(names.value(1), "Bob");
2060        assert_eq!(names.value(2), "Charlie");
2061    }
2062
2063    #[tokio::test]
2064    async fn partial_left_outer_preserves_non_matches() {
2065        let exec = make_partial_exec(LookupJoinType::LeftOuter);
2066        let ctx = Arc::new(TaskContext::default());
2067        let stream = exec.execute(0, ctx).unwrap();
2068        let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
2069
2070        let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2071        assert_eq!(total, 4, "all 4 stream rows preserved in left outer");
2072
2073        let names = batches[0]
2074            .column(4)
2075            .as_any()
2076            .downcast_ref::<StringArray>()
2077            .unwrap();
2078        assert!(names.is_null(2), "customer_id=99 should have null name");
2079    }
2080
2081    #[tokio::test]
2082    async fn partial_empty_cache_inner_produces_no_rows() {
2083        let cache = make_foyer_cache();
2084        let input = batch_exec(orders_batch());
2085        let key_sort_fields = vec![SortField::new(DataType::Int64)];
2086
2087        let exec = PartialLookupJoinExec::try_new(
2088            input,
2089            cache,
2090            vec![1],
2091            key_sort_fields,
2092            LookupJoinType::Inner,
2093            customers_schema(),
2094            output_schema(),
2095        )
2096        .unwrap();
2097
2098        let ctx = Arc::new(TaskContext::default());
2099        let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
2100        let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2101        assert_eq!(total, 0);
2102    }
2103
2104    #[tokio::test]
2105    async fn partial_empty_cache_left_outer_preserves_all() {
2106        let cache = make_foyer_cache();
2107        let input = batch_exec(orders_batch());
2108        let key_sort_fields = vec![SortField::new(DataType::Int64)];
2109
2110        let exec = PartialLookupJoinExec::try_new(
2111            input,
2112            cache,
2113            vec![1],
2114            key_sort_fields,
2115            LookupJoinType::LeftOuter,
2116            customers_schema(),
2117            output_schema(),
2118        )
2119        .unwrap();
2120
2121        let ctx = Arc::new(TaskContext::default());
2122        let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
2123        let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2124        assert_eq!(total, 4);
2125    }
2126
2127    #[test]
2128    fn partial_with_new_children_preserves_state() {
2129        let exec = Arc::new(make_partial_exec(LookupJoinType::Inner));
2130        let expected_schema = exec.schema();
2131        let children = exec.children().into_iter().cloned().collect();
2132        let rebuilt = exec.with_new_children(children).unwrap();
2133        assert_eq!(rebuilt.schema(), expected_schema);
2134        assert_eq!(rebuilt.name(), "PartialLookupJoinExec");
2135    }
2136
2137    #[test]
2138    fn partial_display_format() {
2139        let exec = make_partial_exec(LookupJoinType::Inner);
2140        let s = format!("{exec:?}");
2141        assert!(s.contains("PartialLookupJoinExec"));
2142        assert!(s.contains("cache_table_id: 1"));
2143    }
2144
2145    #[test]
2146    fn registry_partial_entry() {
2147        let reg = LookupTableRegistry::new();
2148        let cache = make_foyer_cache();
2149        let key_sort_fields = vec![SortField::new(DataType::Int64)];
2150
2151        reg.register_partial(
2152            "customers",
2153            PartialLookupState {
2154                foyer_cache: cache,
2155                schema: customers_schema(),
2156                key_columns: vec!["id".into()],
2157                key_sort_fields,
2158                source: None,
2159                fetch_semaphore: Arc::new(Semaphore::new(64)),
2160            },
2161        );
2162
2163        assert!(reg.get("customers").is_none());
2164
2165        let entry = reg.get_entry("customers");
2166        assert!(entry.is_some());
2167        assert!(matches!(entry.unwrap(), RegisteredLookup::Partial(_)));
2168    }
2169
2170    #[tokio::test]
2171    async fn partial_source_fallback_on_miss() {
2172        use laminar_core::lookup::source::LookupError;
2173        use laminar_core::lookup::source::LookupSourceDyn;
2174
2175        struct TestSource;
2176
2177        #[async_trait]
2178        impl LookupSourceDyn for TestSource {
2179            async fn query_batch(
2180                &self,
2181                keys: &[&[u8]],
2182                _predicates: &[laminar_core::lookup::predicate::Predicate],
2183                _projection: &[laminar_core::lookup::source::ColumnId],
2184            ) -> std::result::Result<Vec<Option<RecordBatch>>, LookupError> {
2185                Ok(keys
2186                    .iter()
2187                    .map(|_| Some(customer_row(99, "FromSource")))
2188                    .collect())
2189            }
2190
2191            fn schema(&self) -> SchemaRef {
2192                customers_schema()
2193            }
2194        }
2195
2196        let cache = make_foyer_cache();
2197        // Only warm id=1 in cache, id=99 will miss and go to source
2198        warm_cache(&cache);
2199
2200        let orders = RecordBatch::try_new(
2201            orders_schema(),
2202            vec![
2203                Arc::new(Int64Array::from(vec![200])),
2204                Arc::new(Int64Array::from(vec![99])), // not in cache
2205                Arc::new(Float64Array::from(vec![50.0])),
2206            ],
2207        )
2208        .unwrap();
2209
2210        let input = batch_exec(orders);
2211        let key_sort_fields = vec![SortField::new(DataType::Int64)];
2212        let source: Arc<dyn LookupSourceDyn> = Arc::new(TestSource);
2213
2214        let exec = PartialLookupJoinExec::try_new_with_source(
2215            input,
2216            cache,
2217            vec![1],
2218            key_sort_fields,
2219            LookupJoinType::Inner,
2220            customers_schema(),
2221            output_schema(),
2222            Some(source),
2223            Arc::new(Semaphore::new(64)),
2224        )
2225        .unwrap();
2226
2227        let ctx = Arc::new(TaskContext::default());
2228        let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
2229        let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2230        assert_eq!(total, 1, "source fallback should produce 1 row");
2231
2232        let names = batches[0]
2233            .column(4)
2234            .as_any()
2235            .downcast_ref::<StringArray>()
2236            .unwrap();
2237        assert_eq!(names.value(0), "FromSource");
2238    }
2239
2240    #[tokio::test]
2241    async fn partial_source_error_graceful_degradation() {
2242        use laminar_core::lookup::source::LookupError;
2243        use laminar_core::lookup::source::LookupSourceDyn;
2244
2245        struct FailingSource;
2246
2247        #[async_trait]
2248        impl LookupSourceDyn for FailingSource {
2249            async fn query_batch(
2250                &self,
2251                _keys: &[&[u8]],
2252                _predicates: &[laminar_core::lookup::predicate::Predicate],
2253                _projection: &[laminar_core::lookup::source::ColumnId],
2254            ) -> std::result::Result<Vec<Option<RecordBatch>>, LookupError> {
2255                Err(LookupError::Internal("source unavailable".into()))
2256            }
2257
2258            fn schema(&self) -> SchemaRef {
2259                customers_schema()
2260            }
2261        }
2262
2263        let cache = make_foyer_cache();
2264        let input = batch_exec(orders_batch());
2265        let key_sort_fields = vec![SortField::new(DataType::Int64)];
2266        let source: Arc<dyn LookupSourceDyn> = Arc::new(FailingSource);
2267
2268        let exec = PartialLookupJoinExec::try_new_with_source(
2269            input,
2270            cache,
2271            vec![1],
2272            key_sort_fields,
2273            LookupJoinType::LeftOuter,
2274            customers_schema(),
2275            output_schema(),
2276            Some(source),
2277            Arc::new(Semaphore::new(64)),
2278        )
2279        .unwrap();
2280
2281        let ctx = Arc::new(TaskContext::default());
2282        let batches: Vec<RecordBatch> = exec.execute(0, ctx).unwrap().try_collect().await.unwrap();
2283        let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2284        // All rows preserved in left outer, but all lookup columns null
2285        assert_eq!(total, 4);
2286    }
2287
2288    #[test]
2289    fn registry_snapshot_entry_via_get_entry() {
2290        let reg = LookupTableRegistry::new();
2291        reg.register(
2292            "t",
2293            LookupSnapshot {
2294                batch: customers_batch(),
2295                key_columns: vec!["id".into()],
2296            },
2297        );
2298
2299        let entry = reg.get_entry("t");
2300        assert!(matches!(entry.unwrap(), RegisteredLookup::Snapshot(_)));
2301        assert!(reg.get("t").is_some());
2302    }
2303
2304    // ── NULL key tests ────────────────────────────────────────────────
2305
2306    fn nullable_orders_schema() -> SchemaRef {
2307        Arc::new(Schema::new(vec![
2308            Field::new("order_id", DataType::Int64, false),
2309            Field::new("customer_id", DataType::Int64, true), // nullable key
2310            Field::new("amount", DataType::Float64, false),
2311        ]))
2312    }
2313
2314    fn nullable_output_schema(join_type: LookupJoinType) -> SchemaRef {
2315        let lookup_nullable = join_type == LookupJoinType::LeftOuter;
2316        Arc::new(Schema::new(vec![
2317            Field::new("order_id", DataType::Int64, false),
2318            Field::new("customer_id", DataType::Int64, true),
2319            Field::new("amount", DataType::Float64, false),
2320            Field::new("id", DataType::Int64, lookup_nullable),
2321            Field::new("name", DataType::Utf8, true),
2322        ]))
2323    }
2324
2325    #[tokio::test]
2326    async fn null_key_inner_join_no_match() {
2327        // Stream: customer_id = [1, NULL, 2]
2328        let stream_batch = RecordBatch::try_new(
2329            nullable_orders_schema(),
2330            vec![
2331                Arc::new(Int64Array::from(vec![100, 101, 102])),
2332                Arc::new(Int64Array::from(vec![Some(1), None, Some(2)])),
2333                Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])),
2334            ],
2335        )
2336        .unwrap();
2337
2338        let input = batch_exec(stream_batch);
2339        let exec = LookupJoinExec::try_new(
2340            input,
2341            customers_batch(),
2342            vec![1],
2343            vec![0],
2344            LookupJoinType::Inner,
2345            nullable_output_schema(LookupJoinType::Inner),
2346        )
2347        .unwrap();
2348
2349        let ctx = Arc::new(TaskContext::default());
2350        let stream = exec.execute(0, ctx).unwrap();
2351        let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
2352
2353        let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2354        // Only customer_id=1 and customer_id=2 match; NULL is skipped
2355        assert_eq!(total, 2, "NULL key row should not match in inner join");
2356    }
2357
2358    #[tokio::test]
2359    async fn null_key_left_outer_produces_nulls() {
2360        // Stream: customer_id = [1, NULL, 2]
2361        let stream_batch = RecordBatch::try_new(
2362            nullable_orders_schema(),
2363            vec![
2364                Arc::new(Int64Array::from(vec![100, 101, 102])),
2365                Arc::new(Int64Array::from(vec![Some(1), None, Some(2)])),
2366                Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])),
2367            ],
2368        )
2369        .unwrap();
2370
2371        let input = batch_exec(stream_batch);
2372        let out_schema = nullable_output_schema(LookupJoinType::LeftOuter);
2373        let exec = LookupJoinExec::try_new(
2374            input,
2375            customers_batch(),
2376            vec![1],
2377            vec![0],
2378            LookupJoinType::LeftOuter,
2379            out_schema,
2380        )
2381        .unwrap();
2382
2383        let ctx = Arc::new(TaskContext::default());
2384        let stream = exec.execute(0, ctx).unwrap();
2385        let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
2386
2387        let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
2388        // All 3 rows preserved; NULL key row has null lookup columns
2389        assert_eq!(total, 3, "all rows preserved in left outer");
2390
2391        let names = batches[0]
2392            .column(4)
2393            .as_any()
2394            .downcast_ref::<StringArray>()
2395            .unwrap();
2396        assert_eq!(names.value(0), "Alice");
2397        assert!(
2398            names.is_null(1),
2399            "NULL key row should have null lookup name"
2400        );
2401        assert_eq!(names.value(2), "Bob");
2402    }
2403
2404    // ── Versioned Lookup Join Tests ────────────────────────────────
2405
2406    fn versioned_table_batch() -> RecordBatch {
2407        // Table with key=currency, version_ts=valid_from, rate=value
2408        // Two currencies with multiple versions each
2409        let schema = Arc::new(Schema::new(vec![
2410            Field::new("currency", DataType::Utf8, false),
2411            Field::new("valid_from", DataType::Int64, false),
2412            Field::new("rate", DataType::Float64, false),
2413        ]));
2414        RecordBatch::try_new(
2415            schema,
2416            vec![
2417                Arc::new(StringArray::from(vec!["USD", "USD", "EUR", "EUR", "EUR"])),
2418                Arc::new(Int64Array::from(vec![100, 200, 100, 150, 300])),
2419                Arc::new(Float64Array::from(vec![1.0, 1.1, 0.85, 0.90, 0.88])),
2420            ],
2421        )
2422        .unwrap()
2423    }
2424
2425    fn stream_batch_with_time() -> RecordBatch {
2426        let schema = Arc::new(Schema::new(vec![
2427            Field::new("order_id", DataType::Int64, false),
2428            Field::new("currency", DataType::Utf8, false),
2429            Field::new("event_ts", DataType::Int64, false),
2430        ]));
2431        RecordBatch::try_new(
2432            schema,
2433            vec![
2434                Arc::new(Int64Array::from(vec![1, 2, 3, 4])),
2435                Arc::new(StringArray::from(vec!["USD", "EUR", "USD", "EUR"])),
2436                Arc::new(Int64Array::from(vec![150, 160, 250, 50])),
2437            ],
2438        )
2439        .unwrap()
2440    }
2441
2442    #[test]
2443    fn test_versioned_index_build_and_probe() {
2444        let batch = versioned_table_batch();
2445        let index = VersionedIndex::build(&batch, &[0], 1).unwrap();
2446
2447        // USD has versions at 100 and 200
2448        // Probe at 150 → should find version 100 (latest <= 150)
2449        let key_sf = vec![SortField::new(DataType::Utf8)];
2450        let converter = RowConverter::new(key_sf).unwrap();
2451        let usd_col = Arc::new(StringArray::from(vec!["USD"]));
2452        let usd_rows = converter.convert_columns(&[usd_col]).unwrap();
2453        let usd_key = usd_rows.row(0);
2454
2455        let result = index.probe_at_time(usd_key.as_ref(), 150);
2456        assert!(result.is_some());
2457        // Row 0 is USD@100, Row 1 is USD@200. At time 150, should get row 0.
2458        assert_eq!(result.unwrap(), 0);
2459
2460        // Probe at 250 → should find version 200 (row 1)
2461        let result = index.probe_at_time(usd_key.as_ref(), 250);
2462        assert_eq!(result.unwrap(), 1);
2463    }
2464
2465    #[test]
2466    fn test_versioned_index_no_version_before_ts() {
2467        let batch = versioned_table_batch();
2468        let index = VersionedIndex::build(&batch, &[0], 1).unwrap();
2469
2470        let key_sf = vec![SortField::new(DataType::Utf8)];
2471        let converter = RowConverter::new(key_sf).unwrap();
2472        let eur_col = Arc::new(StringArray::from(vec!["EUR"]));
2473        let eur_rows = converter.convert_columns(&[eur_col]).unwrap();
2474        let eur_key = eur_rows.row(0);
2475
2476        // EUR versions start at 100. Probe at 50 → None
2477        let result = index.probe_at_time(eur_key.as_ref(), 50);
2478        assert!(result.is_none());
2479    }
2480
2481    /// Helper to build a VersionedLookupJoinExec for tests.
2482    fn build_versioned_exec(
2483        table: RecordBatch,
2484        stream: &RecordBatch,
2485        join_type: LookupJoinType,
2486    ) -> VersionedLookupJoinExec {
2487        let input = batch_exec(stream.clone());
2488        let index = Arc::new(VersionedIndex::build(&table, &[0], 1).unwrap());
2489        let key_sort_fields = vec![SortField::new(DataType::Utf8)];
2490        let mut output_fields = stream.schema().fields().to_vec();
2491        output_fields.extend(table.schema().fields().iter().cloned());
2492        let output_schema = Arc::new(Schema::new(output_fields));
2493        VersionedLookupJoinExec::try_new(
2494            input,
2495            table,
2496            index,
2497            vec![1], // stream key col: currency
2498            2,       // stream time col: event_ts
2499            join_type,
2500            output_schema,
2501            key_sort_fields,
2502        )
2503        .unwrap()
2504    }
2505
2506    #[tokio::test]
2507    async fn test_versioned_join_exec_inner() {
2508        let table = versioned_table_batch();
2509        let stream = stream_batch_with_time();
2510        let exec = build_versioned_exec(table, &stream, LookupJoinType::Inner);
2511
2512        let ctx = Arc::new(TaskContext::default());
2513        let stream_out = exec.execute(0, ctx).unwrap();
2514        let batches: Vec<RecordBatch> = stream_out.try_collect().await.unwrap();
2515
2516        assert_eq!(batches.len(), 1);
2517        let batch = &batches[0];
2518        // Row 1: order_id=1, USD, ts=150 → USD@100 (rate=1.0)
2519        // Row 2: order_id=2, EUR, ts=160 → EUR@150 (rate=0.90)
2520        // Row 3: order_id=3, USD, ts=250 → USD@200 (rate=1.1)
2521        // Row 4: order_id=4, EUR, ts=50 → no EUR version <= 50 → SKIP (inner)
2522        assert_eq!(batch.num_rows(), 3);
2523
2524        let rates = batch
2525            .column(5) // rate is 6th column (3 stream + 3 table, rate is table col 2)
2526            .as_any()
2527            .downcast_ref::<Float64Array>()
2528            .unwrap();
2529        assert!((rates.value(0) - 1.0).abs() < f64::EPSILON); // USD@100
2530        assert!((rates.value(1) - 0.90).abs() < f64::EPSILON); // EUR@150
2531        assert!((rates.value(2) - 1.1).abs() < f64::EPSILON); // USD@200
2532    }
2533
2534    #[tokio::test]
2535    async fn test_versioned_join_exec_left_outer() {
2536        let table = versioned_table_batch();
2537        let stream = stream_batch_with_time();
2538        let exec = build_versioned_exec(table, &stream, LookupJoinType::LeftOuter);
2539
2540        let ctx = Arc::new(TaskContext::default());
2541        let stream_out = exec.execute(0, ctx).unwrap();
2542        let batches: Vec<RecordBatch> = stream_out.try_collect().await.unwrap();
2543
2544        assert_eq!(batches.len(), 1);
2545        let batch = &batches[0];
2546        // All 4 rows present (left outer)
2547        assert_eq!(batch.num_rows(), 4);
2548
2549        // Row 4 (EUR@50): no version → null rate
2550        let rates = batch
2551            .column(5)
2552            .as_any()
2553            .downcast_ref::<Float64Array>()
2554            .unwrap();
2555        assert!(rates.is_null(3), "EUR@50 should have null rate");
2556    }
2557
2558    #[test]
2559    fn test_versioned_index_empty_batch() {
2560        let schema = Arc::new(Schema::new(vec![
2561            Field::new("k", DataType::Utf8, false),
2562            Field::new("v", DataType::Int64, false),
2563        ]));
2564        let batch = RecordBatch::new_empty(schema);
2565        let index = VersionedIndex::build(&batch, &[0], 1).unwrap();
2566        assert!(index.map.is_empty());
2567    }
2568
2569    #[test]
2570    fn test_versioned_lookup_registry() {
2571        let registry = LookupTableRegistry::new();
2572        let table = versioned_table_batch();
2573        let index = Arc::new(VersionedIndex::build(&table, &[0], 1).unwrap());
2574
2575        registry.register_versioned(
2576            "rates",
2577            VersionedLookupState {
2578                batch: table,
2579                index,
2580                key_columns: vec!["currency".to_string()],
2581                version_column: "valid_from".to_string(),
2582                stream_time_column: "event_ts".to_string(),
2583            },
2584        );
2585
2586        let entry = registry.get_entry("rates");
2587        assert!(entry.is_some());
2588        assert!(matches!(entry.unwrap(), RegisteredLookup::Versioned(_)));
2589
2590        // get() should return None for versioned entries (snapshot-only)
2591        assert!(registry.get("rates").is_none());
2592    }
2593}