1use super::lsn::Lsn;
4use crate::error::ConnectorError;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
13pub enum ReplicationMessage {
14 XLogData {
19 wal_start: Lsn,
21 wal_end: Lsn,
23 server_time_us: i64,
25 data: Vec<u8>,
27 },
28
29 PrimaryKeepalive {
34 wal_end: Lsn,
36 server_time_us: i64,
38 reply_requested: bool,
41 },
42}
43
44#[allow(clippy::missing_panics_doc)] pub fn parse_replication_message(data: &[u8]) -> Result<ReplicationMessage, ConnectorError> {
60 if data.is_empty() {
61 return Err(ConnectorError::ReadError(
62 "empty replication message".to_string(),
63 ));
64 }
65
66 match data[0] {
67 b'w' => {
68 const HEADER_LEN: usize = 1 + 8 + 8 + 8; if data.len() < HEADER_LEN {
71 return Err(ConnectorError::ReadError(format!(
72 "truncated XLogData: {} bytes (need at least {HEADER_LEN})",
73 data.len()
74 )));
75 }
76
77 let wal_start = Lsn::new(u64::from_be_bytes(data[1..9].try_into().unwrap()));
79 let wal_end = Lsn::new(u64::from_be_bytes(data[9..17].try_into().unwrap()));
80 let server_time_us = i64::from_be_bytes(data[17..25].try_into().unwrap());
81 let payload = data[HEADER_LEN..].to_vec();
82
83 Ok(ReplicationMessage::XLogData {
84 wal_start,
85 wal_end,
86 server_time_us,
87 data: payload,
88 })
89 }
90 b'k' => {
91 const KEEPALIVE_LEN: usize = 1 + 8 + 8 + 1; if data.len() < KEEPALIVE_LEN {
94 return Err(ConnectorError::ReadError(format!(
95 "truncated PrimaryKeepalive: {} bytes (need {KEEPALIVE_LEN})",
96 data.len()
97 )));
98 }
99
100 let wal_end = Lsn::new(u64::from_be_bytes(data[1..9].try_into().unwrap()));
102 let server_time_us = i64::from_be_bytes(data[9..17].try_into().unwrap());
103 let reply_requested = data[17] != 0;
104
105 Ok(ReplicationMessage::PrimaryKeepalive {
106 wal_end,
107 server_time_us,
108 reply_requested,
109 })
110 }
111 tag => Err(ConnectorError::ReadError(format!(
112 "unknown replication message tag: 0x{tag:02X}"
113 ))),
114 }
115}
116
117#[must_use]
132pub fn encode_standby_status(write_lsn: Lsn, flush_lsn: Lsn, apply_lsn: Lsn) -> Vec<u8> {
133 let mut buf = Vec::with_capacity(34);
134 buf.push(b'r');
135 buf.extend_from_slice(&write_lsn.as_u64().to_be_bytes());
136 buf.extend_from_slice(&flush_lsn.as_u64().to_be_bytes());
137 buf.extend_from_slice(&apply_lsn.as_u64().to_be_bytes());
138 buf.extend_from_slice(&0_i64.to_be_bytes());
140 buf.push(0);
142 buf
143}
144
145fn validate_pg_identifier(value: &str, field: &str) -> Result<(), ConnectorError> {
147 if value.is_empty() {
148 return Err(ConnectorError::ConfigurationError(format!(
149 "{field} must not be empty"
150 )));
151 }
152 if !value
153 .bytes()
154 .all(|b| b.is_ascii_alphanumeric() || b == b'_')
155 {
156 return Err(ConnectorError::ConfigurationError(format!(
157 "{field} contains unsafe characters (only [a-zA-Z0-9_] allowed): {value:?}"
158 )));
159 }
160 Ok(())
161}
162
163pub fn build_start_replication_query(
172 slot_name: &str,
173 start_lsn: Lsn,
174 publication: &str,
175) -> Result<String, ConnectorError> {
176 validate_pg_identifier(slot_name, "slot_name")?;
177 validate_pg_identifier(publication, "publication")?;
178 Ok(format!(
179 "START_REPLICATION SLOT {slot_name} LOGICAL {start_lsn} \
180 (proto_version '1', publication_names '{publication}')"
181 ))
182}
183
184#[cfg(feature = "postgres-cdc")]
205pub async fn connect(
206 config: &super::config::PostgresCdcConfig,
207) -> Result<(tokio_postgres::Client, tokio::task::JoinHandle<()>), ConnectorError> {
208 use super::config::SslMode;
209
210 let conn_str = config.connection_string();
211
212 match config.ssl_mode {
222 SslMode::Disable => {}
223 mode => {
224 return Err(ConnectorError::ConfigurationError(format!(
225 "ssl.mode={mode}: TLS for the postgres-cdc control-plane \
226 connection is not implemented; only ssl.mode=disable is \
227 currently supported. The replication stream still \
228 honours ssl.mode."
229 )));
230 }
231 }
232
233 let (client, connection) = tokio_postgres::connect(&conn_str, tokio_postgres::NoTls)
234 .await
235 .map_err(|e| ConnectorError::ConnectionFailed(format!("PostgreSQL connect: {e}")))?;
236
237 let handle = tokio::spawn(async move {
238 if let Err(e) = connection.await {
239 tracing::error!(error = %e, "PostgreSQL control-plane connection error");
240 }
241 });
242
243 Ok((client, handle))
244}
245
246#[cfg(feature = "postgres-cdc")]
259pub async fn ensure_replication_slot(
260 client: &tokio_postgres::Client,
261 slot_name: &str,
262 plugin: &str,
263) -> Result<Option<Lsn>, ConnectorError> {
264 let rows = client
266 .query(
267 "SELECT confirmed_flush_lsn::text FROM pg_replication_slots WHERE slot_name = $1",
268 &[&slot_name],
269 )
270 .await
271 .map_err(|e| ConnectorError::ConnectionFailed(format!("query replication slots: {e}")))?;
272
273 if let Some(row) = rows.first() {
274 let lsn_str: Option<&str> = row.get(0);
275 if let Some(lsn_str) = lsn_str {
276 let lsn: Lsn = lsn_str.parse().map_err(|e| {
277 ConnectorError::ReadError(format!("invalid confirmed_flush_lsn: {e}"))
278 })?;
279 tracing::info!(slot = slot_name, lsn = %lsn, "replication slot exists");
280 return Ok(Some(lsn));
281 }
282 tracing::info!(slot = slot_name, "replication slot exists (no flush LSN)");
284 return Ok(None);
285 }
286
287 client
289 .execute(
290 "SELECT pg_create_logical_replication_slot($1, $2)",
291 &[&slot_name, &plugin],
292 )
293 .await
294 .map_err(|e| ConnectorError::ConnectionFailed(format!("create replication slot: {e}")))?;
295
296 tracing::info!(
297 slot = slot_name,
298 plugin = plugin,
299 "created replication slot"
300 );
301 Ok(None)
302}
303
304#[cfg(feature = "postgres-cdc")]
318pub async fn drop_replication_slot(
319 client: &tokio_postgres::Client,
320 slot_name: &str,
321) -> Result<(), ConnectorError> {
322 client
323 .execute("SELECT pg_drop_replication_slot($1)", &[&slot_name])
324 .await
325 .map_err(|e| {
326 ConnectorError::ConnectionFailed(format!("drop replication slot '{slot_name}': {e}"))
327 })?;
328 tracing::info!(slot = slot_name, "dropped replication slot");
329 Ok(())
330}
331
332#[cfg(feature = "postgres-cdc")]
356pub fn build_replication_config(
357 config: &super::config::PostgresCdcConfig,
358) -> pgwire_replication::ReplicationConfig {
359 use std::path::PathBuf;
360
361 use super::config::SslMode;
362
363 let ca_path = config.ca_cert_path.as_ref().map(PathBuf::from);
364
365 let tls = match config.ssl_mode {
366 SslMode::Disable => pgwire_replication::TlsConfig::disabled(),
367 SslMode::Prefer | SslMode::Require => pgwire_replication::TlsConfig::require(),
368 SslMode::VerifyCa => pgwire_replication::TlsConfig::verify_ca(ca_path),
369 SslMode::VerifyFull => pgwire_replication::TlsConfig::verify_full(ca_path),
370 };
371
372 let tls = if let Some(ref hostname) = config.sni_hostname {
374 tls.with_sni_hostname(hostname)
375 } else {
376 tls
377 };
378
379 let tls = match (&config.client_cert_path, &config.client_key_path) {
381 (Some(cert), Some(key)) => tls.with_client_cert(PathBuf::from(cert), PathBuf::from(key)),
382 _ => tls,
383 };
384
385 let start_lsn = config
386 .start_lsn
387 .map_or(pgwire_replication::Lsn::ZERO, |lsn| {
388 pgwire_replication::Lsn::from_u64(lsn.as_u64())
389 });
390
391 pgwire_replication::ReplicationConfig {
392 host: config.host.clone(),
393 port: config.port,
394 user: config.username.clone(),
395 password: config.password.clone().unwrap_or_default(),
396 database: config.database.clone(),
397 tls,
398 slot: config.slot_name.clone(),
399 publication: config.publication.clone(),
400 start_lsn,
401 stop_at_lsn: None,
402 status_interval: config.keepalive_interval,
403 idle_wakeup_interval: config.poll_timeout,
404 buffer_events: 8192,
405 }
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411
412 #[test]
415 fn test_parse_xlog_data() {
416 let mut msg = vec![b'w'];
417 msg.extend_from_slice(&0x0000_0001_0000_0100_u64.to_be_bytes());
418 msg.extend_from_slice(&0x0000_0001_0000_0200_u64.to_be_bytes());
419 msg.extend_from_slice(&1_234_567_890_i64.to_be_bytes());
420 msg.extend_from_slice(b"hello pgoutput");
421
422 let parsed = parse_replication_message(&msg).unwrap();
423 match parsed {
424 ReplicationMessage::XLogData {
425 wal_start,
426 wal_end,
427 server_time_us,
428 data,
429 } => {
430 assert_eq!(wal_start, Lsn::new(0x0000_0001_0000_0100));
431 assert_eq!(wal_end, Lsn::new(0x0000_0001_0000_0200));
432 assert_eq!(server_time_us, 1_234_567_890);
433 assert_eq!(data, b"hello pgoutput");
434 }
435 ReplicationMessage::PrimaryKeepalive { .. } => panic!("expected XLogData"),
436 }
437 }
438
439 #[test]
440 fn test_parse_xlog_data_empty_payload() {
441 let mut msg = vec![b'w'];
442 msg.extend_from_slice(&0_u64.to_be_bytes());
443 msg.extend_from_slice(&0_u64.to_be_bytes());
444 msg.extend_from_slice(&0_i64.to_be_bytes());
445
446 let parsed = parse_replication_message(&msg).unwrap();
447 match parsed {
448 ReplicationMessage::XLogData { data, .. } => {
449 assert!(data.is_empty());
450 }
451 ReplicationMessage::PrimaryKeepalive { .. } => panic!("expected XLogData"),
452 }
453 }
454
455 #[test]
458 fn test_parse_keepalive_reply_requested() {
459 let mut msg = vec![b'k'];
460 msg.extend_from_slice(&0x0000_0002_0000_0500_u64.to_be_bytes());
461 msg.extend_from_slice(&9_876_543_210_i64.to_be_bytes());
462 msg.push(1);
463
464 let parsed = parse_replication_message(&msg).unwrap();
465 match parsed {
466 ReplicationMessage::PrimaryKeepalive {
467 wal_end,
468 server_time_us,
469 reply_requested,
470 } => {
471 assert_eq!(wal_end, Lsn::new(0x0000_0002_0000_0500));
472 assert_eq!(server_time_us, 9_876_543_210);
473 assert!(reply_requested);
474 }
475 ReplicationMessage::XLogData { .. } => panic!("expected PrimaryKeepalive"),
476 }
477 }
478
479 #[test]
480 fn test_parse_keepalive_no_reply() {
481 let mut msg = vec![b'k'];
482 msg.extend_from_slice(&0x100_u64.to_be_bytes());
483 msg.extend_from_slice(&0_i64.to_be_bytes());
484 msg.push(0);
485
486 let parsed = parse_replication_message(&msg).unwrap();
487 match parsed {
488 ReplicationMessage::PrimaryKeepalive {
489 reply_requested, ..
490 } => {
491 assert!(!reply_requested);
492 }
493 ReplicationMessage::XLogData { .. } => panic!("expected PrimaryKeepalive"),
494 }
495 }
496
497 #[test]
500 fn test_parse_empty_message() {
501 let err = parse_replication_message(&[]).unwrap_err();
502 assert!(err.to_string().contains("empty"));
503 }
504
505 #[test]
506 fn test_parse_unknown_tag() {
507 let err = parse_replication_message(&[0xFF]).unwrap_err();
508 assert!(err.to_string().contains("unknown"));
509 assert!(err.to_string().contains("0xFF"));
510 }
511
512 #[test]
513 fn test_parse_truncated_xlog_data() {
514 let msg = vec![b'w', 0, 0, 0, 0, 0, 0, 0, 0, 0];
515 let err = parse_replication_message(&msg).unwrap_err();
516 assert!(err.to_string().contains("truncated"));
517 }
518
519 #[test]
520 fn test_parse_truncated_keepalive() {
521 let msg = vec![b'k', 0, 0, 0, 0, 0, 0, 0, 0, 0];
522 let err = parse_replication_message(&msg).unwrap_err();
523 assert!(err.to_string().contains("truncated"));
524 }
525
526 #[test]
529 fn test_encode_standby_status_layout() {
530 let write_lsn = Lsn::new(0x0000_0001_0000_0100);
531 let flush_lsn = Lsn::new(0x0000_0001_0000_0080);
532 let apply_lsn = Lsn::new(0x0000_0001_0000_0080);
533
534 let buf = encode_standby_status(write_lsn, flush_lsn, apply_lsn);
535
536 assert_eq!(buf.len(), 34, "standby status must be exactly 34 bytes");
537 assert_eq!(buf[0], b'r', "tag must be 'r'");
538
539 let w = u64::from_be_bytes(buf[1..9].try_into().unwrap());
541 assert_eq!(w, 0x0000_0001_0000_0100);
542
543 let f = u64::from_be_bytes(buf[9..17].try_into().unwrap());
545 assert_eq!(f, 0x0000_0001_0000_0080);
546
547 let a = u64::from_be_bytes(buf[17..25].try_into().unwrap());
549 assert_eq!(a, 0x0000_0001_0000_0080);
550
551 let ts = i64::from_be_bytes(buf[25..33].try_into().unwrap());
553 assert_eq!(ts, 0);
554
555 assert_eq!(buf[33], 0);
557 }
558
559 #[test]
562 fn test_build_start_replication_query() {
563 let query =
564 build_start_replication_query("my_slot", "0/1234ABCD".parse().unwrap(), "my_pub")
565 .unwrap();
566 assert!(query.contains("START_REPLICATION SLOT my_slot LOGICAL 0/1234ABCD"));
567 assert!(query.contains("proto_version '1'"));
568 assert!(query.contains("publication_names 'my_pub'"));
569 }
570
571 #[test]
572 fn test_build_start_replication_query_rejects_injection() {
573 let result = build_start_replication_query(
574 "slot'; DROP TABLE users; --",
575 "0/0".parse().unwrap(),
576 "pub",
577 );
578 assert!(result.is_err());
579 assert!(result
580 .unwrap_err()
581 .to_string()
582 .contains("unsafe characters"));
583 }
584
585 #[test]
586 fn test_build_start_replication_query_rejects_empty() {
587 let result = build_start_replication_query("", "0/0".parse().unwrap(), "pub");
588 assert!(result.is_err());
589 }
590
591 #[test]
592 fn test_validate_pg_identifier_accepts_valid() {
593 assert!(validate_pg_identifier("my_slot_123", "test").is_ok());
594 }
595
596 #[cfg(feature = "postgres-cdc")]
599 mod tls_mapping_tests {
600 use super::super::build_replication_config;
601 use crate::cdc::postgres::config::{PostgresCdcConfig, SslMode};
602
603 #[test]
604 fn test_disable_maps_to_disabled() {
605 let mut cfg = PostgresCdcConfig::default();
606 cfg.ssl_mode = SslMode::Disable;
607 let repl = build_replication_config(&cfg);
608 assert_eq!(repl.tls.mode, pgwire_replication::SslMode::Disable);
609 }
610
611 #[test]
612 fn test_prefer_maps_to_require() {
613 let cfg = PostgresCdcConfig::default(); let repl = build_replication_config(&cfg);
615 assert_eq!(repl.tls.mode, pgwire_replication::SslMode::Require);
616 }
617
618 #[test]
619 fn test_require_maps_to_require() {
620 let mut cfg = PostgresCdcConfig::default();
621 cfg.ssl_mode = SslMode::Require;
622 let repl = build_replication_config(&cfg);
623 assert_eq!(repl.tls.mode, pgwire_replication::SslMode::Require);
624 }
625
626 #[test]
627 fn test_verify_ca_maps_with_ca_path() {
628 let mut cfg = PostgresCdcConfig::default();
629 cfg.ssl_mode = SslMode::VerifyCa;
630 cfg.ca_cert_path = Some("/certs/ca.pem".to_string());
631 let repl = build_replication_config(&cfg);
632 assert_eq!(repl.tls.mode, pgwire_replication::SslMode::VerifyCa);
633 assert_eq!(
634 repl.tls.ca_pem_path.as_deref(),
635 Some(std::path::Path::new("/certs/ca.pem"))
636 );
637 }
638
639 #[test]
640 fn test_verify_full_maps_with_ca_path() {
641 let mut cfg = PostgresCdcConfig::default();
642 cfg.ssl_mode = SslMode::VerifyFull;
643 cfg.ca_cert_path = Some("/certs/ca.pem".to_string());
644 let repl = build_replication_config(&cfg);
645 assert_eq!(repl.tls.mode, pgwire_replication::SslMode::VerifyFull);
646 assert_eq!(
647 repl.tls.ca_pem_path.as_deref(),
648 Some(std::path::Path::new("/certs/ca.pem"))
649 );
650 }
651
652 #[test]
653 fn test_sni_hostname_applied() {
654 let mut cfg = PostgresCdcConfig::default();
655 cfg.sni_hostname = Some("db.example.com".to_string());
656 let repl = build_replication_config(&cfg);
657 assert_eq!(repl.tls.sni_hostname.as_deref(), Some("db.example.com"));
658 }
659
660 #[test]
661 fn test_mtls_client_cert_applied() {
662 let mut cfg = PostgresCdcConfig::default();
663 cfg.ssl_mode = SslMode::Require;
664 cfg.client_cert_path = Some("/certs/client.pem".to_string());
665 cfg.client_key_path = Some("/certs/client-key.pem".to_string());
666 let repl = build_replication_config(&cfg);
667 assert_eq!(
668 repl.tls.client_cert_pem_path.as_deref(),
669 Some(std::path::Path::new("/certs/client.pem"))
670 );
671 assert_eq!(
672 repl.tls.client_key_pem_path.as_deref(),
673 Some(std::path::Path::new("/certs/client-key.pem"))
674 );
675 }
676
677 #[test]
678 fn test_no_client_cert_when_not_set() {
679 let cfg = PostgresCdcConfig::default();
680 let repl = build_replication_config(&cfg);
681 assert!(repl.tls.client_cert_pem_path.is_none());
682 assert!(repl.tls.client_key_pem_path.is_none());
683 }
684
685 #[test]
686 fn test_connection_fields_mapped() {
687 let mut cfg = PostgresCdcConfig::new("pg.example.com", "mydb", "my_slot", "my_pub");
688 cfg.port = 5433;
689 cfg.username = "replicator".to_string();
690 cfg.password = Some("secret".to_string());
691 let repl = build_replication_config(&cfg);
692 assert_eq!(repl.host, "pg.example.com");
693 assert_eq!(repl.port, 5433);
694 assert_eq!(repl.user, "replicator");
695 assert_eq!(repl.password, "secret");
696 assert_eq!(repl.database, "mydb");
697 assert_eq!(repl.slot, "my_slot");
698 assert_eq!(repl.publication, "my_pub");
699 }
700 }
701}