1use std::time::Duration;
9
10use serde::{Deserialize, Serialize};
11
12use super::backpressure::BackpressureStrategy;
13use crate::config::ConnectorConfig;
14use crate::error::ConnectorError;
15
16mod duration_millis {
22 use std::time::Duration;
23
24 use serde::{self, Deserialize, Deserializer, Serializer};
25
26 #[allow(clippy::cast_possible_truncation)]
27 pub fn serialize<S>(d: &Duration, serializer: S) -> Result<S::Ok, S::Error>
28 where
29 S: Serializer,
30 {
31 serializer.serialize_u64(d.as_millis() as u64)
32 }
33
34 pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
35 where
36 D: Deserializer<'de>,
37 {
38 let millis = u64::deserialize(deserializer)?;
39 Ok(Duration::from_millis(millis))
40 }
41}
42
43fn default_backpressure() -> BackpressureStrategy {
49 BackpressureStrategy::Block
50}
51
52const fn default_max_message_size() -> usize {
54 64 * 1024 * 1024
55}
56
57const fn default_ping_interval() -> Duration {
59 Duration::from_secs(30)
60}
61
62const fn default_ping_timeout() -> Duration {
64 Duration::from_secs(10)
65}
66
67const fn default_max_connections() -> usize {
69 1024
70}
71
72const fn default_initial_delay() -> Duration {
74 Duration::from_millis(100)
75}
76
77const fn default_max_delay() -> Duration {
79 Duration::from_secs(30)
80}
81
82const fn default_backoff_multiplier() -> f64 {
84 2.0
85}
86
87const fn default_true() -> bool {
89 true
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct WebSocketSourceConfig {
105 pub mode: SourceMode,
107
108 pub format: MessageFormat,
110
111 #[serde(default = "default_backpressure")]
113 pub on_backpressure: BackpressureStrategy,
114
115 pub event_time_field: Option<String>,
119
120 pub event_time_format: Option<EventTimeFormat>,
122
123 #[serde(default = "default_max_message_size")]
127 pub max_message_size: usize,
128
129 pub auth: Option<WsAuthConfig>,
131}
132
133impl Default for WebSocketSourceConfig {
134 fn default() -> Self {
135 Self {
136 mode: SourceMode::default(),
137 format: MessageFormat::default(),
138 on_backpressure: default_backpressure(),
139 event_time_field: None,
140 event_time_format: None,
141 max_message_size: default_max_message_size(),
142 auth: None,
143 }
144 }
145}
146
147impl WebSocketSourceConfig {
148 pub fn from_config(config: &ConnectorConfig) -> Result<Self, ConnectorError> {
158 let mode = Self::parse_mode(config)?;
159 let format = Self::parse_format(config)?;
160
161 let on_backpressure = match config.get("on.backpressure").map(str::to_lowercase) {
162 Some(ref s) if s == "block" => BackpressureStrategy::Block,
163 Some(ref s) if s == "drop" || s == "drop_newest" => BackpressureStrategy::DropNewest,
164 Some(ref other) => {
165 return Err(ConnectorError::ConfigurationError(format!(
166 "invalid backpressure strategy '{other}': expected 'block' or 'drop'"
167 )));
168 }
169 None => default_backpressure(),
170 };
171
172 let max_message_size: usize = config
173 .get_parsed("max.message.size")?
174 .unwrap_or(default_max_message_size());
175
176 let event_time_field = config.get("event.time.field").map(ToString::to_string);
177 let event_time_format = Self::parse_event_time_format(config);
178 let auth = Self::parse_auth(config)?;
179
180 Ok(Self {
181 mode,
182 format,
183 on_backpressure,
184 event_time_field,
185 event_time_format,
186 max_message_size,
187 auth,
188 })
189 }
190
191 fn parse_mode(config: &ConnectorConfig) -> Result<SourceMode, ConnectorError> {
193 let mode_str = config.get("mode").unwrap_or("client");
194 match mode_str.to_lowercase().as_str() {
195 "client" => {
196 let urls = if let Some(url) = config.get("url") {
197 url.split(',').map(|s| s.trim().to_string()).collect()
198 } else {
199 return Err(ConnectorError::ConfigurationError(
200 "WebSocket client mode requires 'url'. \
201 Set url='wss://...' in the WITH clause."
202 .into(),
203 ));
204 };
205
206 let subscribe_message = config.get("subscribe.message").map(ToString::to_string);
207
208 let reconnect_enabled: bool =
209 config.get_parsed("reconnect.enabled")?.unwrap_or(true);
210 let initial_delay_ms: u64 = config
211 .get_parsed("reconnect.initial.delay.ms")?
212 .unwrap_or(100);
213 let max_delay_ms: u64 = config
214 .get_parsed("reconnect.max.delay.ms")?
215 .unwrap_or(30_000);
216 let max_retries: Option<u32> = config.get_parsed("reconnect.max.retries")?;
217
218 let ping_interval_ms: u64 =
219 config.get_parsed("ping.interval.ms")?.unwrap_or(30_000);
220 let ping_timeout_ms: u64 = config.get_parsed("ping.timeout.ms")?.unwrap_or(10_000);
221
222 Ok(SourceMode::Client {
223 urls,
224 subscribe_message,
225 reconnect: ReconnectConfig {
226 enabled: reconnect_enabled,
227 initial_delay: Duration::from_millis(initial_delay_ms),
228 max_delay: Duration::from_millis(max_delay_ms),
229 backoff_multiplier: default_backoff_multiplier(),
230 max_retries,
231 jitter: true,
232 },
233 ping_interval: Duration::from_millis(ping_interval_ms),
234 ping_timeout: Duration::from_millis(ping_timeout_ms),
235 })
236 }
237 "server" => {
238 let bind_address = config.require("bind.address").map(ToString::to_string)?;
239 let max_connections: usize = config
240 .get_parsed("max.connections")?
241 .unwrap_or(default_max_connections());
242 let path = config.get("path").map(ToString::to_string);
243
244 Ok(SourceMode::Server {
245 bind_address,
246 max_connections,
247 path,
248 })
249 }
250 other => Err(ConnectorError::ConfigurationError(format!(
251 "invalid WebSocket mode '{other}': expected 'client' or 'server'"
252 ))),
253 }
254 }
255
256 fn parse_format(config: &ConnectorConfig) -> Result<MessageFormat, ConnectorError> {
258 match config.get("format").map(str::to_lowercase) {
259 Some(ref s) if s == "json" => Ok(MessageFormat::Json),
260 Some(ref s) if s == "jsonlines" || s == "json_lines" => Ok(MessageFormat::JsonLines),
261 Some(ref s) if s == "binary" => Ok(MessageFormat::Binary),
262 Some(ref s) if s == "csv" => Ok(MessageFormat::Csv {
263 delimiter: ',',
264 has_header: false,
265 }),
266 Some(ref other) => Err(ConnectorError::ConfigurationError(format!(
267 "invalid WebSocket format '{other}': expected json, jsonlines, binary, or csv"
268 ))),
269 None => Ok(MessageFormat::Json),
270 }
271 }
272
273 fn parse_event_time_format(config: &ConnectorConfig) -> Option<EventTimeFormat> {
275 match config.get("event.time.format").map(str::to_lowercase) {
276 Some(ref s) if s == "epoch_millis" => Some(EventTimeFormat::EpochMillis),
277 Some(ref s) if s == "epoch_micros" => Some(EventTimeFormat::EpochMicros),
278 Some(ref s) if s == "epoch_nanos" => Some(EventTimeFormat::EpochNanos),
279 Some(ref s) if s == "epoch_seconds" => Some(EventTimeFormat::EpochSeconds),
280 Some(ref s) if s == "iso8601" => Some(EventTimeFormat::Iso8601),
281 Some(other) => Some(EventTimeFormat::Custom(other.clone())),
282 None => None,
283 }
284 }
285
286 fn parse_auth(config: &ConnectorConfig) -> Result<Option<WsAuthConfig>, ConnectorError> {
288 match config.get("auth.type").map(str::to_lowercase) {
289 Some(ref s) if s == "bearer" => {
290 let token = config.require("auth.token").map(ToString::to_string)?;
291 Ok(Some(WsAuthConfig::Bearer { token }))
292 }
293 Some(ref s) if s == "basic" => {
294 let username = config.require("auth.username").map(ToString::to_string)?;
295 let password = config.require("auth.password").map(ToString::to_string)?;
296 Ok(Some(WsAuthConfig::Basic { username, password }))
297 }
298 Some(ref s) if s == "hmac" => {
299 let api_key = config.require("auth.api.key").map(ToString::to_string)?;
300 let secret = config.require("auth.secret").map(ToString::to_string)?;
301 Ok(Some(WsAuthConfig::Hmac { api_key, secret }))
302 }
303 Some(ref other) => Err(ConnectorError::ConfigurationError(format!(
304 "unsupported auth type '{other}': expected bearer, basic, or hmac"
305 ))),
306 None => Ok(None),
307 }
308 }
309}
310
311#[derive(Debug, Clone, Serialize, Deserialize)]
317#[serde(tag = "type")]
318pub enum SourceMode {
319 Client {
321 urls: Vec<String>,
323
324 subscribe_message: Option<String>,
327
328 #[serde(default)]
330 reconnect: ReconnectConfig,
331
332 #[serde(default = "default_ping_interval", with = "duration_millis")]
334 ping_interval: Duration,
335
336 #[serde(default = "default_ping_timeout", with = "duration_millis")]
338 ping_timeout: Duration,
339 },
340
341 Server {
343 bind_address: String,
345
346 #[serde(default = "default_max_connections")]
348 max_connections: usize,
349
350 path: Option<String>,
354 },
355}
356
357impl Default for SourceMode {
358 fn default() -> Self {
359 Self::Client {
360 urls: vec![String::new()],
361 subscribe_message: None,
362 reconnect: ReconnectConfig::default(),
363 ping_interval: default_ping_interval(),
364 ping_timeout: default_ping_timeout(),
365 }
366 }
367}
368
369#[derive(Debug, Clone, Default, Serialize, Deserialize)]
375pub enum MessageFormat {
376 #[default]
378 Json,
379
380 JsonLines,
382
383 Binary,
385
386 Csv {
388 delimiter: char,
390 has_header: bool,
392 },
393}
394
395#[derive(Debug, Clone, Serialize, Deserialize)]
401pub enum EventTimeFormat {
402 EpochMillis,
404
405 EpochMicros,
407
408 EpochNanos,
410
411 EpochSeconds,
413
414 Iso8601,
416
417 Custom(String),
419}
420
421#[derive(Debug, Clone, Serialize, Deserialize)]
431pub struct ReconnectConfig {
432 pub enabled: bool,
434
435 #[serde(default = "default_initial_delay", with = "duration_millis")]
437 pub initial_delay: Duration,
438
439 #[serde(default = "default_max_delay", with = "duration_millis")]
441 pub max_delay: Duration,
442
443 #[serde(default = "default_backoff_multiplier")]
445 pub backoff_multiplier: f64,
446
447 pub max_retries: Option<u32>,
451
452 #[serde(default = "default_true")]
454 pub jitter: bool,
455}
456
457impl Default for ReconnectConfig {
458 fn default() -> Self {
459 Self {
460 enabled: true,
461 initial_delay: default_initial_delay(),
462 max_delay: default_max_delay(),
463 backoff_multiplier: default_backoff_multiplier(),
464 max_retries: None,
465 jitter: true,
466 }
467 }
468}
469
470#[derive(Debug, Clone, Serialize, Deserialize)]
479#[serde(tag = "type")]
480pub enum WsAuthConfig {
481 Bearer {
483 token: String,
485 },
486
487 Basic {
489 username: String,
491 password: String,
493 },
494
495 Headers {
497 headers: Vec<(String, String)>,
499 },
500
501 QueryParam {
503 key: String,
505 value: String,
507 },
508
509 Hmac {
511 api_key: String,
513 secret: String,
515 },
516}
517
518#[cfg(test)]
523mod tests {
524 use super::*;
525
526 #[test]
529 fn test_default_websocket_source_config() {
530 let cfg = WebSocketSourceConfig::default();
531
532 assert!(matches!(cfg.mode, SourceMode::Client { .. }));
533 assert!(matches!(cfg.format, MessageFormat::Json));
534 assert!(matches!(cfg.on_backpressure, BackpressureStrategy::Block));
535 assert_eq!(cfg.max_message_size, 64 * 1024 * 1024);
536 assert!(cfg.event_time_field.is_none());
537 assert!(cfg.event_time_format.is_none());
538 assert!(cfg.auth.is_none());
539 }
540
541 #[test]
542 fn test_default_source_mode() {
543 let mode = SourceMode::default();
544 match mode {
545 SourceMode::Client {
546 urls,
547 subscribe_message,
548 reconnect,
549 ping_interval,
550 ping_timeout,
551 } => {
552 assert_eq!(urls.len(), 1);
553 assert_eq!(urls[0], "");
554 assert!(subscribe_message.is_none());
555 assert!(reconnect.enabled);
556 assert_eq!(ping_interval, Duration::from_secs(30));
557 assert_eq!(ping_timeout, Duration::from_secs(10));
558 }
559 SourceMode::Server { .. } => panic!("expected Client mode"),
560 }
561 }
562
563 #[test]
564 fn test_default_message_format() {
565 let fmt = MessageFormat::default();
566 assert!(matches!(fmt, MessageFormat::Json));
567 }
568
569 #[test]
570 fn test_default_reconnect_config() {
571 let rc = ReconnectConfig::default();
572 assert!(rc.enabled);
573 assert_eq!(rc.initial_delay, Duration::from_millis(100));
574 assert_eq!(rc.max_delay, Duration::from_secs(30));
575 assert!((rc.backoff_multiplier - 2.0).abs() < f64::EPSILON);
576 assert!(rc.max_retries.is_none());
577 assert!(rc.jitter);
578 }
579
580 #[test]
583 fn test_serde_round_trip_client_mode() {
584 let cfg = WebSocketSourceConfig {
585 mode: SourceMode::Client {
586 urls: vec!["wss://feed.example.com/v1".into()],
587 subscribe_message: Some(r#"{"op":"subscribe","channel":"trades"}"#.into()),
588 reconnect: ReconnectConfig::default(),
589 ping_interval: Duration::from_secs(15),
590 ping_timeout: Duration::from_secs(5),
591 },
592 format: MessageFormat::Json,
593 on_backpressure: BackpressureStrategy::Block,
594 event_time_field: Some("timestamp".into()),
595 event_time_format: Some(EventTimeFormat::EpochMillis),
596 max_message_size: 1024 * 1024,
597 auth: Some(WsAuthConfig::Bearer {
598 token: "tok_abc123".into(),
599 }),
600 };
601
602 let json = serde_json::to_string_pretty(&cfg).expect("serialize");
603 let deser: WebSocketSourceConfig = serde_json::from_str(&json).expect("deserialize");
604
605 match &deser.mode {
607 SourceMode::Client {
608 urls,
609 subscribe_message,
610 ping_interval,
611 ping_timeout,
612 ..
613 } => {
614 assert_eq!(urls, &["wss://feed.example.com/v1"]);
615 assert_eq!(
616 subscribe_message.as_deref(),
617 Some(r#"{"op":"subscribe","channel":"trades"}"#)
618 );
619 assert_eq!(*ping_interval, Duration::from_secs(15));
620 assert_eq!(*ping_timeout, Duration::from_secs(5));
621 }
622 SourceMode::Server { .. } => panic!("expected Client"),
623 }
624 assert_eq!(deser.event_time_field.as_deref(), Some("timestamp"));
625 assert!(matches!(
626 deser.event_time_format,
627 Some(EventTimeFormat::EpochMillis)
628 ));
629 assert_eq!(deser.max_message_size, 1024 * 1024);
630 assert!(matches!(
631 deser.auth,
632 Some(WsAuthConfig::Bearer { ref token }) if token == "tok_abc123"
633 ));
634 }
635
636 #[test]
637 fn test_serde_round_trip_server_mode() {
638 let cfg = WebSocketSourceConfig {
639 mode: SourceMode::Server {
640 bind_address: "0.0.0.0:9443".into(),
641 max_connections: 512,
642 path: Some("/ingest".into()),
643 },
644 format: MessageFormat::JsonLines,
645 on_backpressure: BackpressureStrategy::Block,
646 event_time_field: None,
647 event_time_format: None,
648 max_message_size: default_max_message_size(),
649 auth: None,
650 };
651
652 let json = serde_json::to_string(&cfg).expect("serialize");
653 let deser: WebSocketSourceConfig = serde_json::from_str(&json).expect("deserialize");
654
655 match &deser.mode {
656 SourceMode::Server {
657 bind_address,
658 max_connections,
659 path,
660 } => {
661 assert_eq!(bind_address, "0.0.0.0:9443");
662 assert_eq!(*max_connections, 512);
663 assert_eq!(path.as_deref(), Some("/ingest"));
664 }
665 SourceMode::Client { .. } => panic!("expected Server"),
666 }
667 }
668
669 #[test]
670 fn test_serde_round_trip_reconnect_config() {
671 let rc = ReconnectConfig {
672 enabled: false,
673 initial_delay: Duration::from_millis(500),
674 max_delay: Duration::from_secs(60),
675 backoff_multiplier: 1.5,
676 max_retries: Some(10),
677 jitter: false,
678 };
679
680 let json = serde_json::to_string(&rc).expect("serialize");
681 let deser: ReconnectConfig = serde_json::from_str(&json).expect("deserialize");
682
683 assert!(!deser.enabled);
684 assert_eq!(deser.initial_delay, Duration::from_millis(500));
685 assert_eq!(deser.max_delay, Duration::from_secs(60));
686 assert!((deser.backoff_multiplier - 1.5).abs() < f64::EPSILON);
687 assert_eq!(deser.max_retries, Some(10));
688 assert!(!deser.jitter);
689 }
690
691 #[test]
692 fn test_serde_round_trip_csv_format() {
693 let cfg = WebSocketSourceConfig {
694 format: MessageFormat::Csv {
695 delimiter: '|',
696 has_header: true,
697 },
698 ..WebSocketSourceConfig::default()
699 };
700
701 let json = serde_json::to_string(&cfg).expect("serialize");
702 let deser: WebSocketSourceConfig = serde_json::from_str(&json).expect("deserialize");
703
704 match deser.format {
705 MessageFormat::Csv {
706 delimiter,
707 has_header,
708 } => {
709 assert_eq!(delimiter, '|');
710 assert!(has_header);
711 }
712 _ => panic!("expected Csv format"),
713 }
714 }
715
716 #[test]
717 fn test_serde_round_trip_auth_variants() {
718 let basic = WsAuthConfig::Basic {
720 username: "user".into(),
721 password: "pass".into(),
722 };
723 let json = serde_json::to_string(&basic).expect("serialize");
724 let deser: WsAuthConfig = serde_json::from_str(&json).expect("deserialize");
725 assert!(matches!(
726 deser,
727 WsAuthConfig::Basic { ref username, ref password }
728 if username == "user" && password == "pass"
729 ));
730
731 let headers = WsAuthConfig::Headers {
733 headers: vec![("X-Api-Key".into(), "key123".into())],
734 };
735 let json = serde_json::to_string(&headers).expect("serialize");
736 let deser: WsAuthConfig = serde_json::from_str(&json).expect("deserialize");
737 assert!(matches!(deser, WsAuthConfig::Headers { ref headers } if headers.len() == 1));
738
739 let qp = WsAuthConfig::QueryParam {
741 key: "token".into(),
742 value: "abc".into(),
743 };
744 let json = serde_json::to_string(&qp).expect("serialize");
745 let deser: WsAuthConfig = serde_json::from_str(&json).expect("deserialize");
746 assert!(matches!(
747 deser,
748 WsAuthConfig::QueryParam { ref key, ref value }
749 if key == "token" && value == "abc"
750 ));
751
752 let hmac = WsAuthConfig::Hmac {
754 api_key: "ak".into(),
755 secret: "sk".into(),
756 };
757 let json = serde_json::to_string(&hmac).expect("serialize");
758 let deser: WsAuthConfig = serde_json::from_str(&json).expect("deserialize");
759 assert!(matches!(
760 deser,
761 WsAuthConfig::Hmac { ref api_key, ref secret }
762 if api_key == "ak" && secret == "sk"
763 ));
764 }
765
766 #[test]
767 fn test_serde_round_trip_event_time_formats() {
768 let formats = vec![
769 EventTimeFormat::EpochMillis,
770 EventTimeFormat::EpochMicros,
771 EventTimeFormat::EpochNanos,
772 EventTimeFormat::EpochSeconds,
773 EventTimeFormat::Iso8601,
774 EventTimeFormat::Custom("%Y-%m-%dT%H:%M:%S".into()),
775 ];
776
777 for fmt in formats {
778 let json = serde_json::to_string(&fmt).expect("serialize");
779 let deser: EventTimeFormat = serde_json::from_str(&json).expect("deserialize");
780
781 match (&fmt, &deser) {
783 (EventTimeFormat::EpochMillis, EventTimeFormat::EpochMillis)
784 | (EventTimeFormat::EpochMicros, EventTimeFormat::EpochMicros)
785 | (EventTimeFormat::EpochNanos, EventTimeFormat::EpochNanos)
786 | (EventTimeFormat::EpochSeconds, EventTimeFormat::EpochSeconds)
787 | (EventTimeFormat::Iso8601, EventTimeFormat::Iso8601) => {}
788 (EventTimeFormat::Custom(a), EventTimeFormat::Custom(b)) => {
789 assert_eq!(a, b);
790 }
791 _ => panic!("event time format mismatch after round-trip"),
792 }
793 }
794 }
795
796 #[test]
797 fn test_serde_defaults_applied() {
798 let json = r#"{
800 "mode": {
801 "type": "Server",
802 "bind_address": "127.0.0.1:8080"
803 },
804 "format": "Json"
805 }"#;
806
807 let cfg: WebSocketSourceConfig =
808 serde_json::from_str(json).expect("deserialize with defaults");
809
810 assert!(matches!(cfg.on_backpressure, BackpressureStrategy::Block));
811 assert_eq!(cfg.max_message_size, 64 * 1024 * 1024);
812 assert!(cfg.event_time_field.is_none());
813 assert!(cfg.auth.is_none());
814
815 match cfg.mode {
816 SourceMode::Server {
817 max_connections, ..
818 } => {
819 assert_eq!(max_connections, 1024);
820 }
821 SourceMode::Client { .. } => panic!("expected Server"),
822 }
823 }
824
825 #[test]
828 fn test_default_helper_values() {
829 assert_eq!(default_max_message_size(), 64 * 1024 * 1024);
830 assert_eq!(default_ping_interval(), Duration::from_secs(30));
831 assert_eq!(default_ping_timeout(), Duration::from_secs(10));
832 assert_eq!(default_max_connections(), 1024);
833 assert_eq!(default_initial_delay(), Duration::from_millis(100));
834 assert_eq!(default_max_delay(), Duration::from_secs(30));
835 assert!((default_backoff_multiplier() - 2.0).abs() < f64::EPSILON);
836 assert!(default_true());
837 }
838
839 #[test]
842 fn test_from_config_client_mode() {
843 let mut config = ConnectorConfig::new("websocket");
844 config.set("url", "wss://feed.example.com/v1");
845 config.set("format", "json");
846 config.set("subscribe.message", r#"{"op":"subscribe"}"#);
847 config.set("reconnect.enabled", "true");
848 config.set("reconnect.initial.delay.ms", "200");
849 config.set("reconnect.max.delay.ms", "60000");
850 config.set("ping.interval.ms", "15000");
851 config.set("ping.timeout.ms", "5000");
852
853 let cfg = WebSocketSourceConfig::from_config(&config).unwrap();
854
855 match &cfg.mode {
856 SourceMode::Client {
857 urls,
858 subscribe_message,
859 reconnect,
860 ping_interval,
861 ping_timeout,
862 } => {
863 assert_eq!(urls, &["wss://feed.example.com/v1"]);
864 assert_eq!(subscribe_message.as_deref(), Some(r#"{"op":"subscribe"}"#));
865 assert!(reconnect.enabled);
866 assert_eq!(reconnect.initial_delay, Duration::from_millis(200));
867 assert_eq!(reconnect.max_delay, Duration::from_millis(60_000));
868 assert_eq!(*ping_interval, Duration::from_millis(15_000));
869 assert_eq!(*ping_timeout, Duration::from_millis(5_000));
870 }
871 SourceMode::Server { .. } => panic!("expected Client mode"),
872 }
873 assert!(matches!(cfg.format, MessageFormat::Json));
874 }
875
876 #[test]
877 fn test_from_config_server_mode() {
878 let mut config = ConnectorConfig::new("websocket");
879 config.set("mode", "server");
880 config.set("bind.address", "0.0.0.0:9443");
881 config.set("max.connections", "512");
882 config.set("path", "/ingest");
883
884 let cfg = WebSocketSourceConfig::from_config(&config).unwrap();
885
886 match &cfg.mode {
887 SourceMode::Server {
888 bind_address,
889 max_connections,
890 path,
891 } => {
892 assert_eq!(bind_address, "0.0.0.0:9443");
893 assert_eq!(*max_connections, 512);
894 assert_eq!(path.as_deref(), Some("/ingest"));
895 }
896 SourceMode::Client { .. } => panic!("expected Server mode"),
897 }
898 }
899
900 #[test]
901 fn test_from_config_missing_url_errors() {
902 let config = ConnectorConfig::new("websocket");
903 let result = WebSocketSourceConfig::from_config(&config);
905 assert!(result.is_err());
906 assert!(result.unwrap_err().to_string().contains("url"));
907 }
908
909 #[test]
910 fn test_from_config_bearer_auth() {
911 let mut config = ConnectorConfig::new("websocket");
912 config.set("url", "wss://api.example.com");
913 config.set("auth.type", "bearer");
914 config.set("auth.token", "tok_abc123");
915
916 let cfg = WebSocketSourceConfig::from_config(&config).unwrap();
917 assert!(matches!(
918 cfg.auth,
919 Some(WsAuthConfig::Bearer { ref token }) if token == "tok_abc123"
920 ));
921 }
922
923 #[test]
924 fn test_from_config_multiple_urls() {
925 let mut config = ConnectorConfig::new("websocket");
926 config.set("url", "wss://a.example.com, wss://b.example.com");
927
928 let cfg = WebSocketSourceConfig::from_config(&config).unwrap();
929 match &cfg.mode {
930 SourceMode::Client { urls, .. } => {
931 assert_eq!(urls.len(), 2);
932 assert_eq!(urls[0], "wss://a.example.com");
933 assert_eq!(urls[1], "wss://b.example.com");
934 }
935 SourceMode::Server { .. } => panic!("expected Client mode"),
936 }
937 }
938
939 #[test]
940 fn test_from_config_defaults() {
941 let mut config = ConnectorConfig::new("websocket");
942 config.set("url", "wss://feed.example.com");
943
944 let cfg = WebSocketSourceConfig::from_config(&config).unwrap();
945 assert!(matches!(cfg.format, MessageFormat::Json));
946 assert!(matches!(cfg.on_backpressure, BackpressureStrategy::Block));
947 assert_eq!(cfg.max_message_size, 64 * 1024 * 1024);
948 assert!(cfg.event_time_field.is_none());
949 assert!(cfg.auth.is_none());
950 }
951}