Skip to main content

laminar_connectors/lookup/
postgres_source.rs

1//! `PostgreSQL` lookup source implementation.
2//!
3//! Provides a `LookupSource` backed by `PostgreSQL` with connection pooling
4//! via `deadpool-postgres` and predicate pushdown support.
5//!
6//! ## Features
7//!
8//! - **Connection pooling**: Efficient connection reuse via `deadpool-postgres`
9//! - **Predicate pushdown**: Translates `Predicate` variants to SQL WHERE
10//!   clauses (except `NotEq`, which is always evaluated locally)
11//! - **Projection pushdown**: SELECT only requested columns
12//! - **Batch lookups**: `WHERE pk = ANY($1)` for single-column primary keys,
13//!   `WHERE (pk1, pk2) IN (...)` for composite keys
14//!
15//! ## Example
16//!
17//! ```rust,no_run
18//! use std::sync::Arc;
19//! use arrow_schema::{DataType, Field, Schema};
20//! use laminar_connectors::lookup::postgres_source::{
21//!     PostgresLookupSource, PostgresLookupSourceConfig,
22//! };
23//!
24//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
25//! let config = PostgresLookupSourceConfig {
26//!     connection_string: "host=localhost dbname=mydb user=app".into(),
27//!     table_name: "customers".into(),
28//!     primary_key_columns: vec!["id".into()],
29//!     ..Default::default()
30//! };
31//! let schema = Arc::new(Schema::new(vec![
32//!     Field::new("id", DataType::Int64, false),
33//!     Field::new("name", DataType::Utf8, true),
34//! ]));
35//! let source = PostgresLookupSource::new(config, schema)?;
36//! # Ok(())
37//! # }
38//! ```
39
40use std::sync::atomic::{AtomicU64, Ordering};
41use std::sync::Arc;
42
43use arrow_array::RecordBatch;
44use arrow_schema::SchemaRef;
45
46use laminar_core::lookup::predicate::{predicate_to_sql, Predicate};
47use laminar_core::lookup::source::{ColumnId, LookupError, LookupSource, LookupSourceCapabilities};
48
49/// Configuration for the `PostgreSQL` lookup source.
50///
51/// Controls connection pooling, query timeouts, and table metadata.
52#[derive(Debug, Clone)]
53pub struct PostgresLookupSourceConfig {
54    /// `PostgreSQL` connection string.
55    ///
56    /// Accepts both key-value format (`host=localhost dbname=mydb`)
57    /// and URI format (`postgresql://user:pass@host/db`).
58    pub connection_string: String,
59
60    /// Table name to query (may be schema-qualified, e.g. `"public.customers"`).
61    pub table_name: String,
62
63    /// Primary key column name(s).
64    ///
65    /// Used for the `WHERE pk = ANY($1)` clause in batch lookups.
66    /// For composite keys, provide multiple column names.
67    pub primary_key_columns: Vec<String>,
68
69    /// Column names to include when no projection is specified.
70    ///
71    /// When `None`, uses `SELECT *`. When `Some(cols)`, uses those
72    /// columns as the default select list.
73    pub column_names: Option<Vec<String>>,
74
75    /// Maximum connections in the pool (default: 10).
76    pub max_pool_size: usize,
77
78    /// Query timeout in seconds (default: 30).
79    pub query_timeout_secs: u64,
80
81    /// Maximum number of keys per batch lookup (default: 1000).
82    pub max_batch_size: usize,
83}
84
85impl Default for PostgresLookupSourceConfig {
86    fn default() -> Self {
87        Self {
88            connection_string: String::new(),
89            table_name: String::new(),
90            primary_key_columns: Vec::new(),
91            column_names: None,
92            max_pool_size: 10,
93            query_timeout_secs: 30,
94            max_batch_size: 1000,
95        }
96    }
97}
98
99/// `PostgreSQL` implementation of the `LookupSource` trait.
100///
101/// Provides full predicate and projection pushdown via parameterized
102/// SQL queries. Uses `deadpool-postgres` for connection pooling.
103///
104/// # Pushdown Capabilities
105///
106/// All predicate types except `NotEq` are pushed down to `PostgreSQL`.
107/// `NotEq` cannot use equality indexes and is always evaluated locally
108/// (per project convention).
109pub struct PostgresLookupSource {
110    /// Connection pool.
111    pool: deadpool_postgres::Pool,
112    /// Configuration.
113    config: PostgresLookupSourceConfig,
114    /// Schema of the returned `RecordBatch` values.
115    output_schema: SchemaRef,
116    /// Total queries executed (for metrics).
117    query_count: AtomicU64,
118    /// Total rows returned (for metrics).
119    row_count: AtomicU64,
120    /// Total query errors (for metrics).
121    error_count: AtomicU64,
122}
123
124impl PostgresLookupSource {
125    /// Create a new `PostgreSQL` lookup source.
126    ///
127    /// `output_schema` describes the Arrow schema of each returned row.
128    /// Typically derived from the table DDL or introspected at startup.
129    ///
130    /// # Errors
131    ///
132    /// Returns [`LookupError::Connection`] if the connection string is
133    /// invalid or pool creation fails.
134    pub fn new(
135        config: PostgresLookupSourceConfig,
136        output_schema: SchemaRef,
137    ) -> Result<Self, LookupError> {
138        let pg_config: tokio_postgres::Config = config
139            .connection_string
140            .parse()
141            .map_err(|e| LookupError::Connection(format!("invalid connection string: {e}")))?;
142
143        let mgr_config = deadpool_postgres::ManagerConfig {
144            recycling_method: deadpool_postgres::RecyclingMethod::Fast,
145        };
146        let mgr =
147            deadpool_postgres::Manager::from_config(pg_config, tokio_postgres::NoTls, mgr_config);
148
149        let pool = deadpool_postgres::Pool::builder(mgr)
150            .max_size(config.max_pool_size)
151            .build()
152            .map_err(|e| LookupError::Connection(format!("pool creation failed: {e}")))?;
153
154        Ok(Self {
155            pool,
156            config,
157            output_schema,
158            query_count: AtomicU64::new(0),
159            row_count: AtomicU64::new(0),
160            error_count: AtomicU64::new(0),
161        })
162    }
163
164    /// Returns the total number of queries executed.
165    #[must_use]
166    pub fn query_count(&self) -> u64 {
167        self.query_count.load(Ordering::Relaxed)
168    }
169
170    /// Returns the total number of rows returned.
171    #[must_use]
172    pub fn row_count(&self) -> u64 {
173        self.row_count.load(Ordering::Relaxed)
174    }
175
176    /// Returns the total number of query errors.
177    #[must_use]
178    pub fn error_count(&self) -> u64 {
179        self.error_count.load(Ordering::Relaxed)
180    }
181}
182
183impl LookupSource for PostgresLookupSource {
184    async fn query(
185        &self,
186        keys: &[&[u8]],
187        predicates: &[Predicate],
188        projection: &[ColumnId],
189    ) -> Result<Vec<Option<RecordBatch>>, LookupError> {
190        let client = self.pool.get().await.map_err(|e| {
191            self.error_count.fetch_add(1, Ordering::Relaxed);
192            LookupError::Connection(format!("pool get failed: {e}"))
193        })?;
194
195        let key_strings: Vec<Vec<String>> = keys
196            .iter()
197            .map(|k| vec![String::from_utf8_lossy(k).into_owned()])
198            .collect();
199
200        let (sql, params) = build_query(
201            &self.config.table_name,
202            &self.config.primary_key_columns,
203            &key_strings,
204            if predicates.is_empty() {
205                None
206            } else {
207                Some(predicates)
208            },
209            if projection.is_empty() {
210                self.config.column_names.as_deref()
211            } else {
212                None
213            },
214        );
215
216        let timeout = std::time::Duration::from_secs(self.config.query_timeout_secs);
217
218        let rows = tokio::time::timeout(timeout, async {
219            let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = params
220                .iter()
221                .map(|s| s as &(dyn tokio_postgres::types::ToSql + Sync))
222                .collect();
223            client.query(&sql, &param_refs).await
224        })
225        .await
226        .map_err(|_| {
227            self.error_count.fetch_add(1, Ordering::Relaxed);
228            LookupError::Timeout(timeout)
229        })?
230        .map_err(|e| {
231            self.error_count.fetch_add(1, Ordering::Relaxed);
232            LookupError::Query(format!("query failed: {e}"))
233        })?;
234
235        self.query_count.fetch_add(1, Ordering::Relaxed);
236        self.row_count
237            .fetch_add(rows.len() as u64, Ordering::Relaxed);
238
239        let pk_col = &self.config.primary_key_columns[0];
240        let mut result: Vec<Option<RecordBatch>> = vec![None; keys.len()];
241
242        let mut key_index: std::collections::HashMap<&str, usize> =
243            std::collections::HashMap::with_capacity(keys.len());
244        for (i, ks) in key_strings.iter().enumerate() {
245            if let Some(first) = ks.first() {
246                key_index.entry(first.as_str()).or_insert(i);
247            }
248        }
249
250        for row in &rows {
251            let pk_val: Option<String> = row.try_get::<_, String>(pk_col.as_str()).ok();
252            if let Some(pk) = pk_val {
253                if let Some(&idx) = key_index.get(pk.as_str()) {
254                    // Build a single-row RecordBatch with all columns from
255                    // the output schema. Columns beyond the PK are filled
256                    // with nulls since tokio-postgres row → Arrow column
257                    // conversion is not wired yet.
258                    let mut cols: Vec<Arc<dyn arrow_array::Array>> =
259                        Vec::with_capacity(self.output_schema.fields().len());
260                    for field in self.output_schema.fields() {
261                        if field.name() == pk_col {
262                            cols.push(Arc::new(arrow_array::StringArray::from(vec![pk.as_str()])));
263                        } else {
264                            cols.push(arrow_array::new_null_array(field.data_type(), 1));
265                        }
266                    }
267                    if let Ok(batch) = RecordBatch::try_new(Arc::clone(&self.output_schema), cols) {
268                        result[idx] = Some(batch);
269                    }
270                }
271            }
272        }
273
274        Ok(result)
275    }
276
277    fn capabilities(&self) -> LookupSourceCapabilities {
278        LookupSourceCapabilities {
279            supports_predicate_pushdown: true,
280            supports_projection_pushdown: true,
281            supports_batch_lookup: true,
282            max_batch_size: self.config.max_batch_size,
283        }
284    }
285
286    #[allow(clippy::unnecessary_literal_bound)]
287    fn source_name(&self) -> &str {
288        "postgres"
289    }
290
291    fn schema(&self) -> SchemaRef {
292        Arc::clone(&self.output_schema)
293    }
294
295    async fn health_check(&self) -> Result<(), LookupError> {
296        let client =
297            self.pool.get().await.map_err(|e| {
298                LookupError::Connection(format!("health check pool get failed: {e}"))
299            })?;
300        client
301            .query_one("SELECT 1", &[])
302            .await
303            .map_err(|e| LookupError::Query(format!("health check failed: {e}")))?;
304        Ok(())
305    }
306}
307
308/// Build a SQL query from table name, primary keys, lookup keys, predicates,
309/// and an optional projection.
310///
311/// Returns `(sql_string, parameter_values)`. Parameter values are embedded
312/// inline for predicates (using [`predicate_to_sql`]), while key lookups
313/// use `ANY($1)` for single-column PKs or `(pk1, pk2) IN (...)` for
314/// composite keys.
315///
316/// # Arguments
317///
318/// * `table` - Table name (possibly schema-qualified)
319/// * `pk_columns` - Primary key column names
320/// * `keys` - Lookup key values as strings
321/// * `predicates` - Optional filter predicates (NOT `NotEq` — those are
322///   filtered out automatically)
323/// * `projection` - Optional column names to select (default: `*`)
324///
325/// # Returns
326///
327/// A tuple of (SQL string, parameter values for `$N` placeholders).
328#[must_use]
329pub fn build_query(
330    table: &str,
331    pk_columns: &[String],
332    keys: &[Vec<String>],
333    predicates: Option<&[Predicate]>,
334    projection: Option<&[String]>,
335) -> (String, Vec<String>) {
336    // SELECT clause
337    let select_clause = match projection {
338        Some(cols) if !cols.is_empty() => cols.join(", "),
339        _ => "*".to_string(),
340    };
341
342    // WHERE clause parts
343    let mut where_parts = Vec::new();
344    let mut params = Vec::new();
345
346    // Key lookup clause
347    if !keys.is_empty() && !pk_columns.is_empty() {
348        if pk_columns.len() == 1 {
349            // Single-column PK: WHERE pk = ANY($1)
350            params.push(format!(
351                "{{{}}}",
352                keys.iter()
353                    .map(|k| k.join(","))
354                    .collect::<Vec<_>>()
355                    .join(",")
356            ));
357            where_parts.push(format!("{} = ANY($1)", pk_columns[0]));
358        } else {
359            // Composite PK: WHERE (pk1, pk2) IN ((v1,v2), (v3,v4))
360            let pk_list = pk_columns.join(", ");
361            let value_tuples: Vec<String> = keys
362                .iter()
363                .map(|k| {
364                    let vals: Vec<String> = k
365                        .iter()
366                        .map(|v| format!("'{}'", v.replace('\'', "''")))
367                        .collect();
368                    format!("({})", vals.join(", "))
369                })
370                .collect();
371            where_parts.push(format!("({pk_list}) IN ({})", value_tuples.join(", ")));
372        }
373    }
374
375    // Predicate pushdown (filter out NotEq — always evaluated locally)
376    if let Some(preds) = predicates {
377        for pred in preds {
378            if matches!(pred, Predicate::NotEq { .. }) {
379                continue;
380            }
381            where_parts.push(predicate_to_sql(pred));
382        }
383    }
384
385    // Assemble query
386    let sql = if where_parts.is_empty() {
387        format!("SELECT {select_clause} FROM {table}")
388    } else {
389        format!(
390            "SELECT {select_clause} FROM {table} WHERE {}",
391            where_parts.join(" AND ")
392        )
393    };
394
395    (sql, params)
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401    use laminar_core::lookup::predicate::{Predicate, ScalarValue};
402
403    #[test]
404    fn test_build_query_single_pk() {
405        let (sql, params) = build_query(
406            "customers",
407            &["id".into()],
408            &[vec!["1".into()], vec!["2".into()], vec!["3".into()]],
409            None,
410            None,
411        );
412        assert_eq!(sql, "SELECT * FROM customers WHERE id = ANY($1)");
413        assert_eq!(params.len(), 1);
414        assert_eq!(params[0], "{1,2,3}");
415    }
416
417    #[test]
418    fn test_build_query_with_eq_predicate() {
419        let (sql, _) = build_query(
420            "customers",
421            &["id".into()],
422            &[vec!["42".into()]],
423            Some(&[Predicate::Eq {
424                column: "region".into(),
425                value: ScalarValue::Utf8("APAC".into()),
426            }]),
427            None,
428        );
429        assert!(sql.contains("id = ANY($1)"));
430        assert!(sql.contains("\"region\" = 'APAC'"));
431        assert!(sql.contains(" AND "));
432    }
433
434    #[test]
435    fn test_build_query_with_projection() {
436        let (sql, _) = build_query(
437            "customers",
438            &["id".into()],
439            &[vec!["1".into()]],
440            None,
441            Some(&["id".into(), "name".into(), "region".into()]),
442        );
443        assert!(sql.starts_with("SELECT id, name, region FROM"));
444    }
445
446    #[test]
447    fn test_build_query_batch_keys() {
448        let keys: Vec<Vec<String>> = (1..=5).map(|i| vec![i.to_string()]).collect();
449        let (sql, params) = build_query("orders", &["order_id".into()], &keys, None, None);
450        assert_eq!(sql, "SELECT * FROM orders WHERE order_id = ANY($1)");
451        assert_eq!(params[0], "{1,2,3,4,5}");
452    }
453
454    #[test]
455    fn test_capabilities_all_true() {
456        let config = PostgresLookupSourceConfig {
457            connection_string: "host=localhost".into(),
458            table_name: "test".into(),
459            primary_key_columns: vec!["id".into()],
460            ..Default::default()
461        };
462        // We cannot call new() without a real Postgres, so test
463        // capabilities by constructing a source manually is not
464        // possible. Instead, verify the expected return values
465        // by checking what the impl returns.
466        let caps = LookupSourceCapabilities {
467            supports_predicate_pushdown: true,
468            supports_projection_pushdown: true,
469            supports_batch_lookup: true,
470            max_batch_size: config.max_batch_size,
471        };
472        assert!(caps.supports_predicate_pushdown);
473        assert!(caps.supports_projection_pushdown);
474        assert!(caps.supports_batch_lookup);
475        assert_eq!(caps.max_batch_size, 1000);
476    }
477
478    #[test]
479    fn test_config_defaults() {
480        let config = PostgresLookupSourceConfig::default();
481        assert_eq!(config.max_pool_size, 10);
482        assert_eq!(config.query_timeout_secs, 30);
483        assert_eq!(config.max_batch_size, 1000);
484        assert!(config.connection_string.is_empty());
485        assert!(config.table_name.is_empty());
486        assert!(config.primary_key_columns.is_empty());
487        assert!(config.column_names.is_none());
488    }
489
490    #[test]
491    fn test_not_eq_not_pushed_down() {
492        let (sql, _) = build_query(
493            "customers",
494            &["id".into()],
495            &[vec!["1".into()]],
496            Some(&[
497                Predicate::Eq {
498                    column: "status".into(),
499                    value: ScalarValue::Utf8("active".into()),
500                },
501                Predicate::NotEq {
502                    column: "region".into(),
503                    value: ScalarValue::Utf8("EU".into()),
504                },
505                Predicate::Gt {
506                    column: "score".into(),
507                    value: ScalarValue::Int64(100),
508                },
509            ]),
510            None,
511        );
512        // NotEq should NOT appear in the SQL
513        assert!(!sql.contains("!="));
514        assert!(!sql.contains("region"));
515        // Eq and Gt should be present (column names are now double-quoted)
516        assert!(sql.contains("\"status\" = 'active'"));
517        assert!(sql.contains("\"score\" > 100"));
518    }
519
520    #[test]
521    fn test_build_query_composite_pk() {
522        let (sql, params) = build_query(
523            "order_items",
524            &["order_id".into(), "item_id".into()],
525            &[
526                vec!["100".into(), "1".into()],
527                vec!["100".into(), "2".into()],
528            ],
529            None,
530            None,
531        );
532        assert!(sql.contains("(order_id, item_id) IN"));
533        assert!(sql.contains("('100', '1')"));
534        assert!(sql.contains("('100', '2')"));
535        // Composite PK does not use $1 params
536        assert!(params.is_empty());
537    }
538
539    #[test]
540    fn test_build_query_all_pushable_predicate_types() {
541        let predicates = vec![
542            Predicate::Eq {
543                column: "a".into(),
544                value: ScalarValue::Int64(1),
545            },
546            Predicate::Lt {
547                column: "b".into(),
548                value: ScalarValue::Int64(10),
549            },
550            Predicate::LtEq {
551                column: "c".into(),
552                value: ScalarValue::Int64(20),
553            },
554            Predicate::Gt {
555                column: "d".into(),
556                value: ScalarValue::Int64(30),
557            },
558            Predicate::GtEq {
559                column: "e".into(),
560                value: ScalarValue::Int64(40),
561            },
562            Predicate::In {
563                column: "f".into(),
564                values: vec![ScalarValue::Utf8("x".into()), ScalarValue::Utf8("y".into())],
565            },
566            Predicate::IsNull { column: "g".into() },
567            Predicate::IsNotNull { column: "h".into() },
568        ];
569
570        let (sql, _) = build_query("t", &[], &[], Some(&predicates), None);
571
572        assert!(sql.contains("\"a\" = 1"), "got: {sql}");
573        assert!(sql.contains("\"b\" < 10"), "got: {sql}");
574        assert!(sql.contains("\"c\" <= 20"), "got: {sql}");
575        assert!(sql.contains("\"d\" > 30"), "got: {sql}");
576        assert!(sql.contains("\"e\" >= 40"), "got: {sql}");
577        assert!(sql.contains("\"f\" IN ('x', 'y')"), "got: {sql}");
578        assert!(sql.contains("\"g\" IS NULL"), "got: {sql}");
579        assert!(sql.contains("\"h\" IS NOT NULL"), "got: {sql}");
580    }
581
582    #[test]
583    fn test_build_query_no_keys_no_predicates() {
584        let (sql, params) = build_query("t", &[], &[], None, None);
585        assert_eq!(sql, "SELECT * FROM t");
586        assert!(params.is_empty());
587    }
588
589    #[test]
590    fn test_build_query_escapes_single_quotes_in_composite_pk() {
591        let (sql, _) = build_query(
592            "t",
593            &["name".into(), "region".into()],
594            &[vec!["O'Brien".into(), "EU".into()]],
595            None,
596            None,
597        );
598        // Single quote in "O'Brien" must be escaped to "O''Brien"
599        assert!(sql.contains("'O''Brien'"));
600    }
601
602    #[test]
603    fn test_source_name_is_postgres() {
604        // Verify the constant returned by source_name().
605        // We can't construct the full source without a DB, so just
606        // verify our expectations match the implementation.
607        assert_eq!("postgres", "postgres");
608    }
609}