1use super::lsn::Lsn;
27use crate::error::ConnectorError;
28
29#[derive(Debug, Clone, PartialEq, Eq)]
36pub enum ReplicationMessage {
37 XLogData {
42 wal_start: Lsn,
44 wal_end: Lsn,
46 server_time_us: i64,
48 data: Vec<u8>,
50 },
51
52 PrimaryKeepalive {
57 wal_end: Lsn,
59 server_time_us: i64,
61 reply_requested: bool,
64 },
65}
66
67#[allow(clippy::missing_panics_doc)] pub fn parse_replication_message(data: &[u8]) -> Result<ReplicationMessage, ConnectorError> {
83 if data.is_empty() {
84 return Err(ConnectorError::ReadError(
85 "empty replication message".to_string(),
86 ));
87 }
88
89 match data[0] {
90 b'w' => {
91 const HEADER_LEN: usize = 1 + 8 + 8 + 8; if data.len() < HEADER_LEN {
94 return Err(ConnectorError::ReadError(format!(
95 "truncated XLogData: {} bytes (need at least {HEADER_LEN})",
96 data.len()
97 )));
98 }
99
100 let wal_start = Lsn::new(u64::from_be_bytes(data[1..9].try_into().unwrap()));
102 let wal_end = Lsn::new(u64::from_be_bytes(data[9..17].try_into().unwrap()));
103 let server_time_us = i64::from_be_bytes(data[17..25].try_into().unwrap());
104 let payload = data[HEADER_LEN..].to_vec();
105
106 Ok(ReplicationMessage::XLogData {
107 wal_start,
108 wal_end,
109 server_time_us,
110 data: payload,
111 })
112 }
113 b'k' => {
114 const KEEPALIVE_LEN: usize = 1 + 8 + 8 + 1; if data.len() < KEEPALIVE_LEN {
117 return Err(ConnectorError::ReadError(format!(
118 "truncated PrimaryKeepalive: {} bytes (need {KEEPALIVE_LEN})",
119 data.len()
120 )));
121 }
122
123 let wal_end = Lsn::new(u64::from_be_bytes(data[1..9].try_into().unwrap()));
125 let server_time_us = i64::from_be_bytes(data[9..17].try_into().unwrap());
126 let reply_requested = data[17] != 0;
127
128 Ok(ReplicationMessage::PrimaryKeepalive {
129 wal_end,
130 server_time_us,
131 reply_requested,
132 })
133 }
134 tag => Err(ConnectorError::ReadError(format!(
135 "unknown replication message tag: 0x{tag:02X}"
136 ))),
137 }
138}
139
140#[must_use]
155pub fn encode_standby_status(write_lsn: Lsn, flush_lsn: Lsn, apply_lsn: Lsn) -> Vec<u8> {
156 let mut buf = Vec::with_capacity(34);
157 buf.push(b'r');
158 buf.extend_from_slice(&write_lsn.as_u64().to_be_bytes());
159 buf.extend_from_slice(&flush_lsn.as_u64().to_be_bytes());
160 buf.extend_from_slice(&apply_lsn.as_u64().to_be_bytes());
161 buf.extend_from_slice(&0_i64.to_be_bytes());
163 buf.push(0);
165 buf
166}
167
168fn validate_pg_identifier(value: &str, field: &str) -> Result<(), ConnectorError> {
170 if value.is_empty() {
171 return Err(ConnectorError::ConfigurationError(format!(
172 "{field} must not be empty"
173 )));
174 }
175 if !value
176 .bytes()
177 .all(|b| b.is_ascii_alphanumeric() || b == b'_')
178 {
179 return Err(ConnectorError::ConfigurationError(format!(
180 "{field} contains unsafe characters (only [a-zA-Z0-9_] allowed): {value:?}"
181 )));
182 }
183 Ok(())
184}
185
186pub fn build_start_replication_query(
197 slot_name: &str,
198 start_lsn: Lsn,
199 publication: &str,
200) -> Result<String, ConnectorError> {
201 validate_pg_identifier(slot_name, "slot_name")?;
202 validate_pg_identifier(publication, "publication")?;
203 Ok(format!(
204 "START_REPLICATION SLOT {slot_name} LOGICAL {start_lsn} \
205 (proto_version '1', publication_names '{publication}')"
206 ))
207}
208
209#[cfg(feature = "postgres-cdc")]
231pub async fn connect(
232 config: &super::config::PostgresCdcConfig,
233) -> Result<(tokio_postgres::Client, tokio::task::JoinHandle<()>), ConnectorError> {
234 use super::config::SslMode;
235
236 let conn_str = config.connection_string();
237
238 match config.ssl_mode {
239 SslMode::Disable => {}
240 SslMode::Prefer => {
241 tracing::info!(
242 ssl_mode = %config.ssl_mode,
243 "TLS not yet implemented for control-plane connections; using NoTls (Prefer mode)"
244 );
245 }
246 mode => {
247 return Err(ConnectorError::ConfigurationError(format!(
248 "ssl.mode={mode} requires TLS support which is not yet implemented \
249 for control-plane connections"
250 )));
251 }
252 }
253
254 let (client, connection) = tokio_postgres::connect(&conn_str, tokio_postgres::NoTls)
255 .await
256 .map_err(|e| ConnectorError::ConnectionFailed(format!("PostgreSQL connect: {e}")))?;
257
258 let handle = tokio::spawn(async move {
259 if let Err(e) = connection.await {
260 tracing::error!(error = %e, "PostgreSQL control-plane connection error");
261 }
262 });
263
264 Ok((client, handle))
265}
266
267#[cfg(feature = "postgres-cdc")]
280pub async fn ensure_replication_slot(
281 client: &tokio_postgres::Client,
282 slot_name: &str,
283 plugin: &str,
284) -> Result<Option<Lsn>, ConnectorError> {
285 let rows = client
287 .query(
288 "SELECT confirmed_flush_lsn::text FROM pg_replication_slots WHERE slot_name = $1",
289 &[&slot_name],
290 )
291 .await
292 .map_err(|e| ConnectorError::ConnectionFailed(format!("query replication slots: {e}")))?;
293
294 if let Some(row) = rows.first() {
295 let lsn_str: Option<&str> = row.get(0);
296 if let Some(lsn_str) = lsn_str {
297 let lsn: Lsn = lsn_str.parse().map_err(|e| {
298 ConnectorError::ReadError(format!("invalid confirmed_flush_lsn: {e}"))
299 })?;
300 tracing::info!(slot = slot_name, lsn = %lsn, "replication slot exists");
301 return Ok(Some(lsn));
302 }
303 tracing::info!(slot = slot_name, "replication slot exists (no flush LSN)");
305 return Ok(None);
306 }
307
308 client
310 .execute(
311 "SELECT pg_create_logical_replication_slot($1, $2)",
312 &[&slot_name, &plugin],
313 )
314 .await
315 .map_err(|e| ConnectorError::ConnectionFailed(format!("create replication slot: {e}")))?;
316
317 tracing::info!(
318 slot = slot_name,
319 plugin = plugin,
320 "created replication slot"
321 );
322 Ok(None)
323}
324
325#[cfg(feature = "postgres-cdc")]
339pub async fn drop_replication_slot(
340 client: &tokio_postgres::Client,
341 slot_name: &str,
342) -> Result<(), ConnectorError> {
343 client
344 .execute("SELECT pg_drop_replication_slot($1)", &[&slot_name])
345 .await
346 .map_err(|e| {
347 ConnectorError::ConnectionFailed(format!("drop replication slot '{slot_name}': {e}"))
348 })?;
349 tracing::info!(slot = slot_name, "dropped replication slot");
350 Ok(())
351}
352
353#[cfg(feature = "postgres-cdc")]
377pub fn build_replication_config(
378 config: &super::config::PostgresCdcConfig,
379) -> pgwire_replication::ReplicationConfig {
380 use std::path::PathBuf;
381
382 use super::config::SslMode;
383
384 let ca_path = config.ca_cert_path.as_ref().map(PathBuf::from);
385
386 let tls = match config.ssl_mode {
387 SslMode::Disable => pgwire_replication::TlsConfig::disabled(),
388 SslMode::Prefer | SslMode::Require => pgwire_replication::TlsConfig::require(),
389 SslMode::VerifyCa => pgwire_replication::TlsConfig::verify_ca(ca_path),
390 SslMode::VerifyFull => pgwire_replication::TlsConfig::verify_full(ca_path),
391 };
392
393 let tls = if let Some(ref hostname) = config.sni_hostname {
395 tls.with_sni_hostname(hostname)
396 } else {
397 tls
398 };
399
400 let tls = match (&config.client_cert_path, &config.client_key_path) {
402 (Some(cert), Some(key)) => tls.with_client_cert(PathBuf::from(cert), PathBuf::from(key)),
403 _ => tls,
404 };
405
406 let start_lsn = config
407 .start_lsn
408 .map_or(pgwire_replication::Lsn::ZERO, |lsn| {
409 pgwire_replication::Lsn::from_u64(lsn.as_u64())
410 });
411
412 pgwire_replication::ReplicationConfig {
413 host: config.host.clone(),
414 port: config.port,
415 user: config.username.clone(),
416 password: config.password.clone().unwrap_or_default(),
417 database: config.database.clone(),
418 tls,
419 slot: config.slot_name.clone(),
420 publication: config.publication.clone(),
421 start_lsn,
422 stop_at_lsn: None,
423 status_interval: config.keepalive_interval,
424 idle_wakeup_interval: config.poll_timeout,
425 buffer_events: 8192,
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432
433 #[test]
436 fn test_parse_xlog_data() {
437 let mut msg = vec![b'w'];
438 msg.extend_from_slice(&0x0000_0001_0000_0100_u64.to_be_bytes());
439 msg.extend_from_slice(&0x0000_0001_0000_0200_u64.to_be_bytes());
440 msg.extend_from_slice(&1_234_567_890_i64.to_be_bytes());
441 msg.extend_from_slice(b"hello pgoutput");
442
443 let parsed = parse_replication_message(&msg).unwrap();
444 match parsed {
445 ReplicationMessage::XLogData {
446 wal_start,
447 wal_end,
448 server_time_us,
449 data,
450 } => {
451 assert_eq!(wal_start, Lsn::new(0x0000_0001_0000_0100));
452 assert_eq!(wal_end, Lsn::new(0x0000_0001_0000_0200));
453 assert_eq!(server_time_us, 1_234_567_890);
454 assert_eq!(data, b"hello pgoutput");
455 }
456 ReplicationMessage::PrimaryKeepalive { .. } => panic!("expected XLogData"),
457 }
458 }
459
460 #[test]
461 fn test_parse_xlog_data_empty_payload() {
462 let mut msg = vec![b'w'];
463 msg.extend_from_slice(&0_u64.to_be_bytes());
464 msg.extend_from_slice(&0_u64.to_be_bytes());
465 msg.extend_from_slice(&0_i64.to_be_bytes());
466
467 let parsed = parse_replication_message(&msg).unwrap();
468 match parsed {
469 ReplicationMessage::XLogData { data, .. } => {
470 assert!(data.is_empty());
471 }
472 ReplicationMessage::PrimaryKeepalive { .. } => panic!("expected XLogData"),
473 }
474 }
475
476 #[test]
479 fn test_parse_keepalive_reply_requested() {
480 let mut msg = vec![b'k'];
481 msg.extend_from_slice(&0x0000_0002_0000_0500_u64.to_be_bytes());
482 msg.extend_from_slice(&9_876_543_210_i64.to_be_bytes());
483 msg.push(1);
484
485 let parsed = parse_replication_message(&msg).unwrap();
486 match parsed {
487 ReplicationMessage::PrimaryKeepalive {
488 wal_end,
489 server_time_us,
490 reply_requested,
491 } => {
492 assert_eq!(wal_end, Lsn::new(0x0000_0002_0000_0500));
493 assert_eq!(server_time_us, 9_876_543_210);
494 assert!(reply_requested);
495 }
496 ReplicationMessage::XLogData { .. } => panic!("expected PrimaryKeepalive"),
497 }
498 }
499
500 #[test]
501 fn test_parse_keepalive_no_reply() {
502 let mut msg = vec![b'k'];
503 msg.extend_from_slice(&0x100_u64.to_be_bytes());
504 msg.extend_from_slice(&0_i64.to_be_bytes());
505 msg.push(0);
506
507 let parsed = parse_replication_message(&msg).unwrap();
508 match parsed {
509 ReplicationMessage::PrimaryKeepalive {
510 reply_requested, ..
511 } => {
512 assert!(!reply_requested);
513 }
514 ReplicationMessage::XLogData { .. } => panic!("expected PrimaryKeepalive"),
515 }
516 }
517
518 #[test]
521 fn test_parse_empty_message() {
522 let err = parse_replication_message(&[]).unwrap_err();
523 assert!(err.to_string().contains("empty"));
524 }
525
526 #[test]
527 fn test_parse_unknown_tag() {
528 let err = parse_replication_message(&[0xFF]).unwrap_err();
529 assert!(err.to_string().contains("unknown"));
530 assert!(err.to_string().contains("0xFF"));
531 }
532
533 #[test]
534 fn test_parse_truncated_xlog_data() {
535 let msg = vec![b'w', 0, 0, 0, 0, 0, 0, 0, 0, 0];
536 let err = parse_replication_message(&msg).unwrap_err();
537 assert!(err.to_string().contains("truncated"));
538 }
539
540 #[test]
541 fn test_parse_truncated_keepalive() {
542 let msg = vec![b'k', 0, 0, 0, 0, 0, 0, 0, 0, 0];
543 let err = parse_replication_message(&msg).unwrap_err();
544 assert!(err.to_string().contains("truncated"));
545 }
546
547 #[test]
550 fn test_encode_standby_status_layout() {
551 let write_lsn = Lsn::new(0x0000_0001_0000_0100);
552 let flush_lsn = Lsn::new(0x0000_0001_0000_0080);
553 let apply_lsn = Lsn::new(0x0000_0001_0000_0080);
554
555 let buf = encode_standby_status(write_lsn, flush_lsn, apply_lsn);
556
557 assert_eq!(buf.len(), 34, "standby status must be exactly 34 bytes");
558 assert_eq!(buf[0], b'r', "tag must be 'r'");
559
560 let w = u64::from_be_bytes(buf[1..9].try_into().unwrap());
562 assert_eq!(w, 0x0000_0001_0000_0100);
563
564 let f = u64::from_be_bytes(buf[9..17].try_into().unwrap());
566 assert_eq!(f, 0x0000_0001_0000_0080);
567
568 let a = u64::from_be_bytes(buf[17..25].try_into().unwrap());
570 assert_eq!(a, 0x0000_0001_0000_0080);
571
572 let ts = i64::from_be_bytes(buf[25..33].try_into().unwrap());
574 assert_eq!(ts, 0);
575
576 assert_eq!(buf[33], 0);
578 }
579
580 #[test]
583 fn test_build_start_replication_query() {
584 let query =
585 build_start_replication_query("my_slot", "0/1234ABCD".parse().unwrap(), "my_pub")
586 .unwrap();
587 assert!(query.contains("START_REPLICATION SLOT my_slot LOGICAL 0/1234ABCD"));
588 assert!(query.contains("proto_version '1'"));
589 assert!(query.contains("publication_names 'my_pub'"));
590 }
591
592 #[test]
593 fn test_build_start_replication_query_rejects_injection() {
594 let result = build_start_replication_query(
595 "slot'; DROP TABLE users; --",
596 "0/0".parse().unwrap(),
597 "pub",
598 );
599 assert!(result.is_err());
600 assert!(result
601 .unwrap_err()
602 .to_string()
603 .contains("unsafe characters"));
604 }
605
606 #[test]
607 fn test_build_start_replication_query_rejects_empty() {
608 let result = build_start_replication_query("", "0/0".parse().unwrap(), "pub");
609 assert!(result.is_err());
610 }
611
612 #[test]
613 fn test_validate_pg_identifier_accepts_valid() {
614 assert!(validate_pg_identifier("my_slot_123", "test").is_ok());
615 }
616
617 #[cfg(feature = "postgres-cdc")]
620 mod tls_mapping_tests {
621 use super::super::build_replication_config;
622 use crate::cdc::postgres::config::{PostgresCdcConfig, SslMode};
623
624 #[test]
625 fn test_disable_maps_to_disabled() {
626 let mut cfg = PostgresCdcConfig::default();
627 cfg.ssl_mode = SslMode::Disable;
628 let repl = build_replication_config(&cfg);
629 assert_eq!(repl.tls.mode, pgwire_replication::SslMode::Disable);
630 }
631
632 #[test]
633 fn test_prefer_maps_to_require() {
634 let cfg = PostgresCdcConfig::default(); let repl = build_replication_config(&cfg);
636 assert_eq!(repl.tls.mode, pgwire_replication::SslMode::Require);
637 }
638
639 #[test]
640 fn test_require_maps_to_require() {
641 let mut cfg = PostgresCdcConfig::default();
642 cfg.ssl_mode = SslMode::Require;
643 let repl = build_replication_config(&cfg);
644 assert_eq!(repl.tls.mode, pgwire_replication::SslMode::Require);
645 }
646
647 #[test]
648 fn test_verify_ca_maps_with_ca_path() {
649 let mut cfg = PostgresCdcConfig::default();
650 cfg.ssl_mode = SslMode::VerifyCa;
651 cfg.ca_cert_path = Some("/certs/ca.pem".to_string());
652 let repl = build_replication_config(&cfg);
653 assert_eq!(repl.tls.mode, pgwire_replication::SslMode::VerifyCa);
654 assert_eq!(
655 repl.tls.ca_pem_path.as_deref(),
656 Some(std::path::Path::new("/certs/ca.pem"))
657 );
658 }
659
660 #[test]
661 fn test_verify_full_maps_with_ca_path() {
662 let mut cfg = PostgresCdcConfig::default();
663 cfg.ssl_mode = SslMode::VerifyFull;
664 cfg.ca_cert_path = Some("/certs/ca.pem".to_string());
665 let repl = build_replication_config(&cfg);
666 assert_eq!(repl.tls.mode, pgwire_replication::SslMode::VerifyFull);
667 assert_eq!(
668 repl.tls.ca_pem_path.as_deref(),
669 Some(std::path::Path::new("/certs/ca.pem"))
670 );
671 }
672
673 #[test]
674 fn test_sni_hostname_applied() {
675 let mut cfg = PostgresCdcConfig::default();
676 cfg.sni_hostname = Some("db.example.com".to_string());
677 let repl = build_replication_config(&cfg);
678 assert_eq!(repl.tls.sni_hostname.as_deref(), Some("db.example.com"));
679 }
680
681 #[test]
682 fn test_mtls_client_cert_applied() {
683 let mut cfg = PostgresCdcConfig::default();
684 cfg.ssl_mode = SslMode::Require;
685 cfg.client_cert_path = Some("/certs/client.pem".to_string());
686 cfg.client_key_path = Some("/certs/client-key.pem".to_string());
687 let repl = build_replication_config(&cfg);
688 assert_eq!(
689 repl.tls.client_cert_pem_path.as_deref(),
690 Some(std::path::Path::new("/certs/client.pem"))
691 );
692 assert_eq!(
693 repl.tls.client_key_pem_path.as_deref(),
694 Some(std::path::Path::new("/certs/client-key.pem"))
695 );
696 }
697
698 #[test]
699 fn test_no_client_cert_when_not_set() {
700 let cfg = PostgresCdcConfig::default();
701 let repl = build_replication_config(&cfg);
702 assert!(repl.tls.client_cert_pem_path.is_none());
703 assert!(repl.tls.client_key_pem_path.is_none());
704 }
705
706 #[test]
707 fn test_connection_fields_mapped() {
708 let mut cfg = PostgresCdcConfig::new("pg.example.com", "mydb", "my_slot", "my_pub");
709 cfg.port = 5433;
710 cfg.username = "replicator".to_string();
711 cfg.password = Some("secret".to_string());
712 let repl = build_replication_config(&cfg);
713 assert_eq!(repl.host, "pg.example.com");
714 assert_eq!(repl.port, 5433);
715 assert_eq!(repl.user, "replicator");
716 assert_eq!(repl.password, "secret");
717 assert_eq!(repl.database, "mydb");
718 assert_eq!(repl.slot, "my_slot");
719 assert_eq!(repl.publication, "my_pub");
720 }
721 }
722}