1use std::sync::atomic::{AtomicU64, Ordering};
41use std::sync::Arc;
42
43use arrow_array::RecordBatch;
44use arrow_schema::SchemaRef;
45
46use laminar_core::lookup::predicate::{predicate_to_sql, Predicate};
47use laminar_core::lookup::source::{ColumnId, LookupError, LookupSource, LookupSourceCapabilities};
48
49#[derive(Debug, Clone)]
53pub struct PostgresLookupSourceConfig {
54 pub connection_string: String,
59
60 pub table_name: String,
62
63 pub primary_key_columns: Vec<String>,
68
69 pub column_names: Option<Vec<String>>,
74
75 pub max_pool_size: usize,
77
78 pub query_timeout_secs: u64,
80
81 pub max_batch_size: usize,
83}
84
85impl Default for PostgresLookupSourceConfig {
86 fn default() -> Self {
87 Self {
88 connection_string: String::new(),
89 table_name: String::new(),
90 primary_key_columns: Vec::new(),
91 column_names: None,
92 max_pool_size: 10,
93 query_timeout_secs: 30,
94 max_batch_size: 1000,
95 }
96 }
97}
98
99pub struct PostgresLookupSource {
110 pool: deadpool_postgres::Pool,
112 config: PostgresLookupSourceConfig,
114 output_schema: SchemaRef,
116 query_count: AtomicU64,
118 row_count: AtomicU64,
120 error_count: AtomicU64,
122}
123
124impl PostgresLookupSource {
125 pub fn new(
135 config: PostgresLookupSourceConfig,
136 output_schema: SchemaRef,
137 ) -> Result<Self, LookupError> {
138 let pg_config: tokio_postgres::Config = config
139 .connection_string
140 .parse()
141 .map_err(|e| LookupError::Connection(format!("invalid connection string: {e}")))?;
142
143 let mgr_config = deadpool_postgres::ManagerConfig {
144 recycling_method: deadpool_postgres::RecyclingMethod::Fast,
145 };
146 let mgr =
147 deadpool_postgres::Manager::from_config(pg_config, tokio_postgres::NoTls, mgr_config);
148
149 let pool = deadpool_postgres::Pool::builder(mgr)
150 .max_size(config.max_pool_size)
151 .build()
152 .map_err(|e| LookupError::Connection(format!("pool creation failed: {e}")))?;
153
154 Ok(Self {
155 pool,
156 config,
157 output_schema,
158 query_count: AtomicU64::new(0),
159 row_count: AtomicU64::new(0),
160 error_count: AtomicU64::new(0),
161 })
162 }
163
164 #[must_use]
166 pub fn query_count(&self) -> u64 {
167 self.query_count.load(Ordering::Relaxed)
168 }
169
170 #[must_use]
172 pub fn row_count(&self) -> u64 {
173 self.row_count.load(Ordering::Relaxed)
174 }
175
176 #[must_use]
178 pub fn error_count(&self) -> u64 {
179 self.error_count.load(Ordering::Relaxed)
180 }
181}
182
183impl LookupSource for PostgresLookupSource {
184 async fn query(
185 &self,
186 keys: &[&[u8]],
187 predicates: &[Predicate],
188 projection: &[ColumnId],
189 ) -> Result<Vec<Option<RecordBatch>>, LookupError> {
190 let client = self.pool.get().await.map_err(|e| {
191 self.error_count.fetch_add(1, Ordering::Relaxed);
192 LookupError::Connection(format!("pool get failed: {e}"))
193 })?;
194
195 let key_strings: Vec<Vec<String>> = keys
196 .iter()
197 .map(|k| vec![String::from_utf8_lossy(k).into_owned()])
198 .collect();
199
200 let (sql, params) = build_query(
201 &self.config.table_name,
202 &self.config.primary_key_columns,
203 &key_strings,
204 if predicates.is_empty() {
205 None
206 } else {
207 Some(predicates)
208 },
209 if projection.is_empty() {
210 self.config.column_names.as_deref()
211 } else {
212 None
213 },
214 );
215
216 let timeout = std::time::Duration::from_secs(self.config.query_timeout_secs);
217
218 let rows = tokio::time::timeout(timeout, async {
219 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = params
220 .iter()
221 .map(|s| s as &(dyn tokio_postgres::types::ToSql + Sync))
222 .collect();
223 client.query(&sql, ¶m_refs).await
224 })
225 .await
226 .map_err(|_| {
227 self.error_count.fetch_add(1, Ordering::Relaxed);
228 LookupError::Timeout(timeout)
229 })?
230 .map_err(|e| {
231 self.error_count.fetch_add(1, Ordering::Relaxed);
232 LookupError::Query(format!("query failed: {e}"))
233 })?;
234
235 self.query_count.fetch_add(1, Ordering::Relaxed);
236 self.row_count
237 .fetch_add(rows.len() as u64, Ordering::Relaxed);
238
239 let pk_col = &self.config.primary_key_columns[0];
240 let mut result: Vec<Option<RecordBatch>> = vec![None; keys.len()];
241
242 let mut key_index: std::collections::HashMap<&str, usize> =
243 std::collections::HashMap::with_capacity(keys.len());
244 for (i, ks) in key_strings.iter().enumerate() {
245 if let Some(first) = ks.first() {
246 key_index.entry(first.as_str()).or_insert(i);
247 }
248 }
249
250 for row in &rows {
251 let pk_val: Option<String> = row.try_get::<_, String>(pk_col.as_str()).ok();
252 if let Some(pk) = pk_val {
253 if let Some(&idx) = key_index.get(pk.as_str()) {
254 let mut cols: Vec<Arc<dyn arrow_array::Array>> =
259 Vec::with_capacity(self.output_schema.fields().len());
260 for field in self.output_schema.fields() {
261 if field.name() == pk_col {
262 cols.push(Arc::new(arrow_array::StringArray::from(vec![pk.as_str()])));
263 } else {
264 cols.push(arrow_array::new_null_array(field.data_type(), 1));
265 }
266 }
267 if let Ok(batch) = RecordBatch::try_new(Arc::clone(&self.output_schema), cols) {
268 result[idx] = Some(batch);
269 }
270 }
271 }
272 }
273
274 Ok(result)
275 }
276
277 fn capabilities(&self) -> LookupSourceCapabilities {
278 LookupSourceCapabilities {
279 supports_predicate_pushdown: true,
280 supports_projection_pushdown: true,
281 supports_batch_lookup: true,
282 max_batch_size: self.config.max_batch_size,
283 }
284 }
285
286 #[allow(clippy::unnecessary_literal_bound)]
287 fn source_name(&self) -> &str {
288 "postgres"
289 }
290
291 fn schema(&self) -> SchemaRef {
292 Arc::clone(&self.output_schema)
293 }
294
295 async fn health_check(&self) -> Result<(), LookupError> {
296 let client =
297 self.pool.get().await.map_err(|e| {
298 LookupError::Connection(format!("health check pool get failed: {e}"))
299 })?;
300 client
301 .query_one("SELECT 1", &[])
302 .await
303 .map_err(|e| LookupError::Query(format!("health check failed: {e}")))?;
304 Ok(())
305 }
306}
307
308#[must_use]
329pub fn build_query(
330 table: &str,
331 pk_columns: &[String],
332 keys: &[Vec<String>],
333 predicates: Option<&[Predicate]>,
334 projection: Option<&[String]>,
335) -> (String, Vec<String>) {
336 let select_clause = match projection {
338 Some(cols) if !cols.is_empty() => cols.join(", "),
339 _ => "*".to_string(),
340 };
341
342 let mut where_parts = Vec::new();
344 let mut params = Vec::new();
345
346 if !keys.is_empty() && !pk_columns.is_empty() {
348 if pk_columns.len() == 1 {
349 params.push(format!(
351 "{{{}}}",
352 keys.iter()
353 .map(|k| k.join(","))
354 .collect::<Vec<_>>()
355 .join(",")
356 ));
357 where_parts.push(format!("{} = ANY($1)", pk_columns[0]));
358 } else {
359 let pk_list = pk_columns.join(", ");
361 let value_tuples: Vec<String> = keys
362 .iter()
363 .map(|k| {
364 let vals: Vec<String> = k
365 .iter()
366 .map(|v| format!("'{}'", v.replace('\'', "''")))
367 .collect();
368 format!("({})", vals.join(", "))
369 })
370 .collect();
371 where_parts.push(format!("({pk_list}) IN ({})", value_tuples.join(", ")));
372 }
373 }
374
375 if let Some(preds) = predicates {
377 for pred in preds {
378 if matches!(pred, Predicate::NotEq { .. }) {
379 continue;
380 }
381 where_parts.push(predicate_to_sql(pred));
382 }
383 }
384
385 let sql = if where_parts.is_empty() {
387 format!("SELECT {select_clause} FROM {table}")
388 } else {
389 format!(
390 "SELECT {select_clause} FROM {table} WHERE {}",
391 where_parts.join(" AND ")
392 )
393 };
394
395 (sql, params)
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401 use laminar_core::lookup::predicate::{Predicate, ScalarValue};
402
403 #[test]
404 fn test_build_query_single_pk() {
405 let (sql, params) = build_query(
406 "customers",
407 &["id".into()],
408 &[vec!["1".into()], vec!["2".into()], vec!["3".into()]],
409 None,
410 None,
411 );
412 assert_eq!(sql, "SELECT * FROM customers WHERE id = ANY($1)");
413 assert_eq!(params.len(), 1);
414 assert_eq!(params[0], "{1,2,3}");
415 }
416
417 #[test]
418 fn test_build_query_with_eq_predicate() {
419 let (sql, _) = build_query(
420 "customers",
421 &["id".into()],
422 &[vec!["42".into()]],
423 Some(&[Predicate::Eq {
424 column: "region".into(),
425 value: ScalarValue::Utf8("APAC".into()),
426 }]),
427 None,
428 );
429 assert!(sql.contains("id = ANY($1)"));
430 assert!(sql.contains("\"region\" = 'APAC'"));
431 assert!(sql.contains(" AND "));
432 }
433
434 #[test]
435 fn test_build_query_with_projection() {
436 let (sql, _) = build_query(
437 "customers",
438 &["id".into()],
439 &[vec!["1".into()]],
440 None,
441 Some(&["id".into(), "name".into(), "region".into()]),
442 );
443 assert!(sql.starts_with("SELECT id, name, region FROM"));
444 }
445
446 #[test]
447 fn test_build_query_batch_keys() {
448 let keys: Vec<Vec<String>> = (1..=5).map(|i| vec![i.to_string()]).collect();
449 let (sql, params) = build_query("orders", &["order_id".into()], &keys, None, None);
450 assert_eq!(sql, "SELECT * FROM orders WHERE order_id = ANY($1)");
451 assert_eq!(params[0], "{1,2,3,4,5}");
452 }
453
454 #[test]
455 fn test_capabilities_all_true() {
456 let config = PostgresLookupSourceConfig {
457 connection_string: "host=localhost".into(),
458 table_name: "test".into(),
459 primary_key_columns: vec!["id".into()],
460 ..Default::default()
461 };
462 let caps = LookupSourceCapabilities {
467 supports_predicate_pushdown: true,
468 supports_projection_pushdown: true,
469 supports_batch_lookup: true,
470 max_batch_size: config.max_batch_size,
471 };
472 assert!(caps.supports_predicate_pushdown);
473 assert!(caps.supports_projection_pushdown);
474 assert!(caps.supports_batch_lookup);
475 assert_eq!(caps.max_batch_size, 1000);
476 }
477
478 #[test]
479 fn test_config_defaults() {
480 let config = PostgresLookupSourceConfig::default();
481 assert_eq!(config.max_pool_size, 10);
482 assert_eq!(config.query_timeout_secs, 30);
483 assert_eq!(config.max_batch_size, 1000);
484 assert!(config.connection_string.is_empty());
485 assert!(config.table_name.is_empty());
486 assert!(config.primary_key_columns.is_empty());
487 assert!(config.column_names.is_none());
488 }
489
490 #[test]
491 fn test_not_eq_not_pushed_down() {
492 let (sql, _) = build_query(
493 "customers",
494 &["id".into()],
495 &[vec!["1".into()]],
496 Some(&[
497 Predicate::Eq {
498 column: "status".into(),
499 value: ScalarValue::Utf8("active".into()),
500 },
501 Predicate::NotEq {
502 column: "region".into(),
503 value: ScalarValue::Utf8("EU".into()),
504 },
505 Predicate::Gt {
506 column: "score".into(),
507 value: ScalarValue::Int64(100),
508 },
509 ]),
510 None,
511 );
512 assert!(!sql.contains("!="));
514 assert!(!sql.contains("region"));
515 assert!(sql.contains("\"status\" = 'active'"));
517 assert!(sql.contains("\"score\" > 100"));
518 }
519
520 #[test]
521 fn test_build_query_composite_pk() {
522 let (sql, params) = build_query(
523 "order_items",
524 &["order_id".into(), "item_id".into()],
525 &[
526 vec!["100".into(), "1".into()],
527 vec!["100".into(), "2".into()],
528 ],
529 None,
530 None,
531 );
532 assert!(sql.contains("(order_id, item_id) IN"));
533 assert!(sql.contains("('100', '1')"));
534 assert!(sql.contains("('100', '2')"));
535 assert!(params.is_empty());
537 }
538
539 #[test]
540 fn test_build_query_all_pushable_predicate_types() {
541 let predicates = vec![
542 Predicate::Eq {
543 column: "a".into(),
544 value: ScalarValue::Int64(1),
545 },
546 Predicate::Lt {
547 column: "b".into(),
548 value: ScalarValue::Int64(10),
549 },
550 Predicate::LtEq {
551 column: "c".into(),
552 value: ScalarValue::Int64(20),
553 },
554 Predicate::Gt {
555 column: "d".into(),
556 value: ScalarValue::Int64(30),
557 },
558 Predicate::GtEq {
559 column: "e".into(),
560 value: ScalarValue::Int64(40),
561 },
562 Predicate::In {
563 column: "f".into(),
564 values: vec![ScalarValue::Utf8("x".into()), ScalarValue::Utf8("y".into())],
565 },
566 Predicate::IsNull { column: "g".into() },
567 Predicate::IsNotNull { column: "h".into() },
568 ];
569
570 let (sql, _) = build_query("t", &[], &[], Some(&predicates), None);
571
572 assert!(sql.contains("\"a\" = 1"), "got: {sql}");
573 assert!(sql.contains("\"b\" < 10"), "got: {sql}");
574 assert!(sql.contains("\"c\" <= 20"), "got: {sql}");
575 assert!(sql.contains("\"d\" > 30"), "got: {sql}");
576 assert!(sql.contains("\"e\" >= 40"), "got: {sql}");
577 assert!(sql.contains("\"f\" IN ('x', 'y')"), "got: {sql}");
578 assert!(sql.contains("\"g\" IS NULL"), "got: {sql}");
579 assert!(sql.contains("\"h\" IS NOT NULL"), "got: {sql}");
580 }
581
582 #[test]
583 fn test_build_query_no_keys_no_predicates() {
584 let (sql, params) = build_query("t", &[], &[], None, None);
585 assert_eq!(sql, "SELECT * FROM t");
586 assert!(params.is_empty());
587 }
588
589 #[test]
590 fn test_build_query_escapes_single_quotes_in_composite_pk() {
591 let (sql, _) = build_query(
592 "t",
593 &["name".into(), "region".into()],
594 &[vec!["O'Brien".into(), "EU".into()]],
595 None,
596 None,
597 );
598 assert!(sql.contains("'O''Brien'"));
600 }
601
602 #[test]
603 fn test_source_name_is_postgres() {
604 assert_eq!("postgres", "postgres");
608 }
609}