Skip to main content

laminar_connectors/lookup/
delta_lookup.rs

1//! Delta Lake on-demand lookup source for cache-miss fallback.
2//!
3//! Implements `LookupSource` backed by a `DataFusion` `TableProvider`.
4//! Used as the `source` field in `PartialLookupState` for tables too
5//! large to snapshot into memory.
6
7#[cfg(feature = "delta-lake")]
8use std::sync::atomic::{AtomicU64, Ordering};
9#[cfg(feature = "delta-lake")]
10use std::sync::Arc;
11
12#[cfg(feature = "delta-lake")]
13use arrow_array::RecordBatch;
14#[cfg(feature = "delta-lake")]
15use arrow_row::{RowConverter, SortField};
16#[cfg(feature = "delta-lake")]
17use arrow_schema::SchemaRef;
18#[cfg(feature = "delta-lake")]
19use datafusion::prelude::SessionContext;
20
21#[cfg(feature = "delta-lake")]
22use laminar_core::lookup::predicate::Predicate;
23#[cfg(feature = "delta-lake")]
24use laminar_core::lookup::source::{ColumnId, LookupError, LookupSource, LookupSourceCapabilities};
25
26/// Configuration for [`DeltaLookupSource`].
27#[cfg(feature = "delta-lake")]
28#[derive(Debug, Clone)]
29pub struct DeltaLookupSourceConfig {
30    /// Table path (resolved, post-catalog).
31    pub table_path: String,
32    /// Storage options (credentials, etc.).
33    pub storage_options: std::collections::HashMap<String, String>,
34    /// Primary key column names.
35    pub primary_key_columns: Vec<String>,
36    /// `DataFusion` table name (registered in session context).
37    pub table_name: String,
38}
39
40/// Delta Lake lookup source for on-demand/partial cache mode.
41#[cfg(feature = "delta-lake")]
42pub struct DeltaLookupSource {
43    ctx: Arc<SessionContext>,
44    config: DeltaLookupSourceConfig,
45    schema: SchemaRef,
46    pk_sort_fields: Vec<SortField>,
47    query_count: AtomicU64,
48    row_count: AtomicU64,
49    error_count: AtomicU64,
50}
51
52#[cfg(feature = "delta-lake")]
53impl DeltaLookupSource {
54    /// Opens the Delta table and registers it as a `DataFusion` `TableProvider`.
55    ///
56    /// # Errors
57    ///
58    /// Returns `LookupError` if the table cannot be opened or registered.
59    pub async fn open(config: DeltaLookupSourceConfig) -> Result<Self, LookupError> {
60        if config.primary_key_columns.is_empty() {
61            return Err(LookupError::Internal(
62                "primary_key_columns must not be empty".into(),
63            ));
64        }
65        let ctx = SessionContext::new();
66
67        crate::lakehouse::delta_table_provider::register_delta_table(
68            &ctx,
69            &config.table_name,
70            &config.table_path,
71            config.storage_options.clone(),
72        )
73        .await
74        .map_err(|e| LookupError::Connection(format!("register delta table: {e}")))?;
75
76        let table = ctx
77            .table(&config.table_name)
78            .await
79            .map_err(|e| LookupError::Internal(format!("get table: {e}")))?;
80        let schema: SchemaRef = Arc::new(table.schema().as_arrow().clone());
81
82        let pk_sort_fields: Vec<SortField> = config
83            .primary_key_columns
84            .iter()
85            .map(|col_name| {
86                let idx = schema.index_of(col_name).map_err(|_| {
87                    LookupError::Internal(format!("pk column not found: {col_name}"))
88                })?;
89                Ok(SortField::new(schema.field(idx).data_type().clone()))
90            })
91            .collect::<Result<Vec<_>, LookupError>>()?;
92
93        Ok(Self {
94            ctx: Arc::new(ctx),
95            config,
96            schema,
97            pk_sort_fields,
98            query_count: AtomicU64::new(0),
99            row_count: AtomicU64::new(0),
100            error_count: AtomicU64::new(0),
101        })
102    }
103
104    /// Build a SQL WHERE clause from decoded PK column arrays.
105    fn build_where_clause(
106        &self,
107        pk_arrays: &[Arc<dyn arrow_array::Array>],
108    ) -> Result<String, LookupError> {
109        use arrow_cast::display::{ArrayFormatter, FormatOptions};
110        use arrow_schema::DataType;
111
112        let mut conditions = Vec::with_capacity(self.config.primary_key_columns.len());
113        for (col_name, array) in self.config.primary_key_columns.iter().zip(pk_arrays) {
114            if array.is_null(0) {
115                conditions.push(format!("\"{col_name}\" IS NULL"));
116                continue;
117            }
118            let formatter = ArrayFormatter::try_new(array.as_ref(), &FormatOptions::default())
119                .map_err(|e| LookupError::Internal(format!("format pk: {e}")))?;
120            let value = formatter.value(0).to_string();
121            match array.data_type() {
122                // Numeric and boolean: unquoted literals.
123                dt if dt.is_numeric() || matches!(dt, DataType::Boolean) => {
124                    conditions.push(format!("\"{col_name}\" = {value}"));
125                }
126                DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => {
127                    let escaped = value.replace('\'', "''");
128                    conditions.push(format!("\"{col_name}\" = '{escaped}'"));
129                }
130                DataType::Date32 | DataType::Date64 | DataType::Timestamp(..) => {
131                    conditions.push(format!("\"{col_name}\" = '{value}'"));
132                }
133                dt => {
134                    return Err(LookupError::Internal(format!(
135                        "unsupported PK data type for lookup: {dt} (column \"{col_name}\")"
136                    )));
137                }
138            }
139        }
140        Ok(conditions.join(" AND "))
141    }
142
143    /// Total queries executed.
144    #[must_use]
145    pub fn query_count(&self) -> u64 {
146        self.query_count.load(Ordering::Relaxed)
147    }
148
149    /// Total rows returned.
150    #[must_use]
151    pub fn row_count(&self) -> u64 {
152        self.row_count.load(Ordering::Relaxed)
153    }
154
155    /// Total query errors.
156    #[must_use]
157    pub fn error_count(&self) -> u64 {
158        self.error_count.load(Ordering::Relaxed)
159    }
160}
161
162#[cfg(feature = "delta-lake")]
163impl LookupSource for DeltaLookupSource {
164    async fn query(
165        &self,
166        keys: &[&[u8]],
167        _predicates: &[Predicate],
168        _projection: &[ColumnId],
169    ) -> Result<Vec<Option<RecordBatch>>, LookupError> {
170        use tokio_stream::StreamExt;
171
172        let mut results = Vec::with_capacity(keys.len());
173        let converter = RowConverter::new(self.pk_sort_fields.clone())
174            .map_err(|e| LookupError::Internal(format!("row converter: {e}")))?;
175        let parser = converter.parser();
176
177        for key_bytes in keys {
178            let row = parser.parse(key_bytes);
179            let pk_arrays = converter
180                .convert_rows(std::iter::once(row))
181                .map_err(|e| LookupError::Internal(format!("decode key: {e}")))?;
182            let where_clause = self.build_where_clause(&pk_arrays)?;
183
184            let sql = format!(
185                "SELECT * FROM \"{}\" WHERE {} LIMIT 1",
186                self.config.table_name, where_clause
187            );
188
189            let df = self.ctx.sql(&sql).await.map_err(|e| {
190                self.error_count.fetch_add(1, Ordering::Relaxed);
191                LookupError::Query(format!("delta lookup query failed: {e}"))
192            })?;
193
194            let mut stream = df.execute_stream().await.map_err(|e| {
195                self.error_count.fetch_add(1, Ordering::Relaxed);
196                LookupError::Query(format!("execute stream: {e}"))
197            })?;
198
199            let mut found = None;
200            while let Some(batch_result) = stream.next().await {
201                match batch_result {
202                    Ok(batch) if batch.num_rows() > 0 => {
203                        self.row_count.fetch_add(1, Ordering::Relaxed);
204                        found = Some(batch.slice(0, 1));
205                        break;
206                    }
207                    Err(e) => {
208                        self.error_count.fetch_add(1, Ordering::Relaxed);
209                        return Err(LookupError::Query(format!("stream error: {e}")));
210                    }
211                    _ => {}
212                }
213            }
214
215            results.push(found);
216        }
217
218        self.query_count.fetch_add(1, Ordering::Relaxed);
219        Ok(results)
220    }
221
222    fn capabilities(&self) -> LookupSourceCapabilities {
223        LookupSourceCapabilities {
224            supports_predicate_pushdown: false,
225            supports_projection_pushdown: false,
226            supports_batch_lookup: true,
227            max_batch_size: 0,
228        }
229    }
230
231    #[allow(clippy::unnecessary_literal_bound)]
232    fn source_name(&self) -> &str {
233        "delta-lake"
234    }
235
236    fn schema(&self) -> SchemaRef {
237        Arc::clone(&self.schema)
238    }
239
240    async fn health_check(&self) -> Result<(), LookupError> {
241        self.ctx
242            .table(&self.config.table_name)
243            .await
244            .map_err(|e| LookupError::Connection(format!("health check: {e}")))?;
245        Ok(())
246    }
247}
248
249#[cfg(all(test, feature = "delta-lake"))]
250mod tests {
251    use super::*;
252    use arrow_array::{Int64Array, StringArray};
253    use arrow_schema::{DataType, Field, Schema};
254    use std::collections::HashMap;
255    use tempfile::TempDir;
256
257    fn test_schema() -> SchemaRef {
258        Arc::new(Schema::new(vec![
259            Field::new("id", DataType::Int64, false),
260            Field::new("name", DataType::Utf8, true),
261        ]))
262    }
263
264    fn test_batch(ids: &[i64], names: &[&str]) -> RecordBatch {
265        RecordBatch::try_new(
266            test_schema(),
267            vec![
268                Arc::new(Int64Array::from(ids.to_vec())),
269                Arc::new(StringArray::from(names.to_vec())),
270            ],
271        )
272        .unwrap()
273    }
274
275    async fn create_delta_table(path: &str, batches: Vec<RecordBatch>) {
276        use crate::lakehouse::delta_io;
277        use deltalake::protocol::SaveMode;
278
279        let schema = test_schema();
280        let table = delta_io::open_or_create_table(path, HashMap::new(), Some(&schema))
281            .await
282            .unwrap();
283
284        delta_io::write_batches(
285            table,
286            batches,
287            "test-writer",
288            1,
289            SaveMode::Append,
290            None,
291            false,
292            None,
293            false,
294            None,
295        )
296        .await
297        .unwrap();
298    }
299
300    #[tokio::test]
301    async fn test_open_and_query() {
302        let temp_dir = TempDir::new().unwrap();
303        let table_path = temp_dir.path().to_str().unwrap();
304
305        create_delta_table(table_path, vec![test_batch(&[1, 2, 3], &["A", "B", "C"])]).await;
306
307        let config = DeltaLookupSourceConfig {
308            table_path: table_path.to_string(),
309            storage_options: HashMap::new(),
310            primary_key_columns: vec!["id".into()],
311            table_name: "test_lookup".to_string(),
312        };
313        let source = DeltaLookupSource::open(config).await.unwrap();
314
315        let converter = RowConverter::new(vec![SortField::new(DataType::Int64)]).unwrap();
316        let key_col = Arc::new(Int64Array::from(vec![2i64]));
317        let rows = converter.convert_columns(&[key_col]).unwrap();
318        let key_bytes = rows.row(0);
319
320        let results = source.query(&[key_bytes.as_ref()], &[], &[]).await.unwrap();
321        assert_eq!(results.len(), 1);
322        let batch = results[0].as_ref().unwrap();
323        assert_eq!(batch.num_rows(), 1);
324
325        let id_col = batch
326            .column(0)
327            .as_any()
328            .downcast_ref::<Int64Array>()
329            .unwrap();
330        assert_eq!(id_col.value(0), 2);
331        assert_eq!(source.query_count(), 1);
332        assert_eq!(source.row_count(), 1);
333    }
334
335    #[tokio::test]
336    async fn test_query_miss() {
337        let temp_dir = TempDir::new().unwrap();
338        let table_path = temp_dir.path().to_str().unwrap();
339
340        create_delta_table(table_path, vec![test_batch(&[1], &["A"])]).await;
341
342        let config = DeltaLookupSourceConfig {
343            table_path: table_path.to_string(),
344            storage_options: HashMap::new(),
345            primary_key_columns: vec!["id".into()],
346            table_name: "test_miss".to_string(),
347        };
348        let source = DeltaLookupSource::open(config).await.unwrap();
349
350        let converter = RowConverter::new(vec![SortField::new(DataType::Int64)]).unwrap();
351        let key_col = Arc::new(Int64Array::from(vec![999i64]));
352        let rows = converter.convert_columns(&[key_col]).unwrap();
353        let key_bytes = rows.row(0);
354
355        let results = source.query(&[key_bytes.as_ref()], &[], &[]).await.unwrap();
356        assert_eq!(results.len(), 1);
357        assert!(results[0].is_none());
358        assert_eq!(source.row_count(), 0);
359    }
360}