Skip to main content

laminar_connectors/websocket/
source_config.rs

1//! WebSocket source connector configuration.
2//!
3//! Provides [`WebSocketSourceConfig`] for configuring a WebSocket source
4//! connector in either client mode (connecting to an upstream server) or
5//! server mode (accepting incoming connections). Includes reconnection,
6//! authentication, message format, and event-time extraction options.
7
8use std::time::Duration;
9
10use serde::{Deserialize, Serialize};
11
12use super::backpressure::BackpressureStrategy;
13use crate::config::ConnectorConfig;
14use crate::error::ConnectorError;
15
16// ---------------------------------------------------------------------------
17// Serde helper: Duration as milliseconds
18// ---------------------------------------------------------------------------
19
20/// Serde helper that encodes a [`Duration`] as a `u64` millisecond count.
21mod duration_millis {
22    use std::time::Duration;
23
24    use serde::{self, Deserialize, Deserializer, Serializer};
25
26    #[allow(clippy::cast_possible_truncation)]
27    pub fn serialize<S>(d: &Duration, serializer: S) -> Result<S::Ok, S::Error>
28    where
29        S: Serializer,
30    {
31        serializer.serialize_u64(d.as_millis() as u64)
32    }
33
34    pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
35    where
36        D: Deserializer<'de>,
37    {
38        let millis = u64::deserialize(deserializer)?;
39        Ok(Duration::from_millis(millis))
40    }
41}
42
43// ---------------------------------------------------------------------------
44// Default helpers
45// ---------------------------------------------------------------------------
46
47/// Default backpressure strategy: block the WebSocket read loop.
48fn default_backpressure() -> BackpressureStrategy {
49    BackpressureStrategy::Block
50}
51
52/// Default maximum message size: 64 MiB.
53const fn default_max_message_size() -> usize {
54    64 * 1024 * 1024
55}
56
57/// Default ping interval for client mode: 30 seconds.
58const fn default_ping_interval() -> Duration {
59    Duration::from_secs(30)
60}
61
62/// Default ping timeout for client mode: 10 seconds.
63const fn default_ping_timeout() -> Duration {
64    Duration::from_secs(10)
65}
66
67/// Default maximum concurrent connections for server mode: 1024.
68const fn default_max_connections() -> usize {
69    1024
70}
71
72/// Default reconnect initial delay: 100 ms.
73const fn default_initial_delay() -> Duration {
74    Duration::from_millis(100)
75}
76
77/// Default reconnect maximum delay: 30 seconds.
78const fn default_max_delay() -> Duration {
79    Duration::from_secs(30)
80}
81
82/// Default exponential backoff multiplier.
83const fn default_backoff_multiplier() -> f64 {
84    2.0
85}
86
87/// Returns `true` (used for `#[serde(default)]` on boolean fields).
88const fn default_true() -> bool {
89    true
90}
91
92// ---------------------------------------------------------------------------
93// Top-level config
94// ---------------------------------------------------------------------------
95
96/// WebSocket source connector configuration.
97///
98/// Supports two operating modes:
99/// - **Client**: connects to one or more upstream WebSocket servers and
100///   optionally sends a subscribe message after the handshake.
101/// - **Server**: binds a local address and accepts incoming WebSocket
102///   connections (e.g., from `IoT` devices or browser clients).
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct WebSocketSourceConfig {
105    /// Operating mode (client or server).
106    pub mode: SourceMode,
107
108    /// Message format used for deserialization.
109    pub format: MessageFormat,
110
111    /// Backpressure strategy when the Ring 0 channel is full.
112    #[serde(default = "default_backpressure")]
113    pub on_backpressure: BackpressureStrategy,
114
115    /// JSON field path used to extract event time from each message.
116    ///
117    /// When `None`, processing time is used as the event timestamp.
118    pub event_time_field: Option<String>,
119
120    /// Format of the event time value extracted from `event_time_field`.
121    pub event_time_format: Option<EventTimeFormat>,
122
123    /// Maximum accepted WebSocket message size in bytes.
124    ///
125    /// Messages exceeding this limit are rejected. Defaults to 64 MiB.
126    #[serde(default = "default_max_message_size")]
127    pub max_message_size: usize,
128
129    /// Optional authentication configuration for the WebSocket connection.
130    pub auth: Option<WsAuthConfig>,
131}
132
133impl Default for WebSocketSourceConfig {
134    fn default() -> Self {
135        Self {
136            mode: SourceMode::default(),
137            format: MessageFormat::default(),
138            on_backpressure: default_backpressure(),
139            event_time_field: None,
140            event_time_format: None,
141            max_message_size: default_max_message_size(),
142            auth: None,
143        }
144    }
145}
146
147impl WebSocketSourceConfig {
148    /// Builds a [`WebSocketSourceConfig`] from a flat [`ConnectorConfig`] property map.
149    ///
150    /// Maps well-known keys from `WITH (...)` clauses to the structured config.
151    /// Unknown keys are silently ignored (forward compatibility).
152    ///
153    /// # Errors
154    ///
155    /// Returns `ConnectorError::ConfigurationError` if a required key is missing
156    /// or a value cannot be parsed.
157    pub fn from_config(config: &ConnectorConfig) -> Result<Self, ConnectorError> {
158        let mode = Self::parse_mode(config)?;
159        let format = Self::parse_format(config)?;
160
161        let on_backpressure = match config.get("on.backpressure").map(str::to_lowercase) {
162            Some(ref s) if s == "block" => BackpressureStrategy::Block,
163            Some(ref s) if s == "drop" || s == "drop_newest" => BackpressureStrategy::DropNewest,
164            Some(ref other) => {
165                return Err(ConnectorError::ConfigurationError(format!(
166                    "invalid backpressure strategy '{other}': expected 'block' or 'drop'"
167                )));
168            }
169            None => default_backpressure(),
170        };
171
172        let max_message_size: usize = config
173            .get_parsed("max.message.size")?
174            .unwrap_or(default_max_message_size());
175
176        let event_time_field = config.get("event.time.field").map(ToString::to_string);
177        let event_time_format = Self::parse_event_time_format(config);
178        let auth = Self::parse_auth(config)?;
179
180        Ok(Self {
181            mode,
182            format,
183            on_backpressure,
184            event_time_field,
185            event_time_format,
186            max_message_size,
187            auth,
188        })
189    }
190
191    /// Parses the `mode` property into a [`SourceMode`].
192    fn parse_mode(config: &ConnectorConfig) -> Result<SourceMode, ConnectorError> {
193        let mode_str = config.get("mode").unwrap_or("client");
194        match mode_str.to_lowercase().as_str() {
195            "client" => {
196                let urls = if let Some(url) = config.get("url") {
197                    url.split(',').map(|s| s.trim().to_string()).collect()
198                } else {
199                    return Err(ConnectorError::ConfigurationError(
200                        "WebSocket client mode requires 'url'. \
201                         Set url='wss://...' in the WITH clause."
202                            .into(),
203                    ));
204                };
205
206                let subscribe_message = config.get("subscribe.message").map(ToString::to_string);
207
208                let reconnect_enabled: bool =
209                    config.get_parsed("reconnect.enabled")?.unwrap_or(true);
210                let initial_delay_ms: u64 = config
211                    .get_parsed("reconnect.initial.delay.ms")?
212                    .unwrap_or(100);
213                let max_delay_ms: u64 = config
214                    .get_parsed("reconnect.max.delay.ms")?
215                    .unwrap_or(30_000);
216                let max_retries: Option<u32> = config.get_parsed("reconnect.max.retries")?;
217
218                let ping_interval_ms: u64 =
219                    config.get_parsed("ping.interval.ms")?.unwrap_or(30_000);
220                let ping_timeout_ms: u64 = config.get_parsed("ping.timeout.ms")?.unwrap_or(10_000);
221
222                Ok(SourceMode::Client {
223                    urls,
224                    subscribe_message,
225                    reconnect: ReconnectConfig {
226                        enabled: reconnect_enabled,
227                        initial_delay: Duration::from_millis(initial_delay_ms),
228                        max_delay: Duration::from_millis(max_delay_ms),
229                        backoff_multiplier: default_backoff_multiplier(),
230                        max_retries,
231                        jitter: true,
232                    },
233                    ping_interval: Duration::from_millis(ping_interval_ms),
234                    ping_timeout: Duration::from_millis(ping_timeout_ms),
235                })
236            }
237            "server" => {
238                let bind_address = config.require("bind.address").map(ToString::to_string)?;
239                let max_connections: usize = config
240                    .get_parsed("max.connections")?
241                    .unwrap_or(default_max_connections());
242                let path = config.get("path").map(ToString::to_string);
243
244                Ok(SourceMode::Server {
245                    bind_address,
246                    max_connections,
247                    path,
248                })
249            }
250            other => Err(ConnectorError::ConfigurationError(format!(
251                "invalid WebSocket mode '{other}': expected 'client' or 'server'"
252            ))),
253        }
254    }
255
256    /// Parses the `format` property into a [`MessageFormat`].
257    fn parse_format(config: &ConnectorConfig) -> Result<MessageFormat, ConnectorError> {
258        match config.get("format").map(str::to_lowercase) {
259            Some(ref s) if s == "json" => Ok(MessageFormat::Json),
260            Some(ref s) if s == "jsonlines" || s == "json_lines" => Ok(MessageFormat::JsonLines),
261            Some(ref s) if s == "binary" => Ok(MessageFormat::Binary),
262            Some(ref s) if s == "csv" => Ok(MessageFormat::Csv {
263                delimiter: ',',
264                has_header: false,
265            }),
266            Some(ref other) => Err(ConnectorError::ConfigurationError(format!(
267                "invalid WebSocket format '{other}': expected json, jsonlines, binary, or csv"
268            ))),
269            None => Ok(MessageFormat::Json),
270        }
271    }
272
273    /// Parses the `event.time.format` property into an [`EventTimeFormat`].
274    fn parse_event_time_format(config: &ConnectorConfig) -> Option<EventTimeFormat> {
275        match config.get("event.time.format").map(str::to_lowercase) {
276            Some(ref s) if s == "epoch_millis" => Some(EventTimeFormat::EpochMillis),
277            Some(ref s) if s == "epoch_micros" => Some(EventTimeFormat::EpochMicros),
278            Some(ref s) if s == "epoch_nanos" => Some(EventTimeFormat::EpochNanos),
279            Some(ref s) if s == "epoch_seconds" => Some(EventTimeFormat::EpochSeconds),
280            Some(ref s) if s == "iso8601" => Some(EventTimeFormat::Iso8601),
281            Some(other) => Some(EventTimeFormat::Custom(other.clone())),
282            None => None,
283        }
284    }
285
286    /// Parses the `auth.*` properties into an optional [`WsAuthConfig`].
287    fn parse_auth(config: &ConnectorConfig) -> Result<Option<WsAuthConfig>, ConnectorError> {
288        match config.get("auth.type").map(str::to_lowercase) {
289            Some(ref s) if s == "bearer" => {
290                let token = config.require("auth.token").map(ToString::to_string)?;
291                Ok(Some(WsAuthConfig::Bearer { token }))
292            }
293            Some(ref s) if s == "basic" => {
294                let username = config.require("auth.username").map(ToString::to_string)?;
295                let password = config.require("auth.password").map(ToString::to_string)?;
296                Ok(Some(WsAuthConfig::Basic { username, password }))
297            }
298            Some(ref s) if s == "hmac" => {
299                let api_key = config.require("auth.api.key").map(ToString::to_string)?;
300                let secret = config.require("auth.secret").map(ToString::to_string)?;
301                Ok(Some(WsAuthConfig::Hmac { api_key, secret }))
302            }
303            Some(ref other) => Err(ConnectorError::ConfigurationError(format!(
304                "unsupported auth type '{other}': expected bearer, basic, or hmac"
305            ))),
306            None => Ok(None),
307        }
308    }
309}
310
311// ---------------------------------------------------------------------------
312// SourceMode
313// ---------------------------------------------------------------------------
314
315/// Operating mode for the WebSocket source connector.
316#[derive(Debug, Clone, Serialize, Deserialize)]
317#[serde(tag = "type")]
318pub enum SourceMode {
319    /// Client mode: connect to upstream WebSocket server(s).
320    Client {
321        /// One or more WebSocket URLs to connect to (e.g., `wss://feed.example.com/v1`).
322        urls: Vec<String>,
323
324        /// Optional message to send after the WebSocket handshake completes
325        /// (e.g., a JSON subscribe payload).
326        subscribe_message: Option<String>,
327
328        /// Reconnection policy applied when the connection drops.
329        #[serde(default)]
330        reconnect: ReconnectConfig,
331
332        /// Interval between WebSocket ping frames.
333        #[serde(default = "default_ping_interval", with = "duration_millis")]
334        ping_interval: Duration,
335
336        /// Time to wait for a pong reply before considering the connection dead.
337        #[serde(default = "default_ping_timeout", with = "duration_millis")]
338        ping_timeout: Duration,
339    },
340
341    /// Server mode: listen for incoming WebSocket connections.
342    Server {
343        /// Socket address to bind (e.g., `0.0.0.0:9443`).
344        bind_address: String,
345
346        /// Maximum number of concurrent WebSocket connections.
347        #[serde(default = "default_max_connections")]
348        max_connections: usize,
349
350        /// Optional URL path to accept connections on (e.g., `/ingest`).
351        ///
352        /// When `None`, connections are accepted on any path.
353        path: Option<String>,
354    },
355}
356
357impl Default for SourceMode {
358    fn default() -> Self {
359        Self::Client {
360            urls: vec![String::new()],
361            subscribe_message: None,
362            reconnect: ReconnectConfig::default(),
363            ping_interval: default_ping_interval(),
364            ping_timeout: default_ping_timeout(),
365        }
366    }
367}
368
369// ---------------------------------------------------------------------------
370// MessageFormat
371// ---------------------------------------------------------------------------
372
373/// Deserialization format for incoming WebSocket messages.
374#[derive(Debug, Clone, Default, Serialize, Deserialize)]
375pub enum MessageFormat {
376    /// Each message is a single JSON object.
377    #[default]
378    Json,
379
380    /// Each message contains one or more newline-delimited JSON objects.
381    JsonLines,
382
383    /// Raw binary payload (passed through as-is).
384    Binary,
385
386    /// CSV-formatted payload.
387    Csv {
388        /// Field delimiter character (defaults to `,`).
389        delimiter: char,
390        /// Whether the first row is a header row.
391        has_header: bool,
392    },
393}
394
395// ---------------------------------------------------------------------------
396// EventTimeFormat
397// ---------------------------------------------------------------------------
398
399/// Format of the event timestamp extracted from messages.
400#[derive(Debug, Clone, Serialize, Deserialize)]
401pub enum EventTimeFormat {
402    /// Milliseconds since the Unix epoch.
403    EpochMillis,
404
405    /// Microseconds since the Unix epoch.
406    EpochMicros,
407
408    /// Nanoseconds since the Unix epoch.
409    EpochNanos,
410
411    /// Seconds since the Unix epoch (integer or floating-point).
412    EpochSeconds,
413
414    /// ISO 8601 datetime string (e.g., `2026-02-21T12:00:00Z`).
415    Iso8601,
416
417    /// Custom `strftime`-compatible format string.
418    Custom(String),
419}
420
421// ---------------------------------------------------------------------------
422// ReconnectConfig
423// ---------------------------------------------------------------------------
424
425/// Exponential-backoff reconnection policy for client mode.
426///
427/// When the WebSocket connection is lost, the connector will attempt to
428/// reconnect with exponentially increasing delays between attempts,
429/// optionally capped at `max_retries`.
430#[derive(Debug, Clone, Serialize, Deserialize)]
431pub struct ReconnectConfig {
432    /// Whether automatic reconnection is enabled.
433    pub enabled: bool,
434
435    /// Initial delay before the first reconnection attempt.
436    #[serde(default = "default_initial_delay", with = "duration_millis")]
437    pub initial_delay: Duration,
438
439    /// Maximum delay between reconnection attempts.
440    #[serde(default = "default_max_delay", with = "duration_millis")]
441    pub max_delay: Duration,
442
443    /// Multiplier applied to the delay after each failed attempt.
444    #[serde(default = "default_backoff_multiplier")]
445    pub backoff_multiplier: f64,
446
447    /// Optional upper bound on reconnection attempts.
448    ///
449    /// `None` means retry indefinitely.
450    pub max_retries: Option<u32>,
451
452    /// Whether to apply random jitter to backoff delays to avoid thundering-herd.
453    #[serde(default = "default_true")]
454    pub jitter: bool,
455}
456
457impl Default for ReconnectConfig {
458    fn default() -> Self {
459        Self {
460            enabled: true,
461            initial_delay: default_initial_delay(),
462            max_delay: default_max_delay(),
463            backoff_multiplier: default_backoff_multiplier(),
464            max_retries: None,
465            jitter: true,
466        }
467    }
468}
469
470// ---------------------------------------------------------------------------
471// WsAuthConfig
472// ---------------------------------------------------------------------------
473
474/// Authentication configuration for WebSocket connections.
475///
476/// Applied during the HTTP upgrade handshake as headers, query parameters,
477/// or used to compute a signature.
478#[derive(Debug, Clone, Serialize, Deserialize)]
479#[serde(tag = "type")]
480pub enum WsAuthConfig {
481    /// Bearer token authentication (sent as `Authorization: Bearer <token>`).
482    Bearer {
483        /// The bearer token value.
484        token: String,
485    },
486
487    /// HTTP Basic authentication (sent as `Authorization: Basic <base64>`).
488    Basic {
489        /// Username for basic auth.
490        username: String,
491        /// Password for basic auth.
492        password: String,
493    },
494
495    /// Arbitrary HTTP headers added to the upgrade request.
496    Headers {
497        /// Key-value pairs added as HTTP headers.
498        headers: Vec<(String, String)>,
499    },
500
501    /// Single query parameter appended to the WebSocket URL.
502    QueryParam {
503        /// Query parameter name.
504        key: String,
505        /// Query parameter value.
506        value: String,
507    },
508
509    /// HMAC signature authentication (e.g., for exchange APIs).
510    Hmac {
511        /// API key (sent as a header or query parameter).
512        api_key: String,
513        /// HMAC secret used to sign requests.
514        secret: String,
515    },
516}
517
518// ---------------------------------------------------------------------------
519// Tests
520// ---------------------------------------------------------------------------
521
522#[cfg(test)]
523mod tests {
524    use super::*;
525
526    // -- Default impls -------------------------------------------------------
527
528    #[test]
529    fn test_default_websocket_source_config() {
530        let cfg = WebSocketSourceConfig::default();
531
532        assert!(matches!(cfg.mode, SourceMode::Client { .. }));
533        assert!(matches!(cfg.format, MessageFormat::Json));
534        assert!(matches!(cfg.on_backpressure, BackpressureStrategy::Block));
535        assert_eq!(cfg.max_message_size, 64 * 1024 * 1024);
536        assert!(cfg.event_time_field.is_none());
537        assert!(cfg.event_time_format.is_none());
538        assert!(cfg.auth.is_none());
539    }
540
541    #[test]
542    fn test_default_source_mode() {
543        let mode = SourceMode::default();
544        match mode {
545            SourceMode::Client {
546                urls,
547                subscribe_message,
548                reconnect,
549                ping_interval,
550                ping_timeout,
551            } => {
552                assert_eq!(urls.len(), 1);
553                assert_eq!(urls[0], "");
554                assert!(subscribe_message.is_none());
555                assert!(reconnect.enabled);
556                assert_eq!(ping_interval, Duration::from_secs(30));
557                assert_eq!(ping_timeout, Duration::from_secs(10));
558            }
559            SourceMode::Server { .. } => panic!("expected Client mode"),
560        }
561    }
562
563    #[test]
564    fn test_default_message_format() {
565        let fmt = MessageFormat::default();
566        assert!(matches!(fmt, MessageFormat::Json));
567    }
568
569    #[test]
570    fn test_default_reconnect_config() {
571        let rc = ReconnectConfig::default();
572        assert!(rc.enabled);
573        assert_eq!(rc.initial_delay, Duration::from_millis(100));
574        assert_eq!(rc.max_delay, Duration::from_secs(30));
575        assert!((rc.backoff_multiplier - 2.0).abs() < f64::EPSILON);
576        assert!(rc.max_retries.is_none());
577        assert!(rc.jitter);
578    }
579
580    // -- Serde round-trip -----------------------------------------------------
581
582    #[test]
583    fn test_serde_round_trip_client_mode() {
584        let cfg = WebSocketSourceConfig {
585            mode: SourceMode::Client {
586                urls: vec!["wss://feed.example.com/v1".into()],
587                subscribe_message: Some(r#"{"op":"subscribe","channel":"trades"}"#.into()),
588                reconnect: ReconnectConfig::default(),
589                ping_interval: Duration::from_secs(15),
590                ping_timeout: Duration::from_secs(5),
591            },
592            format: MessageFormat::Json,
593            on_backpressure: BackpressureStrategy::Block,
594            event_time_field: Some("timestamp".into()),
595            event_time_format: Some(EventTimeFormat::EpochMillis),
596            max_message_size: 1024 * 1024,
597            auth: Some(WsAuthConfig::Bearer {
598                token: "tok_abc123".into(),
599            }),
600        };
601
602        let json = serde_json::to_string_pretty(&cfg).expect("serialize");
603        let deser: WebSocketSourceConfig = serde_json::from_str(&json).expect("deserialize");
604
605        // Verify key fields survived the round-trip.
606        match &deser.mode {
607            SourceMode::Client {
608                urls,
609                subscribe_message,
610                ping_interval,
611                ping_timeout,
612                ..
613            } => {
614                assert_eq!(urls, &["wss://feed.example.com/v1"]);
615                assert_eq!(
616                    subscribe_message.as_deref(),
617                    Some(r#"{"op":"subscribe","channel":"trades"}"#)
618                );
619                assert_eq!(*ping_interval, Duration::from_secs(15));
620                assert_eq!(*ping_timeout, Duration::from_secs(5));
621            }
622            SourceMode::Server { .. } => panic!("expected Client"),
623        }
624        assert_eq!(deser.event_time_field.as_deref(), Some("timestamp"));
625        assert!(matches!(
626            deser.event_time_format,
627            Some(EventTimeFormat::EpochMillis)
628        ));
629        assert_eq!(deser.max_message_size, 1024 * 1024);
630        assert!(matches!(
631            deser.auth,
632            Some(WsAuthConfig::Bearer { ref token }) if token == "tok_abc123"
633        ));
634    }
635
636    #[test]
637    fn test_serde_round_trip_server_mode() {
638        let cfg = WebSocketSourceConfig {
639            mode: SourceMode::Server {
640                bind_address: "0.0.0.0:9443".into(),
641                max_connections: 512,
642                path: Some("/ingest".into()),
643            },
644            format: MessageFormat::JsonLines,
645            on_backpressure: BackpressureStrategy::Block,
646            event_time_field: None,
647            event_time_format: None,
648            max_message_size: default_max_message_size(),
649            auth: None,
650        };
651
652        let json = serde_json::to_string(&cfg).expect("serialize");
653        let deser: WebSocketSourceConfig = serde_json::from_str(&json).expect("deserialize");
654
655        match &deser.mode {
656            SourceMode::Server {
657                bind_address,
658                max_connections,
659                path,
660            } => {
661                assert_eq!(bind_address, "0.0.0.0:9443");
662                assert_eq!(*max_connections, 512);
663                assert_eq!(path.as_deref(), Some("/ingest"));
664            }
665            SourceMode::Client { .. } => panic!("expected Server"),
666        }
667    }
668
669    #[test]
670    fn test_serde_round_trip_reconnect_config() {
671        let rc = ReconnectConfig {
672            enabled: false,
673            initial_delay: Duration::from_millis(500),
674            max_delay: Duration::from_secs(60),
675            backoff_multiplier: 1.5,
676            max_retries: Some(10),
677            jitter: false,
678        };
679
680        let json = serde_json::to_string(&rc).expect("serialize");
681        let deser: ReconnectConfig = serde_json::from_str(&json).expect("deserialize");
682
683        assert!(!deser.enabled);
684        assert_eq!(deser.initial_delay, Duration::from_millis(500));
685        assert_eq!(deser.max_delay, Duration::from_secs(60));
686        assert!((deser.backoff_multiplier - 1.5).abs() < f64::EPSILON);
687        assert_eq!(deser.max_retries, Some(10));
688        assert!(!deser.jitter);
689    }
690
691    #[test]
692    fn test_serde_round_trip_csv_format() {
693        let cfg = WebSocketSourceConfig {
694            format: MessageFormat::Csv {
695                delimiter: '|',
696                has_header: true,
697            },
698            ..WebSocketSourceConfig::default()
699        };
700
701        let json = serde_json::to_string(&cfg).expect("serialize");
702        let deser: WebSocketSourceConfig = serde_json::from_str(&json).expect("deserialize");
703
704        match deser.format {
705            MessageFormat::Csv {
706                delimiter,
707                has_header,
708            } => {
709                assert_eq!(delimiter, '|');
710                assert!(has_header);
711            }
712            _ => panic!("expected Csv format"),
713        }
714    }
715
716    #[test]
717    fn test_serde_round_trip_auth_variants() {
718        // Basic auth
719        let basic = WsAuthConfig::Basic {
720            username: "user".into(),
721            password: "pass".into(),
722        };
723        let json = serde_json::to_string(&basic).expect("serialize");
724        let deser: WsAuthConfig = serde_json::from_str(&json).expect("deserialize");
725        assert!(matches!(
726            deser,
727            WsAuthConfig::Basic { ref username, ref password }
728                if username == "user" && password == "pass"
729        ));
730
731        // Headers auth
732        let headers = WsAuthConfig::Headers {
733            headers: vec![("X-Api-Key".into(), "key123".into())],
734        };
735        let json = serde_json::to_string(&headers).expect("serialize");
736        let deser: WsAuthConfig = serde_json::from_str(&json).expect("deserialize");
737        assert!(matches!(deser, WsAuthConfig::Headers { ref headers } if headers.len() == 1));
738
739        // QueryParam auth
740        let qp = WsAuthConfig::QueryParam {
741            key: "token".into(),
742            value: "abc".into(),
743        };
744        let json = serde_json::to_string(&qp).expect("serialize");
745        let deser: WsAuthConfig = serde_json::from_str(&json).expect("deserialize");
746        assert!(matches!(
747            deser,
748            WsAuthConfig::QueryParam { ref key, ref value }
749                if key == "token" && value == "abc"
750        ));
751
752        // Hmac auth
753        let hmac = WsAuthConfig::Hmac {
754            api_key: "ak".into(),
755            secret: "sk".into(),
756        };
757        let json = serde_json::to_string(&hmac).expect("serialize");
758        let deser: WsAuthConfig = serde_json::from_str(&json).expect("deserialize");
759        assert!(matches!(
760            deser,
761            WsAuthConfig::Hmac { ref api_key, ref secret }
762                if api_key == "ak" && secret == "sk"
763        ));
764    }
765
766    #[test]
767    fn test_serde_round_trip_event_time_formats() {
768        let formats = vec![
769            EventTimeFormat::EpochMillis,
770            EventTimeFormat::EpochMicros,
771            EventTimeFormat::EpochNanos,
772            EventTimeFormat::EpochSeconds,
773            EventTimeFormat::Iso8601,
774            EventTimeFormat::Custom("%Y-%m-%dT%H:%M:%S".into()),
775        ];
776
777        for fmt in formats {
778            let json = serde_json::to_string(&fmt).expect("serialize");
779            let deser: EventTimeFormat = serde_json::from_str(&json).expect("deserialize");
780
781            // Verify variant is preserved.
782            match (&fmt, &deser) {
783                (EventTimeFormat::EpochMillis, EventTimeFormat::EpochMillis)
784                | (EventTimeFormat::EpochMicros, EventTimeFormat::EpochMicros)
785                | (EventTimeFormat::EpochNanos, EventTimeFormat::EpochNanos)
786                | (EventTimeFormat::EpochSeconds, EventTimeFormat::EpochSeconds)
787                | (EventTimeFormat::Iso8601, EventTimeFormat::Iso8601) => {}
788                (EventTimeFormat::Custom(a), EventTimeFormat::Custom(b)) => {
789                    assert_eq!(a, b);
790                }
791                _ => panic!("event time format mismatch after round-trip"),
792            }
793        }
794    }
795
796    #[test]
797    fn test_serde_defaults_applied() {
798        // Minimal JSON: only required fields are set, all defaulted fields omitted.
799        let json = r#"{
800            "mode": {
801                "type": "Server",
802                "bind_address": "127.0.0.1:8080"
803            },
804            "format": "Json"
805        }"#;
806
807        let cfg: WebSocketSourceConfig =
808            serde_json::from_str(json).expect("deserialize with defaults");
809
810        assert!(matches!(cfg.on_backpressure, BackpressureStrategy::Block));
811        assert_eq!(cfg.max_message_size, 64 * 1024 * 1024);
812        assert!(cfg.event_time_field.is_none());
813        assert!(cfg.auth.is_none());
814
815        match cfg.mode {
816            SourceMode::Server {
817                max_connections, ..
818            } => {
819                assert_eq!(max_connections, 1024);
820            }
821            SourceMode::Client { .. } => panic!("expected Server"),
822        }
823    }
824
825    // -- Default helper functions -------------------------------------------
826
827    #[test]
828    fn test_default_helper_values() {
829        assert_eq!(default_max_message_size(), 64 * 1024 * 1024);
830        assert_eq!(default_ping_interval(), Duration::from_secs(30));
831        assert_eq!(default_ping_timeout(), Duration::from_secs(10));
832        assert_eq!(default_max_connections(), 1024);
833        assert_eq!(default_initial_delay(), Duration::from_millis(100));
834        assert_eq!(default_max_delay(), Duration::from_secs(30));
835        assert!((default_backoff_multiplier() - 2.0).abs() < f64::EPSILON);
836        assert!(default_true());
837    }
838
839    // -- from_config tests ---------------------------------------------------
840
841    #[test]
842    fn test_from_config_client_mode() {
843        let mut config = ConnectorConfig::new("websocket");
844        config.set("url", "wss://feed.example.com/v1");
845        config.set("format", "json");
846        config.set("subscribe.message", r#"{"op":"subscribe"}"#);
847        config.set("reconnect.enabled", "true");
848        config.set("reconnect.initial.delay.ms", "200");
849        config.set("reconnect.max.delay.ms", "60000");
850        config.set("ping.interval.ms", "15000");
851        config.set("ping.timeout.ms", "5000");
852
853        let cfg = WebSocketSourceConfig::from_config(&config).unwrap();
854
855        match &cfg.mode {
856            SourceMode::Client {
857                urls,
858                subscribe_message,
859                reconnect,
860                ping_interval,
861                ping_timeout,
862            } => {
863                assert_eq!(urls, &["wss://feed.example.com/v1"]);
864                assert_eq!(subscribe_message.as_deref(), Some(r#"{"op":"subscribe"}"#));
865                assert!(reconnect.enabled);
866                assert_eq!(reconnect.initial_delay, Duration::from_millis(200));
867                assert_eq!(reconnect.max_delay, Duration::from_millis(60_000));
868                assert_eq!(*ping_interval, Duration::from_millis(15_000));
869                assert_eq!(*ping_timeout, Duration::from_millis(5_000));
870            }
871            SourceMode::Server { .. } => panic!("expected Client mode"),
872        }
873        assert!(matches!(cfg.format, MessageFormat::Json));
874    }
875
876    #[test]
877    fn test_from_config_server_mode() {
878        let mut config = ConnectorConfig::new("websocket");
879        config.set("mode", "server");
880        config.set("bind.address", "0.0.0.0:9443");
881        config.set("max.connections", "512");
882        config.set("path", "/ingest");
883
884        let cfg = WebSocketSourceConfig::from_config(&config).unwrap();
885
886        match &cfg.mode {
887            SourceMode::Server {
888                bind_address,
889                max_connections,
890                path,
891            } => {
892                assert_eq!(bind_address, "0.0.0.0:9443");
893                assert_eq!(*max_connections, 512);
894                assert_eq!(path.as_deref(), Some("/ingest"));
895            }
896            SourceMode::Client { .. } => panic!("expected Server mode"),
897        }
898    }
899
900    #[test]
901    fn test_from_config_missing_url_errors() {
902        let config = ConnectorConfig::new("websocket");
903        // Client mode is default — missing URL should error.
904        let result = WebSocketSourceConfig::from_config(&config);
905        assert!(result.is_err());
906        assert!(result.unwrap_err().to_string().contains("url"));
907    }
908
909    #[test]
910    fn test_from_config_bearer_auth() {
911        let mut config = ConnectorConfig::new("websocket");
912        config.set("url", "wss://api.example.com");
913        config.set("auth.type", "bearer");
914        config.set("auth.token", "tok_abc123");
915
916        let cfg = WebSocketSourceConfig::from_config(&config).unwrap();
917        assert!(matches!(
918            cfg.auth,
919            Some(WsAuthConfig::Bearer { ref token }) if token == "tok_abc123"
920        ));
921    }
922
923    #[test]
924    fn test_from_config_multiple_urls() {
925        let mut config = ConnectorConfig::new("websocket");
926        config.set("url", "wss://a.example.com, wss://b.example.com");
927
928        let cfg = WebSocketSourceConfig::from_config(&config).unwrap();
929        match &cfg.mode {
930            SourceMode::Client { urls, .. } => {
931                assert_eq!(urls.len(), 2);
932                assert_eq!(urls[0], "wss://a.example.com");
933                assert_eq!(urls[1], "wss://b.example.com");
934            }
935            SourceMode::Server { .. } => panic!("expected Client mode"),
936        }
937    }
938
939    #[test]
940    fn test_from_config_defaults() {
941        let mut config = ConnectorConfig::new("websocket");
942        config.set("url", "wss://feed.example.com");
943
944        let cfg = WebSocketSourceConfig::from_config(&config).unwrap();
945        assert!(matches!(cfg.format, MessageFormat::Json));
946        assert!(matches!(cfg.on_backpressure, BackpressureStrategy::Block));
947        assert_eq!(cfg.max_message_size, 64 * 1024 * 1024);
948        assert!(cfg.event_time_field.is_none());
949        assert!(cfg.auth.is_none());
950    }
951}