1use super::lsn::Lsn;
4use crate::error::ConnectorError;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
13pub enum ReplicationMessage {
14 XLogData {
19 wal_start: Lsn,
21 wal_end: Lsn,
23 server_time_us: i64,
25 data: Vec<u8>,
27 },
28
29 PrimaryKeepalive {
34 wal_end: Lsn,
36 server_time_us: i64,
38 reply_requested: bool,
41 },
42}
43
44#[allow(clippy::missing_panics_doc)] pub fn parse_replication_message(data: &[u8]) -> Result<ReplicationMessage, ConnectorError> {
60 if data.is_empty() {
61 return Err(ConnectorError::ReadError(
62 "empty replication message".to_string(),
63 ));
64 }
65
66 match data[0] {
67 b'w' => {
68 const HEADER_LEN: usize = 1 + 8 + 8 + 8; if data.len() < HEADER_LEN {
71 return Err(ConnectorError::ReadError(format!(
72 "truncated XLogData: {} bytes (need at least {HEADER_LEN})",
73 data.len()
74 )));
75 }
76
77 let wal_start = Lsn::new(u64::from_be_bytes(data[1..9].try_into().unwrap()));
79 let wal_end = Lsn::new(u64::from_be_bytes(data[9..17].try_into().unwrap()));
80 let server_time_us = i64::from_be_bytes(data[17..25].try_into().unwrap());
81 let payload = data[HEADER_LEN..].to_vec();
82
83 Ok(ReplicationMessage::XLogData {
84 wal_start,
85 wal_end,
86 server_time_us,
87 data: payload,
88 })
89 }
90 b'k' => {
91 const KEEPALIVE_LEN: usize = 1 + 8 + 8 + 1; if data.len() < KEEPALIVE_LEN {
94 return Err(ConnectorError::ReadError(format!(
95 "truncated PrimaryKeepalive: {} bytes (need {KEEPALIVE_LEN})",
96 data.len()
97 )));
98 }
99
100 let wal_end = Lsn::new(u64::from_be_bytes(data[1..9].try_into().unwrap()));
102 let server_time_us = i64::from_be_bytes(data[9..17].try_into().unwrap());
103 let reply_requested = data[17] != 0;
104
105 Ok(ReplicationMessage::PrimaryKeepalive {
106 wal_end,
107 server_time_us,
108 reply_requested,
109 })
110 }
111 tag => Err(ConnectorError::ReadError(format!(
112 "unknown replication message tag: 0x{tag:02X}"
113 ))),
114 }
115}
116
117#[must_use]
132pub fn encode_standby_status(write_lsn: Lsn, flush_lsn: Lsn, apply_lsn: Lsn) -> Vec<u8> {
133 let mut buf = Vec::with_capacity(34);
134 buf.push(b'r');
135 buf.extend_from_slice(&write_lsn.as_u64().to_be_bytes());
136 buf.extend_from_slice(&flush_lsn.as_u64().to_be_bytes());
137 buf.extend_from_slice(&apply_lsn.as_u64().to_be_bytes());
138 buf.extend_from_slice(&0_i64.to_be_bytes());
140 buf.push(0);
142 buf
143}
144
145fn validate_pg_identifier(value: &str, field: &str) -> Result<(), ConnectorError> {
147 if value.is_empty() {
148 return Err(ConnectorError::ConfigurationError(format!(
149 "{field} must not be empty"
150 )));
151 }
152 if !value
153 .bytes()
154 .all(|b| b.is_ascii_alphanumeric() || b == b'_')
155 {
156 return Err(ConnectorError::ConfigurationError(format!(
157 "{field} contains unsafe characters (only [a-zA-Z0-9_] allowed): {value:?}"
158 )));
159 }
160 Ok(())
161}
162
163pub fn build_start_replication_query(
172 slot_name: &str,
173 start_lsn: Lsn,
174 publication: &str,
175) -> Result<String, ConnectorError> {
176 validate_pg_identifier(slot_name, "slot_name")?;
177 validate_pg_identifier(publication, "publication")?;
178 Ok(format!(
179 "START_REPLICATION SLOT {slot_name} LOGICAL {start_lsn} \
180 (proto_version '1', publication_names '{publication}')"
181 ))
182}
183
184#[cfg(feature = "postgres-cdc")]
205pub async fn connect(
206 config: &super::config::PostgresCdcConfig,
207) -> Result<(tokio_postgres::Client, tokio::task::JoinHandle<()>), ConnectorError> {
208 use super::config::SslMode;
209
210 let conn_str = config.connection_string();
211
212 match config.ssl_mode {
213 SslMode::Disable => {}
214 SslMode::Prefer => {
215 tracing::info!(
216 ssl_mode = %config.ssl_mode,
217 "TLS not yet implemented for control-plane connections; using NoTls (Prefer mode)"
218 );
219 }
220 mode => {
221 return Err(ConnectorError::ConfigurationError(format!(
222 "ssl.mode={mode} requires TLS support which is not yet implemented \
223 for control-plane connections"
224 )));
225 }
226 }
227
228 let (client, connection) = tokio_postgres::connect(&conn_str, tokio_postgres::NoTls)
229 .await
230 .map_err(|e| ConnectorError::ConnectionFailed(format!("PostgreSQL connect: {e}")))?;
231
232 let handle = tokio::spawn(async move {
233 if let Err(e) = connection.await {
234 tracing::error!(error = %e, "PostgreSQL control-plane connection error");
235 }
236 });
237
238 Ok((client, handle))
239}
240
241#[cfg(feature = "postgres-cdc")]
254pub async fn ensure_replication_slot(
255 client: &tokio_postgres::Client,
256 slot_name: &str,
257 plugin: &str,
258) -> Result<Option<Lsn>, ConnectorError> {
259 let rows = client
261 .query(
262 "SELECT confirmed_flush_lsn::text FROM pg_replication_slots WHERE slot_name = $1",
263 &[&slot_name],
264 )
265 .await
266 .map_err(|e| ConnectorError::ConnectionFailed(format!("query replication slots: {e}")))?;
267
268 if let Some(row) = rows.first() {
269 let lsn_str: Option<&str> = row.get(0);
270 if let Some(lsn_str) = lsn_str {
271 let lsn: Lsn = lsn_str.parse().map_err(|e| {
272 ConnectorError::ReadError(format!("invalid confirmed_flush_lsn: {e}"))
273 })?;
274 tracing::info!(slot = slot_name, lsn = %lsn, "replication slot exists");
275 return Ok(Some(lsn));
276 }
277 tracing::info!(slot = slot_name, "replication slot exists (no flush LSN)");
279 return Ok(None);
280 }
281
282 client
284 .execute(
285 "SELECT pg_create_logical_replication_slot($1, $2)",
286 &[&slot_name, &plugin],
287 )
288 .await
289 .map_err(|e| ConnectorError::ConnectionFailed(format!("create replication slot: {e}")))?;
290
291 tracing::info!(
292 slot = slot_name,
293 plugin = plugin,
294 "created replication slot"
295 );
296 Ok(None)
297}
298
299#[cfg(feature = "postgres-cdc")]
313pub async fn drop_replication_slot(
314 client: &tokio_postgres::Client,
315 slot_name: &str,
316) -> Result<(), ConnectorError> {
317 client
318 .execute("SELECT pg_drop_replication_slot($1)", &[&slot_name])
319 .await
320 .map_err(|e| {
321 ConnectorError::ConnectionFailed(format!("drop replication slot '{slot_name}': {e}"))
322 })?;
323 tracing::info!(slot = slot_name, "dropped replication slot");
324 Ok(())
325}
326
327#[cfg(feature = "postgres-cdc")]
351pub fn build_replication_config(
352 config: &super::config::PostgresCdcConfig,
353) -> pgwire_replication::ReplicationConfig {
354 use std::path::PathBuf;
355
356 use super::config::SslMode;
357
358 let ca_path = config.ca_cert_path.as_ref().map(PathBuf::from);
359
360 let tls = match config.ssl_mode {
361 SslMode::Disable => pgwire_replication::TlsConfig::disabled(),
362 SslMode::Prefer | SslMode::Require => pgwire_replication::TlsConfig::require(),
363 SslMode::VerifyCa => pgwire_replication::TlsConfig::verify_ca(ca_path),
364 SslMode::VerifyFull => pgwire_replication::TlsConfig::verify_full(ca_path),
365 };
366
367 let tls = if let Some(ref hostname) = config.sni_hostname {
369 tls.with_sni_hostname(hostname)
370 } else {
371 tls
372 };
373
374 let tls = match (&config.client_cert_path, &config.client_key_path) {
376 (Some(cert), Some(key)) => tls.with_client_cert(PathBuf::from(cert), PathBuf::from(key)),
377 _ => tls,
378 };
379
380 let start_lsn = config
381 .start_lsn
382 .map_or(pgwire_replication::Lsn::ZERO, |lsn| {
383 pgwire_replication::Lsn::from_u64(lsn.as_u64())
384 });
385
386 pgwire_replication::ReplicationConfig {
387 host: config.host.clone(),
388 port: config.port,
389 user: config.username.clone(),
390 password: config.password.clone().unwrap_or_default(),
391 database: config.database.clone(),
392 tls,
393 slot: config.slot_name.clone(),
394 publication: config.publication.clone(),
395 start_lsn,
396 stop_at_lsn: None,
397 status_interval: config.keepalive_interval,
398 idle_wakeup_interval: config.poll_timeout,
399 buffer_events: 8192,
400 }
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406
407 #[test]
410 fn test_parse_xlog_data() {
411 let mut msg = vec![b'w'];
412 msg.extend_from_slice(&0x0000_0001_0000_0100_u64.to_be_bytes());
413 msg.extend_from_slice(&0x0000_0001_0000_0200_u64.to_be_bytes());
414 msg.extend_from_slice(&1_234_567_890_i64.to_be_bytes());
415 msg.extend_from_slice(b"hello pgoutput");
416
417 let parsed = parse_replication_message(&msg).unwrap();
418 match parsed {
419 ReplicationMessage::XLogData {
420 wal_start,
421 wal_end,
422 server_time_us,
423 data,
424 } => {
425 assert_eq!(wal_start, Lsn::new(0x0000_0001_0000_0100));
426 assert_eq!(wal_end, Lsn::new(0x0000_0001_0000_0200));
427 assert_eq!(server_time_us, 1_234_567_890);
428 assert_eq!(data, b"hello pgoutput");
429 }
430 ReplicationMessage::PrimaryKeepalive { .. } => panic!("expected XLogData"),
431 }
432 }
433
434 #[test]
435 fn test_parse_xlog_data_empty_payload() {
436 let mut msg = vec![b'w'];
437 msg.extend_from_slice(&0_u64.to_be_bytes());
438 msg.extend_from_slice(&0_u64.to_be_bytes());
439 msg.extend_from_slice(&0_i64.to_be_bytes());
440
441 let parsed = parse_replication_message(&msg).unwrap();
442 match parsed {
443 ReplicationMessage::XLogData { data, .. } => {
444 assert!(data.is_empty());
445 }
446 ReplicationMessage::PrimaryKeepalive { .. } => panic!("expected XLogData"),
447 }
448 }
449
450 #[test]
453 fn test_parse_keepalive_reply_requested() {
454 let mut msg = vec![b'k'];
455 msg.extend_from_slice(&0x0000_0002_0000_0500_u64.to_be_bytes());
456 msg.extend_from_slice(&9_876_543_210_i64.to_be_bytes());
457 msg.push(1);
458
459 let parsed = parse_replication_message(&msg).unwrap();
460 match parsed {
461 ReplicationMessage::PrimaryKeepalive {
462 wal_end,
463 server_time_us,
464 reply_requested,
465 } => {
466 assert_eq!(wal_end, Lsn::new(0x0000_0002_0000_0500));
467 assert_eq!(server_time_us, 9_876_543_210);
468 assert!(reply_requested);
469 }
470 ReplicationMessage::XLogData { .. } => panic!("expected PrimaryKeepalive"),
471 }
472 }
473
474 #[test]
475 fn test_parse_keepalive_no_reply() {
476 let mut msg = vec![b'k'];
477 msg.extend_from_slice(&0x100_u64.to_be_bytes());
478 msg.extend_from_slice(&0_i64.to_be_bytes());
479 msg.push(0);
480
481 let parsed = parse_replication_message(&msg).unwrap();
482 match parsed {
483 ReplicationMessage::PrimaryKeepalive {
484 reply_requested, ..
485 } => {
486 assert!(!reply_requested);
487 }
488 ReplicationMessage::XLogData { .. } => panic!("expected PrimaryKeepalive"),
489 }
490 }
491
492 #[test]
495 fn test_parse_empty_message() {
496 let err = parse_replication_message(&[]).unwrap_err();
497 assert!(err.to_string().contains("empty"));
498 }
499
500 #[test]
501 fn test_parse_unknown_tag() {
502 let err = parse_replication_message(&[0xFF]).unwrap_err();
503 assert!(err.to_string().contains("unknown"));
504 assert!(err.to_string().contains("0xFF"));
505 }
506
507 #[test]
508 fn test_parse_truncated_xlog_data() {
509 let msg = vec![b'w', 0, 0, 0, 0, 0, 0, 0, 0, 0];
510 let err = parse_replication_message(&msg).unwrap_err();
511 assert!(err.to_string().contains("truncated"));
512 }
513
514 #[test]
515 fn test_parse_truncated_keepalive() {
516 let msg = vec![b'k', 0, 0, 0, 0, 0, 0, 0, 0, 0];
517 let err = parse_replication_message(&msg).unwrap_err();
518 assert!(err.to_string().contains("truncated"));
519 }
520
521 #[test]
524 fn test_encode_standby_status_layout() {
525 let write_lsn = Lsn::new(0x0000_0001_0000_0100);
526 let flush_lsn = Lsn::new(0x0000_0001_0000_0080);
527 let apply_lsn = Lsn::new(0x0000_0001_0000_0080);
528
529 let buf = encode_standby_status(write_lsn, flush_lsn, apply_lsn);
530
531 assert_eq!(buf.len(), 34, "standby status must be exactly 34 bytes");
532 assert_eq!(buf[0], b'r', "tag must be 'r'");
533
534 let w = u64::from_be_bytes(buf[1..9].try_into().unwrap());
536 assert_eq!(w, 0x0000_0001_0000_0100);
537
538 let f = u64::from_be_bytes(buf[9..17].try_into().unwrap());
540 assert_eq!(f, 0x0000_0001_0000_0080);
541
542 let a = u64::from_be_bytes(buf[17..25].try_into().unwrap());
544 assert_eq!(a, 0x0000_0001_0000_0080);
545
546 let ts = i64::from_be_bytes(buf[25..33].try_into().unwrap());
548 assert_eq!(ts, 0);
549
550 assert_eq!(buf[33], 0);
552 }
553
554 #[test]
557 fn test_build_start_replication_query() {
558 let query =
559 build_start_replication_query("my_slot", "0/1234ABCD".parse().unwrap(), "my_pub")
560 .unwrap();
561 assert!(query.contains("START_REPLICATION SLOT my_slot LOGICAL 0/1234ABCD"));
562 assert!(query.contains("proto_version '1'"));
563 assert!(query.contains("publication_names 'my_pub'"));
564 }
565
566 #[test]
567 fn test_build_start_replication_query_rejects_injection() {
568 let result = build_start_replication_query(
569 "slot'; DROP TABLE users; --",
570 "0/0".parse().unwrap(),
571 "pub",
572 );
573 assert!(result.is_err());
574 assert!(result
575 .unwrap_err()
576 .to_string()
577 .contains("unsafe characters"));
578 }
579
580 #[test]
581 fn test_build_start_replication_query_rejects_empty() {
582 let result = build_start_replication_query("", "0/0".parse().unwrap(), "pub");
583 assert!(result.is_err());
584 }
585
586 #[test]
587 fn test_validate_pg_identifier_accepts_valid() {
588 assert!(validate_pg_identifier("my_slot_123", "test").is_ok());
589 }
590
591 #[cfg(feature = "postgres-cdc")]
594 mod tls_mapping_tests {
595 use super::super::build_replication_config;
596 use crate::cdc::postgres::config::{PostgresCdcConfig, SslMode};
597
598 #[test]
599 fn test_disable_maps_to_disabled() {
600 let mut cfg = PostgresCdcConfig::default();
601 cfg.ssl_mode = SslMode::Disable;
602 let repl = build_replication_config(&cfg);
603 assert_eq!(repl.tls.mode, pgwire_replication::SslMode::Disable);
604 }
605
606 #[test]
607 fn test_prefer_maps_to_require() {
608 let cfg = PostgresCdcConfig::default(); let repl = build_replication_config(&cfg);
610 assert_eq!(repl.tls.mode, pgwire_replication::SslMode::Require);
611 }
612
613 #[test]
614 fn test_require_maps_to_require() {
615 let mut cfg = PostgresCdcConfig::default();
616 cfg.ssl_mode = SslMode::Require;
617 let repl = build_replication_config(&cfg);
618 assert_eq!(repl.tls.mode, pgwire_replication::SslMode::Require);
619 }
620
621 #[test]
622 fn test_verify_ca_maps_with_ca_path() {
623 let mut cfg = PostgresCdcConfig::default();
624 cfg.ssl_mode = SslMode::VerifyCa;
625 cfg.ca_cert_path = Some("/certs/ca.pem".to_string());
626 let repl = build_replication_config(&cfg);
627 assert_eq!(repl.tls.mode, pgwire_replication::SslMode::VerifyCa);
628 assert_eq!(
629 repl.tls.ca_pem_path.as_deref(),
630 Some(std::path::Path::new("/certs/ca.pem"))
631 );
632 }
633
634 #[test]
635 fn test_verify_full_maps_with_ca_path() {
636 let mut cfg = PostgresCdcConfig::default();
637 cfg.ssl_mode = SslMode::VerifyFull;
638 cfg.ca_cert_path = Some("/certs/ca.pem".to_string());
639 let repl = build_replication_config(&cfg);
640 assert_eq!(repl.tls.mode, pgwire_replication::SslMode::VerifyFull);
641 assert_eq!(
642 repl.tls.ca_pem_path.as_deref(),
643 Some(std::path::Path::new("/certs/ca.pem"))
644 );
645 }
646
647 #[test]
648 fn test_sni_hostname_applied() {
649 let mut cfg = PostgresCdcConfig::default();
650 cfg.sni_hostname = Some("db.example.com".to_string());
651 let repl = build_replication_config(&cfg);
652 assert_eq!(repl.tls.sni_hostname.as_deref(), Some("db.example.com"));
653 }
654
655 #[test]
656 fn test_mtls_client_cert_applied() {
657 let mut cfg = PostgresCdcConfig::default();
658 cfg.ssl_mode = SslMode::Require;
659 cfg.client_cert_path = Some("/certs/client.pem".to_string());
660 cfg.client_key_path = Some("/certs/client-key.pem".to_string());
661 let repl = build_replication_config(&cfg);
662 assert_eq!(
663 repl.tls.client_cert_pem_path.as_deref(),
664 Some(std::path::Path::new("/certs/client.pem"))
665 );
666 assert_eq!(
667 repl.tls.client_key_pem_path.as_deref(),
668 Some(std::path::Path::new("/certs/client-key.pem"))
669 );
670 }
671
672 #[test]
673 fn test_no_client_cert_when_not_set() {
674 let cfg = PostgresCdcConfig::default();
675 let repl = build_replication_config(&cfg);
676 assert!(repl.tls.client_cert_pem_path.is_none());
677 assert!(repl.tls.client_key_pem_path.is_none());
678 }
679
680 #[test]
681 fn test_connection_fields_mapped() {
682 let mut cfg = PostgresCdcConfig::new("pg.example.com", "mydb", "my_slot", "my_pub");
683 cfg.port = 5433;
684 cfg.username = "replicator".to_string();
685 cfg.password = Some("secret".to_string());
686 let repl = build_replication_config(&cfg);
687 assert_eq!(repl.host, "pg.example.com");
688 assert_eq!(repl.port, 5433);
689 assert_eq!(repl.user, "replicator");
690 assert_eq!(repl.password, "secret");
691 assert_eq!(repl.database, "mydb");
692 assert_eq!(repl.slot, "my_slot");
693 assert_eq!(repl.publication, "my_pub");
694 }
695 }
696}