1#[cfg(feature = "postgres-cdc")]
16use std::collections::HashMap;
17#[cfg(feature = "postgres-cdc")]
18use std::sync::Arc;
19
20#[cfg(feature = "postgres-cdc")]
21use arrow_array::{Array, RecordBatch};
22#[cfg(feature = "postgres-cdc")]
23use arrow_row::SortField;
24#[cfg(feature = "postgres-cdc")]
25use arrow_schema::{DataType, Field, Schema, SchemaRef};
26#[cfg(feature = "postgres-cdc")]
27use deadpool_postgres::Pool;
28#[cfg(feature = "postgres-cdc")]
29use tokio_postgres::types::{ToSql, Type};
30
31#[cfg(feature = "postgres-cdc")]
32use laminar_core::lookup::predicate::Predicate;
33#[cfg(feature = "postgres-cdc")]
34use laminar_core::lookup::source::{
35 projection_names, ColumnId, LookupError, LookupSource, LookupSourceCapabilities,
36};
37#[cfg(feature = "postgres-cdc")]
38use laminar_core::lookup::KeyAligner;
39
40#[cfg(feature = "postgres-cdc")]
42#[derive(Debug, Clone)]
43pub struct PostgresLookupSourceConfig {
44 pub properties: HashMap<String, String>,
47 pub table: String,
49 pub primary_key_columns: Vec<String>,
51 pub pool_size: usize,
53}
54
55#[cfg(feature = "postgres-cdc")]
57pub struct PostgresLookupSource {
58 pool: Pool,
59 select_sql: String,
60 table: String,
63 pk_column: String,
64 schema: SchemaRef,
65 aligner: KeyAligner,
66}
67
68#[cfg(feature = "postgres-cdc")]
69fn quote_identifier(name: &str) -> String {
70 if name.contains('.') {
71 name.split('.')
72 .map(|part| format!("\"{}\"", part.replace('"', "\"\"")))
73 .collect::<Vec<_>>()
74 .join(".")
75 } else {
76 format!("\"{}\"", name.replace('"', "\"\""))
77 }
78}
79
80#[cfg(feature = "postgres-cdc")]
81impl PostgresLookupSource {
82 pub async fn open(config: PostgresLookupSourceConfig) -> Result<Self, LookupError> {
89 if config.primary_key_columns.len() != 1 {
90 return Err(LookupError::Internal(format!(
91 "postgres lookup requires exactly one primary key column, got {}",
92 config.primary_key_columns.len()
93 )));
94 }
95 let pk_column = config.primary_key_columns[0].clone();
96
97 let pool = build_pool(&config.properties, config.pool_size)?;
98 let select_sql = format!(
99 "SELECT * FROM {} WHERE {} = ANY($1)",
100 quote_identifier(&config.table),
101 quote_identifier(&pk_column)
102 );
103
104 let client = pool
106 .get()
107 .await
108 .map_err(|e| LookupError::Connection(format!("postgres pool: {e}")))?;
109 let stmt = client
110 .prepare(&format!(
111 "SELECT * FROM {} LIMIT 0",
112 quote_identifier(&config.table)
113 ))
114 .await
115 .map_err(|e| LookupError::Connection(format!("prepare schema probe: {e}")))?;
116 let fields: Vec<Field> = stmt
117 .columns()
118 .iter()
119 .map(|c| Field::new(c.name(), pg_type_to_arrow(c.type_()), true))
120 .collect();
121 let schema: SchemaRef = Arc::new(Schema::new(fields));
122
123 let pk_idx = schema.index_of(&pk_column).map_err(|_| {
124 LookupError::Internal(format!("pk column not found in table: {pk_column}"))
125 })?;
126 let pk_sort_fields = vec![SortField::new(schema.field(pk_idx).data_type().clone())];
127 let aligner = KeyAligner::new(pk_sort_fields, config.primary_key_columns)?;
128
129 Ok(Self {
130 pool,
131 select_sql,
132 table: config.table,
133 pk_column,
134 schema,
135 aligner,
136 })
137 }
138
139 fn build_any_param(pk_array: &dyn Array) -> Result<Box<dyn ToSql + Sync + Send>, LookupError> {
142 use arrow_array::{
143 BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
144 LargeStringArray, StringArray, StringViewArray,
145 };
146
147 fn downcast<T: 'static>(array: &dyn Array) -> Result<&T, LookupError> {
148 array
149 .as_any()
150 .downcast_ref::<T>()
151 .ok_or_else(|| LookupError::Internal("pk column downcast failed".into()))
152 }
153 fn non_null<A: Array, T>(a: &A, get: impl Fn(usize) -> T) -> Vec<T> {
154 (0..a.len()).filter(|&i| !a.is_null(i)).map(get).collect()
155 }
156
157 let param: Box<dyn ToSql + Sync + Send> = match pk_array.data_type() {
158 DataType::Int16 => {
159 let a = downcast::<Int16Array>(pk_array)?;
160 Box::new(non_null(a, |i| a.value(i)))
161 }
162 DataType::Int32 => {
163 let a = downcast::<Int32Array>(pk_array)?;
164 Box::new(non_null(a, |i| a.value(i)))
165 }
166 DataType::Int64 => {
167 let a = downcast::<Int64Array>(pk_array)?;
168 Box::new(non_null(a, |i| a.value(i)))
169 }
170 DataType::Float32 => {
171 let a = downcast::<Float32Array>(pk_array)?;
172 Box::new(non_null(a, |i| a.value(i)))
173 }
174 DataType::Float64 => {
175 let a = downcast::<Float64Array>(pk_array)?;
176 Box::new(non_null(a, |i| a.value(i)))
177 }
178 DataType::Boolean => {
179 let a = downcast::<BooleanArray>(pk_array)?;
180 Box::new(non_null(a, |i| a.value(i)))
181 }
182 DataType::Utf8 => {
183 let a = downcast::<StringArray>(pk_array)?;
184 Box::new(non_null(a, |i| a.value(i).to_string()))
185 }
186 DataType::LargeUtf8 => {
187 let a = downcast::<LargeStringArray>(pk_array)?;
188 Box::new(non_null(a, |i| a.value(i).to_string()))
189 }
190 DataType::Utf8View => {
191 let a = downcast::<StringViewArray>(pk_array)?;
192 Box::new(non_null(a, |i| a.value(i).to_string()))
193 }
194 dt => {
195 return Err(LookupError::Internal(format!(
196 "unsupported PK data type for postgres lookup: {dt}"
197 )));
198 }
199 };
200 Ok(param)
201 }
202}
203
204#[cfg(feature = "postgres-cdc")]
205impl LookupSource for PostgresLookupSource {
206 async fn query(
207 &self,
208 keys: &[&[u8]],
209 _predicates: &[Predicate],
210 projection: &[ColumnId],
211 ) -> Result<Vec<Option<RecordBatch>>, LookupError> {
212 if keys.is_empty() {
213 return Ok(Vec::new());
214 }
215
216 let pk_arrays = self.aligner.decode_keys(keys)?;
217 let param = Self::build_any_param(pk_arrays[0].as_ref())?;
218
219 let (sql, out_schema, project_needed) = if projection.is_empty() {
223 (self.select_sql.clone(), Arc::clone(&self.schema), false)
224 } else {
225 let mut proj_names = projection_names(&self.schema, projection)?;
226 let mut idx: Vec<usize> = projection.iter().map(|&c| c as usize).collect();
227 let mut project_needed = false;
228
229 if !proj_names.contains(&self.pk_column) {
230 proj_names.push(self.pk_column.clone());
231 let pk_idx = self
232 .schema
233 .index_of(&self.pk_column)
234 .map_err(|e| LookupError::Internal(format!("pk column index: {e}")))?;
235 idx.push(pk_idx);
236 project_needed = true;
237 }
238
239 let cols = proj_names
240 .iter()
241 .map(|n| quote_identifier(n))
242 .collect::<Vec<_>>()
243 .join(", ");
244 let sql = format!(
245 "SELECT {cols} FROM {} WHERE {} = ANY($1)",
246 quote_identifier(&self.table),
247 quote_identifier(&self.pk_column)
248 );
249 let proj_schema = Arc::new(
250 self.schema
251 .project(&idx)
252 .map_err(|e| LookupError::Internal(format!("project postgres schema: {e}")))?,
253 );
254 (sql, proj_schema, project_needed)
255 };
256
257 let client = self
258 .pool
259 .get()
260 .await
261 .map_err(|e| LookupError::Connection(format!("postgres pool: {e}")))?;
262 let pg_rows = client
263 .query(&sql, &[&*param])
264 .await
265 .map_err(|e| LookupError::Query(format!("postgres lookup query: {e}")))?;
266
267 let batches = if pg_rows.is_empty() {
268 Vec::new()
269 } else {
270 vec![rows_to_batch(&out_schema, &pg_rows)?]
271 };
272 let aligned = self.aligner.align(keys, &batches)?;
273
274 if project_needed {
275 let orig_names = projection_names(&self.schema, projection)?;
276 let mut projected_aligned = Vec::with_capacity(aligned.len());
277 for maybe_batch in aligned {
278 if let Some(batch) = maybe_batch {
279 let indices: Vec<usize> = orig_names
280 .iter()
281 .map(|name| {
282 batch.schema().index_of(name).map_err(|e| {
283 LookupError::Internal(format!(
284 "column not found in aligned schema: {e}"
285 ))
286 })
287 })
288 .collect::<Result<Vec<usize>, LookupError>>()?;
289 let projected = batch.project(&indices).map_err(|e| {
290 LookupError::Internal(format!("project aligned batch: {e}"))
291 })?;
292 projected_aligned.push(Some(projected));
293 } else {
294 projected_aligned.push(None);
295 }
296 }
297 Ok(projected_aligned)
298 } else {
299 Ok(aligned)
300 }
301 }
302
303 fn capabilities(&self) -> LookupSourceCapabilities {
304 LookupSourceCapabilities {
305 supports_batch_lookup: true,
306 supports_projection_pushdown: true,
307 ..LookupSourceCapabilities::none()
308 }
309 }
310
311 #[allow(clippy::unnecessary_literal_bound)]
312 fn source_name(&self) -> &str {
313 "postgres"
314 }
315
316 fn schema(&self) -> SchemaRef {
317 Arc::clone(&self.schema)
318 }
319
320 async fn health_check(&self) -> Result<(), LookupError> {
321 let client = self
322 .pool
323 .get()
324 .await
325 .map_err(|e| LookupError::Connection(format!("health check pool: {e}")))?;
326 client
327 .query_one("SELECT 1", &[])
328 .await
329 .map(|_| ())
330 .map_err(|e| LookupError::Connection(format!("health check: {e}")))
331 }
332}
333
334#[cfg(feature = "postgres-cdc")]
335fn parse_conn_string_params(conn: &str) -> HashMap<String, String> {
336 let mut params = HashMap::new();
337 if conn.starts_with("postgresql://") || conn.starts_with("postgres://") {
338 if let Some(pos) = conn.find('?') {
339 let query = &conn[pos + 1..];
340 for pair in query.split('&') {
341 let mut parts = pair.splitn(2, '=');
342 if let (Some(k), Some(v)) = (parts.next(), parts.next()) {
343 params.insert(k.to_string(), v.replace("%2F", "/").replace("%2f", "/"));
344 }
345 }
346 }
347 } else {
348 let mut chars = conn.chars().peekable();
349 while let Some(&c) = chars.peek() {
350 if c.is_whitespace() {
351 chars.next();
352 continue;
353 }
354 let mut key = String::new();
355 while let Some(&c) = chars.peek() {
356 if c == '=' {
357 chars.next();
358 break;
359 }
360 if c.is_whitespace() {
361 break;
362 }
363 key.push(c);
364 chars.next();
365 }
366 if key.is_empty() {
367 break;
368 }
369 let mut val = String::new();
370 if chars.peek() == Some(&'\'') {
371 chars.next();
372 for c in chars.by_ref() {
373 if c == '\'' {
374 break;
375 }
376 val.push(c);
377 }
378 } else {
379 while let Some(&c) = chars.peek() {
380 if c.is_whitespace() {
381 break;
382 }
383 val.push(c);
384 chars.next();
385 }
386 }
387 params.insert(key, val);
388 }
389 }
390 params
391}
392
393#[cfg(feature = "postgres-cdc")]
396#[allow(clippy::match_wildcard_for_single_variants)]
397fn build_pool(props: &HashMap<String, String>, pool_size: usize) -> Result<Pool, LookupError> {
398 let mut cfg = deadpool_postgres::Config::new();
399 let mut merged_props = props.clone();
400
401 if let Some(conn) = props
402 .get("connection")
403 .or_else(|| props.get("connection_string"))
404 {
405 let conn_params = parse_conn_string_params(conn);
406 for (k, v) in conn_params {
407 merged_props.insert(k, v);
408 }
409
410 let pg: tokio_postgres::Config = conn
411 .parse()
412 .map_err(|e| LookupError::Connection(format!("parse connection string: {e}")))?;
413 cfg.host = pg.get_hosts().iter().find_map(|h| match h {
414 tokio_postgres::config::Host::Tcp(s) => Some(s.clone()),
415 #[allow(unreachable_patterns)]
416 _ => None,
417 });
418 cfg.port = pg.get_ports().first().copied();
419 cfg.dbname = pg.get_dbname().map(str::to_string);
420 cfg.user = pg.get_user().map(str::to_string);
421 cfg.password = pg
422 .get_password()
423 .map(|p| String::from_utf8_lossy(p).into_owned());
424 } else {
425 cfg.host = props.get("host").cloned();
426 cfg.port = props.get("port").and_then(|p| p.parse().ok());
427 cfg.dbname = props
428 .get("database")
429 .or_else(|| props.get("dbname"))
430 .cloned();
431 cfg.user = props.get("user").cloned();
432 cfg.password = props.get("password").cloned();
433 }
434
435 cfg.pool = Some(deadpool_postgres::PoolConfig::new(pool_size.max(1)));
436 let runtime = Some(deadpool_postgres::Runtime::Tokio1);
437
438 if tls_enabled(&merged_props)? {
441 let connector = build_rustls_connector(&merged_props)?;
442 cfg.create_pool(runtime, connector)
443 .map_err(|e| LookupError::Connection(format!("create pool: {e}")))
444 } else {
445 cfg.create_pool(runtime, tokio_postgres::NoTls)
446 .map_err(|e| LookupError::Connection(format!("create pool: {e}")))
447 }
448}
449
450#[cfg(feature = "postgres-cdc")]
455fn tls_enabled(props: &HashMap<String, String>) -> Result<bool, LookupError> {
456 let Some(mode) = props.get("sslmode").or_else(|| props.get("ssl.mode")) else {
457 return Ok(false);
458 };
459 match mode.to_ascii_lowercase().as_str() {
460 "disable" => Ok(false),
461 "require" | "verify-ca" | "verify-full" => Ok(true),
462 other => Err(LookupError::Connection(format!(
463 "unsupported sslmode '{other}' (use disable/require/verify-ca/verify-full)"
464 ))),
465 }
466}
467
468#[cfg(feature = "postgres-cdc")]
472fn build_rustls_connector(
473 props: &HashMap<String, String>,
474) -> Result<tokio_postgres_rustls::MakeRustlsConnect, LookupError> {
475 use tokio_rustls::rustls::{ClientConfig, RootCertStore};
476
477 let _ = tokio_rustls::rustls::crypto::aws_lc_rs::default_provider().install_default();
479
480 let mut roots = RootCertStore::empty();
481 if let Some(ca_path) = props.get("sslrootcert").or_else(|| props.get("ssl.ca")) {
482 let pem = std::fs::read(ca_path)
483 .map_err(|e| LookupError::Connection(format!("read sslrootcert '{ca_path}': {e}")))?;
484 let certs = rustls_pemfile::certs(&mut std::io::Cursor::new(pem))
485 .collect::<Result<Vec<_>, _>>()
486 .map_err(|e| LookupError::Connection(format!("parse sslrootcert: {e}")))?;
487 if certs.is_empty() {
488 return Err(LookupError::Connection(
489 "sslrootcert contained no certificates".into(),
490 ));
491 }
492 for cert in certs {
493 roots
494 .add(cert)
495 .map_err(|e| LookupError::Connection(format!("add CA cert: {e}")))?;
496 }
497 } else {
498 roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
499 }
500
501 let client_cfg = ClientConfig::builder()
502 .with_root_certificates(roots)
503 .with_no_client_auth();
504 Ok(tokio_postgres_rustls::MakeRustlsConnect::new(client_cfg))
505}
506
507#[cfg(feature = "postgres-cdc")]
510fn rows_to_batch(
511 schema: &SchemaRef,
512 rows: &[tokio_postgres::Row],
513) -> Result<RecordBatch, LookupError> {
514 use arrow_array::{
515 BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, StringArray,
516 };
517
518 let mut columns: Vec<Arc<dyn Array>> = Vec::with_capacity(schema.fields().len());
519 for field in schema.fields() {
520 let name = field.name().as_str();
521 let array: Arc<dyn Array> = match field.data_type() {
522 DataType::Boolean => Arc::new(
523 collect_col::<bool>(rows, name)?
524 .into_iter()
525 .collect::<BooleanArray>(),
526 ),
527 DataType::Int16 => Arc::new(Int16Array::from(collect_col::<i16>(rows, name)?)),
528 DataType::Int32 => Arc::new(Int32Array::from(collect_col::<i32>(rows, name)?)),
529 DataType::Int64 => Arc::new(Int64Array::from(collect_col::<i64>(rows, name)?)),
530 DataType::Float32 => Arc::new(Float32Array::from(collect_col::<f32>(rows, name)?)),
531 DataType::Float64 => Arc::new(Float64Array::from(collect_col::<f64>(rows, name)?)),
532 _ => {
534 let vals: Vec<Option<String>> = rows
535 .iter()
536 .map(|r| r.try_get::<_, Option<String>>(name).unwrap_or(None))
537 .collect();
538 Arc::new(StringArray::from(vals))
539 }
540 };
541 columns.push(array);
542 }
543 RecordBatch::try_new(Arc::clone(schema), columns)
544 .map_err(|e| LookupError::Internal(format!("arrow batch construction: {e}")))
545}
546
547#[cfg(feature = "postgres-cdc")]
549fn collect_col<'a, T>(
550 rows: &'a [tokio_postgres::Row],
551 name: &str,
552) -> Result<Vec<Option<T>>, LookupError>
553where
554 T: tokio_postgres::types::FromSql<'a>,
555{
556 rows.iter()
557 .map(|r| {
558 r.try_get::<_, Option<T>>(name)
559 .map_err(|e| LookupError::Internal(format!("column '{name}': {e}")))
560 })
561 .collect()
562}
563
564#[cfg(feature = "postgres-cdc")]
567fn pg_type_to_arrow(pg_type: &Type) -> DataType {
568 match *pg_type {
569 Type::BOOL => DataType::Boolean,
570 Type::INT2 => DataType::Int16,
571 Type::INT4 => DataType::Int32,
572 Type::INT8 => DataType::Int64,
573 Type::FLOAT4 => DataType::Float32,
574 Type::FLOAT8 => DataType::Float64,
575 _ => DataType::Utf8,
576 }
577}
578
579#[cfg(all(test, feature = "postgres-cdc"))]
580mod tests {
581 use super::*;
582 use arrow_array::{Int64Array, StringArray};
583
584 #[test]
585 fn pg_type_map_native_and_text_fallback() {
586 assert_eq!(pg_type_to_arrow(&Type::INT8), DataType::Int64);
587 assert_eq!(pg_type_to_arrow(&Type::FLOAT8), DataType::Float64);
588 assert_eq!(pg_type_to_arrow(&Type::BOOL), DataType::Boolean);
589 assert_eq!(pg_type_to_arrow(&Type::TIMESTAMP), DataType::Utf8);
591 assert_eq!(pg_type_to_arrow(&Type::NUMERIC), DataType::Utf8);
592 assert_eq!(pg_type_to_arrow(&Type::UUID), DataType::Utf8);
593 }
594
595 #[test]
596 fn any_param_built_for_supported_types_skipping_nulls() {
597 assert!(
598 PostgresLookupSource::build_any_param(&Int64Array::from(vec![
599 Some(1i64),
600 None,
601 Some(3)
602 ]))
603 .is_ok()
604 );
605 assert!(PostgresLookupSource::build_any_param(&StringArray::from(vec!["a", "b"])).is_ok());
606 }
607
608 #[test]
609 fn any_param_rejects_unsupported_type() {
610 assert!(
611 PostgresLookupSource::build_any_param(&arrow_array::Date32Array::from(vec![1]))
612 .is_err()
613 );
614 }
615
616 fn props(kv: &[(&str, &str)]) -> HashMap<String, String> {
617 kv.iter().map(|(k, v)| ((*k).into(), (*v).into())).collect()
618 }
619
620 #[test]
621 fn tls_mode_parsing() {
622 assert!(!tls_enabled(&HashMap::new()).unwrap()); assert!(!tls_enabled(&props(&[("sslmode", "disable")])).unwrap());
624 assert!(tls_enabled(&props(&[("sslmode", "require")])).unwrap());
625 assert!(tls_enabled(&props(&[("ssl.mode", "verify-full")])).unwrap());
626 assert!(tls_enabled(&props(&[("sslmode", "bogus")])).is_err());
627 }
628
629 #[test]
630 fn tls_connector_builds_with_roots_and_rejects_bad_ca() {
631 assert!(build_rustls_connector(&HashMap::new()).is_ok());
633 assert!(build_rustls_connector(&props(&[("sslrootcert", "/no/such/ca.pem")])).is_err());
635 }
636}