Skip to main content

laminar_connectors/cdc/postgres/
config.rs

1//! `PostgreSQL` CDC source connector configuration.
2//!
3//! Provides [`PostgresCdcConfig`] with all settings needed to connect to
4//! a `PostgreSQL` database and stream logical replication changes.
5
6use std::time::Duration;
7
8use crate::config::ConnectorConfig;
9use crate::error::ConnectorError;
10
11use super::lsn::Lsn;
12
13/// Configuration for the `PostgreSQL` CDC source connector.
14#[derive(Debug, Clone)]
15pub struct PostgresCdcConfig {
16    // ── Connection ──
17    /// `PostgreSQL` host address.
18    pub host: String,
19
20    /// `PostgreSQL` port.
21    pub port: u16,
22
23    /// Database name.
24    pub database: String,
25
26    /// Username for authentication.
27    pub username: String,
28
29    /// Password for authentication.
30    pub password: Option<String>,
31
32    /// SSL mode for the connection.
33    pub ssl_mode: SslMode,
34
35    /// Path to CA certificate PEM file (for `VerifyCa` / `VerifyFull`).
36    pub ca_cert_path: Option<String>,
37
38    /// Path to client certificate PEM file (for mTLS).
39    pub client_cert_path: Option<String>,
40
41    /// Path to client private key PEM file (for mTLS).
42    pub client_key_path: Option<String>,
43
44    /// SNI hostname override (for proxy/load-balancer scenarios).
45    pub sni_hostname: Option<String>,
46
47    // ── Replication ──
48    /// Name of the logical replication slot.
49    pub slot_name: String,
50
51    /// Name of the publication to subscribe to.
52    pub publication: String,
53
54    /// LSN to start replication from (None = slot's `confirmed_flush_lsn`).
55    pub start_lsn: Option<Lsn>,
56
57    /// Output plugin name (always `pgoutput` for logical replication).
58    pub output_plugin: String,
59
60    // ── Snapshot ──
61    /// How to handle the initial data snapshot.
62    pub snapshot_mode: SnapshotMode,
63
64    // ── Tuning ──
65    /// Timeout for each poll operation.
66    pub poll_timeout: Duration,
67
68    /// Maximum records to return per poll.
69    pub max_poll_records: usize,
70
71    /// Interval for sending keepalive/status updates to `PostgreSQL`.
72    pub keepalive_interval: Duration,
73
74    /// Maximum WAL sender timeout before the server drops the connection.
75    pub wal_sender_timeout: Duration,
76
77    // ── Schema ──
78    /// Tables to include (empty = all tables in publication).
79    pub table_include: Vec<String>,
80
81    /// Tables to exclude from replication.
82    pub table_exclude: Vec<String>,
83}
84
85impl Default for PostgresCdcConfig {
86    fn default() -> Self {
87        Self {
88            host: "localhost".to_string(),
89            port: 5432,
90            database: "postgres".to_string(),
91            username: "postgres".to_string(),
92            password: None,
93            ssl_mode: SslMode::Prefer,
94            ca_cert_path: None,
95            client_cert_path: None,
96            client_key_path: None,
97            sni_hostname: None,
98            slot_name: "laminar_slot".to_string(),
99            publication: "laminar_pub".to_string(),
100            start_lsn: None,
101            output_plugin: "pgoutput".to_string(),
102            snapshot_mode: SnapshotMode::Initial,
103            poll_timeout: Duration::from_millis(100),
104            max_poll_records: 1000,
105            keepalive_interval: Duration::from_secs(10),
106            wal_sender_timeout: Duration::from_secs(60),
107            table_include: Vec::new(),
108            table_exclude: Vec::new(),
109        }
110    }
111}
112
113impl PostgresCdcConfig {
114    /// Creates a new config with required fields.
115    #[must_use]
116    pub fn new(host: &str, database: &str, slot_name: &str, publication: &str) -> Self {
117        Self {
118            host: host.to_string(),
119            database: database.to_string(),
120            slot_name: slot_name.to_string(),
121            publication: publication.to_string(),
122            ..Self::default()
123        }
124    }
125
126    /// Builds a `PostgreSQL` connection string.
127    #[must_use]
128    pub fn connection_string(&self) -> String {
129        use std::fmt::Write;
130        let mut s = format!(
131            "host={} port={} dbname={} user={}",
132            self.host, self.port, self.database, self.username
133        );
134        if let Some(ref pw) = self.password {
135            // Escape for libpq: wrap in single quotes, escape \ and '
136            let escaped = pw.replace('\\', "\\\\").replace('\'', "\\'");
137            let _ = write!(s, " password='{escaped}'");
138        }
139        let _ = write!(s, " sslmode={}", self.ssl_mode);
140        s
141    }
142
143    /// Parses configuration from a generic [`ConnectorConfig`].
144    ///
145    /// # Errors
146    ///
147    /// Returns `ConnectorError` if required keys are missing or values are
148    /// invalid.
149    pub fn from_config(config: &ConnectorConfig) -> Result<Self, ConnectorError> {
150        let mut cfg = Self {
151            host: config.require("host")?.to_string(),
152            database: config.require("database")?.to_string(),
153            slot_name: config.require("slot.name")?.to_string(),
154            publication: config.require("publication")?.to_string(),
155            ..Self::default()
156        };
157
158        if let Some(port) = config.get("port") {
159            cfg.port = crate::config::parse_port(port)?;
160        }
161        if let Some(user) = config.get("username") {
162            cfg.username = user.to_string();
163        }
164        cfg.password = config.get("password").map(String::from);
165
166        if let Some(ssl) = config.get_parsed::<SslMode>("ssl.mode")? {
167            cfg.ssl_mode = ssl;
168        }
169        cfg.ca_cert_path = config.get("ssl.ca.cert.path").map(String::from);
170        cfg.client_cert_path = config.get("ssl.client.cert.path").map(String::from);
171        cfg.client_key_path = config.get("ssl.client.key.path").map(String::from);
172        cfg.sni_hostname = config.get("ssl.sni.hostname").map(String::from);
173
174        if let Some(lsn) = config.get_parsed::<Lsn>("start.lsn")? {
175            cfg.start_lsn = Some(lsn);
176        }
177        if let Some(mode) = config.get_parsed::<SnapshotMode>("snapshot.mode")? {
178            cfg.snapshot_mode = mode;
179        }
180        if let Some(timeout) = config.get_parsed::<u64>("poll.timeout.ms")? {
181            cfg.poll_timeout = Duration::from_millis(timeout);
182        }
183        if let Some(max) = config.get_parsed::<usize>("max.poll.records")? {
184            cfg.max_poll_records = max;
185        }
186        if let Some(interval) = config.get_parsed::<u64>("keepalive.interval.ms")? {
187            cfg.keepalive_interval = Duration::from_millis(interval);
188        }
189        if let Some(tables) = config.get("table.include") {
190            cfg.table_include = tables.split(',').map(|s| s.trim().to_string()).collect();
191        }
192        if let Some(tables) = config.get("table.exclude") {
193            cfg.table_exclude = tables.split(',').map(|s| s.trim().to_string()).collect();
194        }
195
196        cfg.validate()?;
197        Ok(cfg)
198    }
199
200    /// Validates the configuration.
201    ///
202    /// # Errors
203    ///
204    /// Returns `ConnectorError::ConfigurationError` for invalid settings.
205    pub fn validate(&self) -> Result<(), ConnectorError> {
206        crate::config::require_non_empty(&self.host, "host")?;
207        crate::config::require_non_empty(&self.database, "database")?;
208        crate::config::require_non_empty(&self.slot_name, "slot.name")?;
209        crate::config::require_non_empty(&self.publication, "publication")?;
210        if self.max_poll_records == 0 {
211            return Err(ConnectorError::ConfigurationError(
212                "max.poll.records must be > 0".to_string(),
213            ));
214        }
215        // VerifyCa/VerifyFull require a CA certificate path
216        if matches!(self.ssl_mode, SslMode::VerifyCa | SslMode::VerifyFull)
217            && self.ca_cert_path.is_none()
218        {
219            return Err(ConnectorError::ConfigurationError(format!(
220                "ssl.mode={} requires ssl.ca.cert.path",
221                self.ssl_mode
222            )));
223        }
224        // Client cert without key (or vice versa) is invalid
225        if self.client_cert_path.is_some() != self.client_key_path.is_some() {
226            return Err(ConnectorError::ConfigurationError(
227                "ssl.client.cert.path and ssl.client.key.path must both be set for mTLS"
228                    .to_string(),
229            ));
230        }
231        Ok(())
232    }
233
234    /// Returns whether a table should be included based on include/exclude lists.
235    #[must_use]
236    pub fn should_include_table(&self, table: &str) -> bool {
237        if self.table_exclude.iter().any(|t| t == table) {
238            return false;
239        }
240        if self.table_include.is_empty() {
241            return true;
242        }
243        self.table_include.iter().any(|t| t == table)
244    }
245}
246
247/// SSL connection mode.
248#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
249pub enum SslMode {
250    /// No SSL.
251    Disable,
252    /// Try SSL, fall back to unencrypted.
253    #[default]
254    Prefer,
255    /// Require SSL.
256    Require,
257    /// Require SSL and verify CA certificate.
258    VerifyCa,
259    /// Require SSL and verify server hostname.
260    VerifyFull,
261}
262
263str_enum!(SslMode, lowercase_nodash, String, "unknown SSL mode",
264    Disable => "disable";
265    Prefer => "prefer";
266    Require => "require";
267    VerifyCa => "verify-ca", "verify_ca";
268    VerifyFull => "verify-full", "verify_full"
269);
270
271/// How to handle the initial snapshot when no prior checkpoint exists.
272#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
273pub enum SnapshotMode {
274    /// Take a full snapshot on first start, then switch to streaming.
275    #[default]
276    Initial,
277    /// Never take a snapshot; only stream from the replication slot's position.
278    Never,
279    /// Always take a snapshot on startup, even if a checkpoint exists.
280    Always,
281}
282
283str_enum!(SnapshotMode, lowercase_nodash, String, "unknown snapshot mode",
284    Initial => "initial";
285    Never => "never";
286    Always => "always"
287);
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    #[test]
294    fn test_default_config() {
295        let cfg = PostgresCdcConfig::default();
296        assert_eq!(cfg.host, "localhost");
297        assert_eq!(cfg.port, 5432);
298        assert_eq!(cfg.database, "postgres");
299        assert_eq!(cfg.slot_name, "laminar_slot");
300        assert_eq!(cfg.publication, "laminar_pub");
301        assert_eq!(cfg.output_plugin, "pgoutput");
302        assert_eq!(cfg.ssl_mode, SslMode::Prefer);
303        assert_eq!(cfg.snapshot_mode, SnapshotMode::Initial);
304        assert_eq!(cfg.max_poll_records, 1000);
305    }
306
307    #[test]
308    fn test_new_config() {
309        let cfg = PostgresCdcConfig::new("db.example.com", "mydb", "my_slot", "my_pub");
310        assert_eq!(cfg.host, "db.example.com");
311        assert_eq!(cfg.database, "mydb");
312        assert_eq!(cfg.slot_name, "my_slot");
313        assert_eq!(cfg.publication, "my_pub");
314    }
315
316    #[test]
317    fn test_connection_string() {
318        let mut cfg = PostgresCdcConfig::new("db.example.com", "mydb", "s", "p");
319        cfg.password = Some("secret".to_string());
320        let conn = cfg.connection_string();
321        assert!(conn.contains("host=db.example.com"));
322        assert!(conn.contains("dbname=mydb"));
323        assert!(conn.contains("password='secret'"));
324        assert!(conn.contains("sslmode=prefer"));
325    }
326
327    #[test]
328    fn test_connection_string_password_with_spaces() {
329        let mut cfg = PostgresCdcConfig::new("h", "d", "s", "p");
330        cfg.password = Some("my secret pass".to_string());
331        let conn = cfg.connection_string();
332        assert!(conn.contains("password='my secret pass'"));
333    }
334
335    #[test]
336    fn test_connection_string_password_with_quotes() {
337        let mut cfg = PostgresCdcConfig::new("h", "d", "s", "p");
338        cfg.password = Some("it's a p@ss'word".to_string());
339        let conn = cfg.connection_string();
340        assert!(conn.contains(r"password='it\'s a p@ss\'word'"));
341    }
342
343    #[test]
344    fn test_connection_string_password_with_backslash() {
345        let mut cfg = PostgresCdcConfig::new("h", "d", "s", "p");
346        cfg.password = Some(r"pass\word".to_string());
347        let conn = cfg.connection_string();
348        assert!(conn.contains(r"password='pass\\word'"));
349    }
350
351    #[test]
352    fn test_from_connector_config() {
353        let mut config = ConnectorConfig::new("postgres-cdc");
354        config.set("host", "pg.local");
355        config.set("database", "testdb");
356        config.set("slot.name", "test_slot");
357        config.set("publication", "test_pub");
358        config.set("port", "5433");
359        config.set("ssl.mode", "require");
360        config.set("snapshot.mode", "never");
361        config.set("max.poll.records", "500");
362
363        let cfg = PostgresCdcConfig::from_config(&config).unwrap();
364        assert_eq!(cfg.host, "pg.local");
365        assert_eq!(cfg.port, 5433);
366        assert_eq!(cfg.database, "testdb");
367        assert_eq!(cfg.ssl_mode, SslMode::Require);
368        assert_eq!(cfg.snapshot_mode, SnapshotMode::Never);
369        assert_eq!(cfg.max_poll_records, 500);
370    }
371
372    #[test]
373    fn test_from_config_missing_required() {
374        let config = ConnectorConfig::new("postgres-cdc");
375        assert!(PostgresCdcConfig::from_config(&config).is_err());
376    }
377
378    #[test]
379    fn test_from_config_invalid_port() {
380        let mut config = ConnectorConfig::new("postgres-cdc");
381        config.set("host", "localhost");
382        config.set("database", "db");
383        config.set("slot.name", "s");
384        config.set("publication", "p");
385        config.set("port", "not_a_number");
386        assert!(PostgresCdcConfig::from_config(&config).is_err());
387    }
388
389    #[test]
390    fn test_validate_empty_host() {
391        let mut cfg = PostgresCdcConfig::default();
392        cfg.host = String::new();
393        assert!(cfg.validate().is_err());
394    }
395
396    #[test]
397    fn test_validate_zero_max_poll() {
398        let mut cfg = PostgresCdcConfig::default();
399        cfg.max_poll_records = 0;
400        assert!(cfg.validate().is_err());
401    }
402
403    #[test]
404    fn test_ssl_mode_fromstr() {
405        assert_eq!("disable".parse::<SslMode>().unwrap(), SslMode::Disable);
406        assert_eq!("prefer".parse::<SslMode>().unwrap(), SslMode::Prefer);
407        assert_eq!("require".parse::<SslMode>().unwrap(), SslMode::Require);
408        assert_eq!("verify-ca".parse::<SslMode>().unwrap(), SslMode::VerifyCa);
409        assert_eq!(
410            "verify-full".parse::<SslMode>().unwrap(),
411            SslMode::VerifyFull
412        );
413        assert!("invalid".parse::<SslMode>().is_err());
414    }
415
416    #[test]
417    fn test_snapshot_mode_fromstr() {
418        assert_eq!(
419            "initial".parse::<SnapshotMode>().unwrap(),
420            SnapshotMode::Initial
421        );
422        assert_eq!(
423            "never".parse::<SnapshotMode>().unwrap(),
424            SnapshotMode::Never
425        );
426        assert_eq!(
427            "always".parse::<SnapshotMode>().unwrap(),
428            SnapshotMode::Always
429        );
430        assert!("bad".parse::<SnapshotMode>().is_err());
431    }
432
433    #[test]
434    fn test_ssl_mode_display() {
435        assert_eq!(SslMode::Disable.to_string(), "disable");
436        assert_eq!(SslMode::VerifyFull.to_string(), "verify-full");
437    }
438
439    #[test]
440    fn test_table_filtering() {
441        let mut cfg = PostgresCdcConfig::default();
442        // No filters → include all
443        assert!(cfg.should_include_table("public.users"));
444
445        // Include list
446        cfg.table_include = vec!["public.users".to_string(), "public.orders".to_string()];
447        assert!(cfg.should_include_table("public.users"));
448        assert!(!cfg.should_include_table("public.logs"));
449
450        // Exclude overrides include
451        cfg.table_exclude = vec!["public.users".to_string()];
452        assert!(!cfg.should_include_table("public.users"));
453    }
454
455    #[test]
456    fn test_from_config_with_start_lsn() {
457        let mut config = ConnectorConfig::new("postgres-cdc");
458        config.set("host", "localhost");
459        config.set("database", "db");
460        config.set("slot.name", "s");
461        config.set("publication", "p");
462        config.set("start.lsn", "0/1234ABCD");
463
464        let cfg = PostgresCdcConfig::from_config(&config).unwrap();
465        assert!(cfg.start_lsn.is_some());
466        assert_eq!(cfg.start_lsn.unwrap().as_u64(), 0x1234_ABCD);
467    }
468
469    #[test]
470    fn test_from_config_table_include() {
471        let mut config = ConnectorConfig::new("postgres-cdc");
472        config.set("host", "localhost");
473        config.set("database", "db");
474        config.set("slot.name", "s");
475        config.set("publication", "p");
476        config.set("table.include", "public.users, public.orders");
477
478        let cfg = PostgresCdcConfig::from_config(&config).unwrap();
479        assert_eq!(cfg.table_include, vec!["public.users", "public.orders"]);
480    }
481
482    // ── TLS cert path fields ──
483
484    #[test]
485    fn test_default_tls_fields_are_none() {
486        let cfg = PostgresCdcConfig::default();
487        assert!(cfg.ca_cert_path.is_none());
488        assert!(cfg.client_cert_path.is_none());
489        assert!(cfg.client_key_path.is_none());
490        assert!(cfg.sni_hostname.is_none());
491    }
492
493    #[test]
494    fn test_from_config_tls_cert_paths() {
495        let mut config = ConnectorConfig::new("postgres-cdc");
496        config.set("host", "localhost");
497        config.set("database", "db");
498        config.set("slot.name", "s");
499        config.set("publication", "p");
500        config.set("ssl.mode", "verify-full");
501        config.set("ssl.ca.cert.path", "/certs/ca.pem");
502        config.set("ssl.client.cert.path", "/certs/client.pem");
503        config.set("ssl.client.key.path", "/certs/client-key.pem");
504        config.set("ssl.sni.hostname", "db.example.com");
505
506        let cfg = PostgresCdcConfig::from_config(&config).unwrap();
507        assert_eq!(cfg.ssl_mode, SslMode::VerifyFull);
508        assert_eq!(cfg.ca_cert_path.as_deref(), Some("/certs/ca.pem"));
509        assert_eq!(cfg.client_cert_path.as_deref(), Some("/certs/client.pem"));
510        assert_eq!(
511            cfg.client_key_path.as_deref(),
512            Some("/certs/client-key.pem")
513        );
514        assert_eq!(cfg.sni_hostname.as_deref(), Some("db.example.com"));
515    }
516
517    #[test]
518    fn test_validate_verify_ca_requires_ca_path() {
519        let mut cfg = PostgresCdcConfig::default();
520        cfg.ssl_mode = SslMode::VerifyCa;
521        let err = cfg.validate().unwrap_err();
522        assert!(err.to_string().contains("ssl.ca.cert.path"));
523    }
524
525    #[test]
526    fn test_validate_verify_full_requires_ca_path() {
527        let mut cfg = PostgresCdcConfig::default();
528        cfg.ssl_mode = SslMode::VerifyFull;
529        let err = cfg.validate().unwrap_err();
530        assert!(err.to_string().contains("ssl.ca.cert.path"));
531    }
532
533    #[test]
534    fn test_validate_verify_ca_with_ca_path_ok() {
535        let mut cfg = PostgresCdcConfig::default();
536        cfg.ssl_mode = SslMode::VerifyCa;
537        cfg.ca_cert_path = Some("/certs/ca.pem".to_string());
538        assert!(cfg.validate().is_ok());
539    }
540
541    #[test]
542    fn test_validate_client_cert_without_key() {
543        let mut cfg = PostgresCdcConfig::default();
544        cfg.client_cert_path = Some("/certs/client.pem".to_string());
545        let err = cfg.validate().unwrap_err();
546        assert!(err.to_string().contains("mTLS"));
547    }
548
549    #[test]
550    fn test_validate_client_key_without_cert() {
551        let mut cfg = PostgresCdcConfig::default();
552        cfg.client_key_path = Some("/certs/client-key.pem".to_string());
553        let err = cfg.validate().unwrap_err();
554        assert!(err.to_string().contains("mTLS"));
555    }
556
557    #[test]
558    fn test_validate_client_cert_and_key_ok() {
559        let mut cfg = PostgresCdcConfig::default();
560        cfg.client_cert_path = Some("/certs/client.pem".to_string());
561        cfg.client_key_path = Some("/certs/client-key.pem".to_string());
562        assert!(cfg.validate().is_ok());
563    }
564
565    #[test]
566    fn test_require_mode_no_ca_path_ok() {
567        let mut cfg = PostgresCdcConfig::default();
568        cfg.ssl_mode = SslMode::Require;
569        assert!(cfg.validate().is_ok());
570    }
571}