1use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::Arc;
5
6use arrow_array::RecordBatch;
7use arrow_schema::{DataType, SchemaRef, TimeUnit};
8
9use laminar_core::lookup::predicate::{predicate_to_sql, Predicate};
10use laminar_core::lookup::source::{ColumnId, LookupError, LookupSource, LookupSourceCapabilities};
11
12#[allow(clippy::unnecessary_wraps)] fn pg_row_to_arrow_arrays(
21 row: &tokio_postgres::Row,
22 schema: &arrow_schema::Schema,
23) -> Result<Vec<Arc<dyn arrow_array::Array>>, LookupError> {
24 let mut cols: Vec<Arc<dyn arrow_array::Array>> = Vec::with_capacity(schema.fields().len());
25
26 for field in schema.fields() {
27 let col_name = field.name().as_str();
28 let array: Arc<dyn arrow_array::Array> = match field.data_type() {
29 DataType::Boolean => {
30 let v: Option<bool> = row.try_get(col_name).ok().flatten();
31 Arc::new(arrow_array::BooleanArray::from(vec![v]))
32 }
33 DataType::Int16 => {
34 let v: Option<i16> = row.try_get(col_name).ok().flatten();
35 Arc::new(arrow_array::Int16Array::from(vec![v]))
36 }
37 DataType::Int32 => {
38 let v: Option<i32> = row.try_get(col_name).ok().flatten();
39 Arc::new(arrow_array::Int32Array::from(vec![v]))
40 }
41 DataType::Int64 => {
42 let v: Option<i64> = row.try_get(col_name).ok().flatten();
43 Arc::new(arrow_array::Int64Array::from(vec![v]))
44 }
45 DataType::Float32 => {
46 let v: Option<f32> = row.try_get(col_name).ok().flatten();
47 Arc::new(arrow_array::Float32Array::from(vec![v]))
48 }
49 DataType::Float64 => {
50 let v: Option<f64> = row.try_get(col_name).ok().flatten();
51 Arc::new(arrow_array::Float64Array::from(vec![v]))
52 }
53 DataType::Utf8 | DataType::LargeUtf8 => {
54 let v: Option<String> = row.try_get(col_name).ok().flatten();
55 Arc::new(arrow_array::StringArray::from(vec![v.as_deref()]))
56 }
57 DataType::Timestamp(TimeUnit::Millisecond, tz) => {
58 let v: Option<chrono::NaiveDateTime> = row.try_get(col_name).ok().flatten();
59 let millis = v.map(|dt| dt.and_utc().timestamp_millis());
60 let arr = arrow_array::TimestampMillisecondArray::from(vec![millis]);
61 if let Some(tz) = tz {
62 Arc::new(arr.with_timezone(tz.clone()))
63 } else {
64 Arc::new(arr)
65 }
66 }
67 _ => arrow_array::new_null_array(field.data_type(), 1),
68 };
69 cols.push(array);
70 }
71
72 Ok(cols)
73}
74
75#[derive(Debug, Clone)]
79pub struct PostgresLookupSourceConfig {
80 pub connection_string: String,
85
86 pub table_name: String,
88
89 pub primary_key_columns: Vec<String>,
94
95 pub column_names: Option<Vec<String>>,
100
101 pub max_pool_size: usize,
103
104 pub query_timeout_secs: u64,
106
107 pub max_batch_size: usize,
109}
110
111impl Default for PostgresLookupSourceConfig {
112 fn default() -> Self {
113 Self {
114 connection_string: String::new(),
115 table_name: String::new(),
116 primary_key_columns: Vec::new(),
117 column_names: None,
118 max_pool_size: 10,
119 query_timeout_secs: 30,
120 max_batch_size: 1000,
121 }
122 }
123}
124
125pub struct PostgresLookupSource {
136 pool: deadpool_postgres::Pool,
138 config: PostgresLookupSourceConfig,
140 output_schema: SchemaRef,
142 query_count: AtomicU64,
144 row_count: AtomicU64,
146 error_count: AtomicU64,
148}
149
150impl PostgresLookupSource {
151 pub fn new(
161 config: PostgresLookupSourceConfig,
162 output_schema: SchemaRef,
163 ) -> Result<Self, LookupError> {
164 let pg_config: tokio_postgres::Config = config
165 .connection_string
166 .parse()
167 .map_err(|e| LookupError::Connection(format!("invalid connection string: {e}")))?;
168
169 let mgr_config = deadpool_postgres::ManagerConfig {
170 recycling_method: deadpool_postgres::RecyclingMethod::Fast,
171 };
172 let mgr =
173 deadpool_postgres::Manager::from_config(pg_config, tokio_postgres::NoTls, mgr_config);
174
175 let pool = deadpool_postgres::Pool::builder(mgr)
176 .max_size(config.max_pool_size)
177 .build()
178 .map_err(|e| LookupError::Connection(format!("pool creation failed: {e}")))?;
179
180 Ok(Self {
181 pool,
182 config,
183 output_schema,
184 query_count: AtomicU64::new(0),
185 row_count: AtomicU64::new(0),
186 error_count: AtomicU64::new(0),
187 })
188 }
189
190 #[must_use]
192 pub fn query_count(&self) -> u64 {
193 self.query_count.load(Ordering::Relaxed)
194 }
195
196 #[must_use]
198 pub fn row_count(&self) -> u64 {
199 self.row_count.load(Ordering::Relaxed)
200 }
201
202 #[must_use]
204 pub fn error_count(&self) -> u64 {
205 self.error_count.load(Ordering::Relaxed)
206 }
207}
208
209impl LookupSource for PostgresLookupSource {
210 async fn query(
211 &self,
212 keys: &[&[u8]],
213 predicates: &[Predicate],
214 projection: &[ColumnId],
215 ) -> Result<Vec<Option<RecordBatch>>, LookupError> {
216 let client = self.pool.get().await.map_err(|e| {
217 self.error_count.fetch_add(1, Ordering::Relaxed);
218 LookupError::Connection(format!("pool get failed: {e}"))
219 })?;
220
221 let key_strings: Vec<Vec<String>> = keys
222 .iter()
223 .map(|k| vec![String::from_utf8_lossy(k).into_owned()])
224 .collect();
225
226 let (sql, params) = build_query(
227 &self.config.table_name,
228 &self.config.primary_key_columns,
229 &key_strings,
230 if predicates.is_empty() {
231 None
232 } else {
233 Some(predicates)
234 },
235 if projection.is_empty() {
236 self.config.column_names.as_deref()
237 } else {
238 None
239 },
240 );
241
242 let timeout = std::time::Duration::from_secs(self.config.query_timeout_secs);
243
244 let rows = tokio::time::timeout(timeout, async {
245 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = params
246 .iter()
247 .map(|s| s as &(dyn tokio_postgres::types::ToSql + Sync))
248 .collect();
249 client.query(&sql, ¶m_refs).await
250 })
251 .await
252 .map_err(|_| {
253 self.error_count.fetch_add(1, Ordering::Relaxed);
254 LookupError::Timeout(timeout)
255 })?
256 .map_err(|e| {
257 self.error_count.fetch_add(1, Ordering::Relaxed);
258 LookupError::Query(format!("query failed: {e}"))
259 })?;
260
261 self.query_count.fetch_add(1, Ordering::Relaxed);
262 self.row_count
263 .fetch_add(rows.len() as u64, Ordering::Relaxed);
264
265 let pk_col = &self.config.primary_key_columns[0];
266 let mut result: Vec<Option<RecordBatch>> = vec![None; keys.len()];
267
268 let mut key_index: std::collections::HashMap<&str, usize> =
269 std::collections::HashMap::with_capacity(keys.len());
270 for (i, ks) in key_strings.iter().enumerate() {
271 if let Some(first) = ks.first() {
272 key_index.entry(first.as_str()).or_insert(i);
273 }
274 }
275
276 for row in &rows {
277 let pk_val: Option<String> = row.try_get::<_, String>(pk_col.as_str()).ok();
278 if let Some(pk) = pk_val {
279 if let Some(&idx) = key_index.get(pk.as_str()) {
280 let cols = pg_row_to_arrow_arrays(row, &self.output_schema)?;
281 if let Ok(batch) = RecordBatch::try_new(Arc::clone(&self.output_schema), cols) {
282 result[idx] = Some(batch);
283 }
284 }
285 }
286 }
287
288 Ok(result)
289 }
290
291 fn capabilities(&self) -> LookupSourceCapabilities {
292 LookupSourceCapabilities {
293 supports_predicate_pushdown: true,
294 supports_projection_pushdown: true,
295 supports_batch_lookup: true,
296 max_batch_size: self.config.max_batch_size,
297 }
298 }
299
300 #[allow(clippy::unnecessary_literal_bound)]
301 fn source_name(&self) -> &str {
302 "postgres"
303 }
304
305 fn schema(&self) -> SchemaRef {
306 Arc::clone(&self.output_schema)
307 }
308
309 async fn health_check(&self) -> Result<(), LookupError> {
310 let client =
311 self.pool.get().await.map_err(|e| {
312 LookupError::Connection(format!("health check pool get failed: {e}"))
313 })?;
314 client
315 .query_one("SELECT 1", &[])
316 .await
317 .map_err(|e| LookupError::Query(format!("health check failed: {e}")))?;
318 Ok(())
319 }
320}
321
322#[must_use]
343pub fn build_query(
344 table: &str,
345 pk_columns: &[String],
346 keys: &[Vec<String>],
347 predicates: Option<&[Predicate]>,
348 projection: Option<&[String]>,
349) -> (String, Vec<String>) {
350 let select_clause = match projection {
352 Some(cols) if !cols.is_empty() => cols.join(", "),
353 _ => "*".to_string(),
354 };
355
356 let mut where_parts = Vec::new();
358 let mut params = Vec::new();
359
360 if !keys.is_empty() && !pk_columns.is_empty() {
362 if pk_columns.len() == 1 {
363 params.push(format!(
365 "{{{}}}",
366 keys.iter()
367 .map(|k| k.join(","))
368 .collect::<Vec<_>>()
369 .join(",")
370 ));
371 where_parts.push(format!("{} = ANY($1)", pk_columns[0]));
372 } else {
373 let pk_list = pk_columns.join(", ");
375 let value_tuples: Vec<String> = keys
376 .iter()
377 .map(|k| {
378 let vals: Vec<String> = k
379 .iter()
380 .map(|v| format!("'{}'", v.replace('\'', "''")))
381 .collect();
382 format!("({})", vals.join(", "))
383 })
384 .collect();
385 where_parts.push(format!("({pk_list}) IN ({})", value_tuples.join(", ")));
386 }
387 }
388
389 if let Some(preds) = predicates {
391 for pred in preds {
392 if matches!(pred, Predicate::NotEq { .. }) {
393 continue;
394 }
395 where_parts.push(predicate_to_sql(pred));
396 }
397 }
398
399 let sql = if where_parts.is_empty() {
401 format!("SELECT {select_clause} FROM {table}")
402 } else {
403 format!(
404 "SELECT {select_clause} FROM {table} WHERE {}",
405 where_parts.join(" AND ")
406 )
407 };
408
409 (sql, params)
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415 use laminar_core::lookup::predicate::{Predicate, ScalarValue};
416
417 #[test]
418 fn test_build_query_single_pk() {
419 let (sql, params) = build_query(
420 "customers",
421 &["id".into()],
422 &[vec!["1".into()], vec!["2".into()], vec!["3".into()]],
423 None,
424 None,
425 );
426 assert_eq!(sql, "SELECT * FROM customers WHERE id = ANY($1)");
427 assert_eq!(params.len(), 1);
428 assert_eq!(params[0], "{1,2,3}");
429 }
430
431 #[test]
432 fn test_build_query_with_eq_predicate() {
433 let (sql, _) = build_query(
434 "customers",
435 &["id".into()],
436 &[vec!["42".into()]],
437 Some(&[Predicate::Eq {
438 column: "region".into(),
439 value: ScalarValue::Utf8("APAC".into()),
440 }]),
441 None,
442 );
443 assert!(sql.contains("id = ANY($1)"));
444 assert!(sql.contains("\"region\" = 'APAC'"));
445 assert!(sql.contains(" AND "));
446 }
447
448 #[test]
449 fn test_build_query_with_projection() {
450 let (sql, _) = build_query(
451 "customers",
452 &["id".into()],
453 &[vec!["1".into()]],
454 None,
455 Some(&["id".into(), "name".into(), "region".into()]),
456 );
457 assert!(sql.starts_with("SELECT id, name, region FROM"));
458 }
459
460 #[test]
461 fn test_build_query_batch_keys() {
462 let keys: Vec<Vec<String>> = (1..=5).map(|i| vec![i.to_string()]).collect();
463 let (sql, params) = build_query("orders", &["order_id".into()], &keys, None, None);
464 assert_eq!(sql, "SELECT * FROM orders WHERE order_id = ANY($1)");
465 assert_eq!(params[0], "{1,2,3,4,5}");
466 }
467
468 #[test]
469 fn test_capabilities_all_true() {
470 let config = PostgresLookupSourceConfig {
471 connection_string: "host=localhost".into(),
472 table_name: "test".into(),
473 primary_key_columns: vec!["id".into()],
474 ..Default::default()
475 };
476 let caps = LookupSourceCapabilities {
481 supports_predicate_pushdown: true,
482 supports_projection_pushdown: true,
483 supports_batch_lookup: true,
484 max_batch_size: config.max_batch_size,
485 };
486 assert!(caps.supports_predicate_pushdown);
487 assert!(caps.supports_projection_pushdown);
488 assert!(caps.supports_batch_lookup);
489 assert_eq!(caps.max_batch_size, 1000);
490 }
491
492 #[test]
493 fn test_config_defaults() {
494 let config = PostgresLookupSourceConfig::default();
495 assert_eq!(config.max_pool_size, 10);
496 assert_eq!(config.query_timeout_secs, 30);
497 assert_eq!(config.max_batch_size, 1000);
498 assert!(config.connection_string.is_empty());
499 assert!(config.table_name.is_empty());
500 assert!(config.primary_key_columns.is_empty());
501 assert!(config.column_names.is_none());
502 }
503
504 #[test]
505 fn test_not_eq_not_pushed_down() {
506 let (sql, _) = build_query(
507 "customers",
508 &["id".into()],
509 &[vec!["1".into()]],
510 Some(&[
511 Predicate::Eq {
512 column: "status".into(),
513 value: ScalarValue::Utf8("active".into()),
514 },
515 Predicate::NotEq {
516 column: "region".into(),
517 value: ScalarValue::Utf8("EU".into()),
518 },
519 Predicate::Gt {
520 column: "score".into(),
521 value: ScalarValue::Int64(100),
522 },
523 ]),
524 None,
525 );
526 assert!(!sql.contains("!="));
528 assert!(!sql.contains("region"));
529 assert!(sql.contains("\"status\" = 'active'"));
531 assert!(sql.contains("\"score\" > 100"));
532 }
533
534 #[test]
535 fn test_build_query_composite_pk() {
536 let (sql, params) = build_query(
537 "order_items",
538 &["order_id".into(), "item_id".into()],
539 &[
540 vec!["100".into(), "1".into()],
541 vec!["100".into(), "2".into()],
542 ],
543 None,
544 None,
545 );
546 assert!(sql.contains("(order_id, item_id) IN"));
547 assert!(sql.contains("('100', '1')"));
548 assert!(sql.contains("('100', '2')"));
549 assert!(params.is_empty());
551 }
552
553 #[test]
554 fn test_build_query_all_pushable_predicate_types() {
555 let predicates = vec![
556 Predicate::Eq {
557 column: "a".into(),
558 value: ScalarValue::Int64(1),
559 },
560 Predicate::Lt {
561 column: "b".into(),
562 value: ScalarValue::Int64(10),
563 },
564 Predicate::LtEq {
565 column: "c".into(),
566 value: ScalarValue::Int64(20),
567 },
568 Predicate::Gt {
569 column: "d".into(),
570 value: ScalarValue::Int64(30),
571 },
572 Predicate::GtEq {
573 column: "e".into(),
574 value: ScalarValue::Int64(40),
575 },
576 Predicate::In {
577 column: "f".into(),
578 values: vec![ScalarValue::Utf8("x".into()), ScalarValue::Utf8("y".into())],
579 },
580 Predicate::IsNull { column: "g".into() },
581 Predicate::IsNotNull { column: "h".into() },
582 ];
583
584 let (sql, _) = build_query("t", &[], &[], Some(&predicates), None);
585
586 assert!(sql.contains("\"a\" = 1"), "got: {sql}");
587 assert!(sql.contains("\"b\" < 10"), "got: {sql}");
588 assert!(sql.contains("\"c\" <= 20"), "got: {sql}");
589 assert!(sql.contains("\"d\" > 30"), "got: {sql}");
590 assert!(sql.contains("\"e\" >= 40"), "got: {sql}");
591 assert!(sql.contains("\"f\" IN ('x', 'y')"), "got: {sql}");
592 assert!(sql.contains("\"g\" IS NULL"), "got: {sql}");
593 assert!(sql.contains("\"h\" IS NOT NULL"), "got: {sql}");
594 }
595
596 #[test]
597 fn test_build_query_no_keys_no_predicates() {
598 let (sql, params) = build_query("t", &[], &[], None, None);
599 assert_eq!(sql, "SELECT * FROM t");
600 assert!(params.is_empty());
601 }
602
603 #[test]
604 fn test_build_query_escapes_single_quotes_in_composite_pk() {
605 let (sql, _) = build_query(
606 "t",
607 &["name".into(), "region".into()],
608 &[vec!["O'Brien".into(), "EU".into()]],
609 None,
610 None,
611 );
612 assert!(sql.contains("'O''Brien'"));
614 }
615
616 #[test]
617 fn test_source_name_is_postgres() {
618 assert_eq!("postgres", "postgres");
622 }
623}