Skip to main content

laminar_connectors/websocket/
sink_config.rs

1//! WebSocket sink connector configuration.
2//!
3//! [`WebSocketSinkConfig`] controls how the sink writes Arrow data to
4//! WebSocket clients (server mode) or to an upstream WebSocket server
5//! (client mode). Parsed from a SQL `WITH (...)` clause or constructed
6//! programmatically.
7
8use std::time::Duration;
9
10use serde::{Deserialize, Serialize};
11
12use super::source_config::{ReconnectConfig, WsAuthConfig};
13use crate::config::ConnectorConfig;
14use crate::error::ConnectorError;
15
16// ---------------------------------------------------------------------------
17// Serde default helpers
18// ---------------------------------------------------------------------------
19
20/// Default maximum number of concurrent WebSocket connections.
21fn default_max_connections() -> usize {
22    10_000
23}
24
25/// Default per-client send buffer size in bytes (256 KB).
26fn default_per_client_buffer() -> usize {
27    262_144
28}
29
30/// Default WebSocket ping interval.
31fn default_ping_interval() -> Duration {
32    Duration::from_secs(30)
33}
34
35/// Default timeout waiting for a pong response.
36fn default_ping_timeout() -> Duration {
37    Duration::from_secs(10)
38}
39
40// ---------------------------------------------------------------------------
41// Top-level config
42// ---------------------------------------------------------------------------
43
44/// Configuration for the WebSocket sink connector.
45///
46/// Supports two operating modes:
47/// - **Server**: binds a local address and fans out records to connected
48///   WebSocket clients.
49/// - **Client**: connects to a remote WebSocket server and pushes records.
50#[derive(Debug, Clone, Default, Serialize, Deserialize)]
51pub struct WebSocketSinkConfig {
52    /// The operating mode (server or client).
53    pub mode: SinkMode,
54    /// Serialization format for outgoing messages.
55    pub format: SinkFormat,
56    /// Optional authentication configuration shared with the source connector.
57    pub auth: Option<WsAuthConfig>,
58}
59
60impl WebSocketSinkConfig {
61    /// Builds a [`WebSocketSinkConfig`] from a flat [`ConnectorConfig`] property map.
62    ///
63    /// # Errors
64    ///
65    /// Returns `ConnectorError::ConfigurationError` if a required key is missing
66    /// or a value cannot be parsed.
67    pub fn from_config(config: &ConnectorConfig) -> Result<Self, ConnectorError> {
68        let mode_str = config.get("mode").unwrap_or("server");
69        let mode = match mode_str.to_lowercase().as_str() {
70            "server" => {
71                let bind_address = config.require("bind.address").map(ToString::to_string)?;
72                let max_connections: usize = config
73                    .get_parsed("max.connections")?
74                    .unwrap_or(default_max_connections());
75                let per_client_buffer: usize = config
76                    .get_parsed("per.client.buffer")?
77                    .unwrap_or(default_per_client_buffer());
78                let ping_interval_ms: u64 =
79                    config.get_parsed("ping.interval.ms")?.unwrap_or(30_000);
80                let ping_timeout_ms: u64 = config.get_parsed("ping.timeout.ms")?.unwrap_or(10_000);
81                let replay_buffer_size: Option<usize> = config.get_parsed("replay.buffer.size")?;
82                let path = config.get("path").map(ToString::to_string);
83
84                let slow_client_policy =
85                    match config.get("slow.client.policy").map(str::to_lowercase) {
86                        Some(ref s) if s == "drop_oldest" => SlowClientPolicy::DropOldest,
87                        Some(ref s) if s == "drop_newest" => SlowClientPolicy::DropNewest,
88                        Some(ref s) if s == "disconnect" => SlowClientPolicy::Disconnect {
89                            threshold_pct: config
90                                .get_parsed("slow.client.threshold.pct")?
91                                .unwrap_or(90),
92                        },
93                        _ => SlowClientPolicy::default(),
94                    };
95
96                SinkMode::Server {
97                    bind_address,
98                    path,
99                    max_connections,
100                    per_client_buffer,
101                    slow_client_policy,
102                    ping_interval: Duration::from_millis(ping_interval_ms),
103                    ping_timeout: Duration::from_millis(ping_timeout_ms),
104                    enable_subscription_filter: false,
105                    replay_buffer_size,
106                }
107            }
108            "client" => {
109                let url = config.require("url").map(ToString::to_string)?;
110                let buffer_on_disconnect: Option<usize> =
111                    config.get_parsed("buffer.on.disconnect")?;
112                let batch_max_size: Option<usize> = config.get_parsed("batch.max.size")?;
113                let batch_interval_ms: Option<u64> = config.get_parsed("batch.interval.ms")?;
114
115                SinkMode::Client {
116                    url,
117                    reconnect: ReconnectConfig::default(),
118                    buffer_on_disconnect,
119                    batch_interval: batch_interval_ms.map(Duration::from_millis),
120                    batch_max_size,
121                }
122            }
123            other => {
124                return Err(ConnectorError::ConfigurationError(format!(
125                    "invalid WebSocket sink mode '{other}': expected 'server' or 'client'"
126                )));
127            }
128        };
129
130        let format = match config.get("format").map(str::to_lowercase) {
131            Some(ref s) if s == "json" => SinkFormat::Json,
132            Some(ref s) if s == "jsonlines" || s == "json_lines" => SinkFormat::JsonLines,
133            Some(ref s) if s == "arrow_ipc" || s == "arowipc" => SinkFormat::ArrowIpc,
134            Some(ref s) if s == "binary" => SinkFormat::Binary,
135            Some(ref other) => {
136                return Err(ConnectorError::ConfigurationError(format!(
137                    "invalid sink format '{other}': expected json, jsonlines, arrow_ipc, or binary"
138                )));
139            }
140            None => SinkFormat::Json,
141        };
142
143        let auth = match config.get("auth.type").map(str::to_lowercase) {
144            Some(ref s) if s == "bearer" => {
145                let token = config.require("auth.token").map(ToString::to_string)?;
146                Some(WsAuthConfig::Bearer { token })
147            }
148            Some(ref s) if s == "basic" => {
149                let username = config.require("auth.username").map(ToString::to_string)?;
150                let password = config.require("auth.password").map(ToString::to_string)?;
151                Some(WsAuthConfig::Basic { username, password })
152            }
153            Some(ref s) if s == "hmac" => {
154                let api_key = config.require("auth.api.key").map(ToString::to_string)?;
155                let secret = config.require("auth.secret").map(ToString::to_string)?;
156                Some(WsAuthConfig::Hmac { api_key, secret })
157            }
158            _ => None,
159        };
160
161        Ok(Self { mode, format, auth })
162    }
163}
164
165// ---------------------------------------------------------------------------
166// SinkMode
167// ---------------------------------------------------------------------------
168
169/// Operating mode for the WebSocket sink.
170#[derive(Debug, Clone, Serialize, Deserialize)]
171#[serde(tag = "type")]
172pub enum SinkMode {
173    /// Bind a local address and fan out records to all connected clients.
174    Server {
175        /// Address to bind (e.g. `"0.0.0.0:8080"`).
176        bind_address: String,
177        /// Optional URL path filter (e.g. `"/stream"`).
178        path: Option<String>,
179        /// Maximum number of concurrent client connections.
180        #[serde(default = "default_max_connections")]
181        max_connections: usize,
182        /// Per-client outbound buffer size in bytes (default 256 KB).
183        #[serde(default = "default_per_client_buffer")]
184        per_client_buffer: usize,
185        /// Policy for handling clients whose send buffer is full.
186        #[serde(default)]
187        slow_client_policy: SlowClientPolicy,
188        /// Interval between WebSocket ping frames.
189        #[serde(default = "default_ping_interval")]
190        ping_interval: Duration,
191        /// Timeout waiting for a pong response before disconnecting.
192        #[serde(default = "default_ping_timeout")]
193        ping_timeout: Duration,
194        /// When `true`, clients may send subscription filter messages.
195        #[serde(default)]
196        enable_subscription_filter: bool,
197        /// Optional bounded replay buffer so late-joining clients can
198        /// catch up. `None` means no replay.
199        replay_buffer_size: Option<usize>,
200    },
201    /// Connect to an external WebSocket server and push records.
202    Client {
203        /// WebSocket URL to connect to (e.g. `"wss://host/path"`).
204        url: String,
205        /// Reconnection strategy when the connection drops.
206        #[serde(default)]
207        reconnect: ReconnectConfig,
208        /// Number of records to buffer in memory while disconnected.
209        /// `None` means records are dropped on disconnect.
210        buffer_on_disconnect: Option<usize>,
211        /// Optional interval to batch outgoing messages.
212        batch_interval: Option<Duration>,
213        /// Maximum number of records per batch.
214        batch_max_size: Option<usize>,
215    },
216}
217
218impl Default for SinkMode {
219    fn default() -> Self {
220        Self::Server {
221            bind_address: "0.0.0.0:8080".to_string(),
222            path: None,
223            max_connections: default_max_connections(),
224            per_client_buffer: default_per_client_buffer(),
225            slow_client_policy: SlowClientPolicy::default(),
226            ping_interval: default_ping_interval(),
227            ping_timeout: default_ping_timeout(),
228            enable_subscription_filter: false,
229            replay_buffer_size: None,
230        }
231    }
232}
233
234// ---------------------------------------------------------------------------
235// SlowClientPolicy
236// ---------------------------------------------------------------------------
237
238/// Policy applied when a client's outbound buffer is full.
239#[derive(Debug, Clone, Default, Serialize, Deserialize)]
240pub enum SlowClientPolicy {
241    /// Drop the oldest buffered message to make room for the new one.
242    #[default]
243    DropOldest,
244    /// Drop the newest (incoming) message when the buffer is full.
245    DropNewest,
246    /// Disconnect the client once its buffer reaches the given threshold
247    /// percentage.
248    Disconnect {
249        /// Buffer fullness percentage (0-100) at which the client is
250        /// disconnected.
251        threshold_pct: u8,
252    },
253    /// Emit a warning when the buffer reaches `warn_pct`, then disconnect
254    /// at `disconnect_pct`.
255    WarnThenDisconnect {
256        /// Buffer fullness percentage (0-100) at which a warning is emitted.
257        warn_pct: u8,
258        /// Buffer fullness percentage (0-100) at which the client is
259        /// disconnected.
260        disconnect_pct: u8,
261    },
262}
263
264// ---------------------------------------------------------------------------
265// SinkFormat
266// ---------------------------------------------------------------------------
267
268/// Serialization format for outgoing WebSocket messages.
269#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
270pub enum SinkFormat {
271    /// One JSON object per message.
272    #[default]
273    Json,
274    /// Newline-delimited JSON (one JSON object per line).
275    JsonLines,
276    /// Arrow IPC (streaming) format.
277    ArrowIpc,
278    /// Raw binary (application-defined).
279    Binary,
280}
281
282// ---------------------------------------------------------------------------
283// Tests
284// ---------------------------------------------------------------------------
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    // -- Default impls -------------------------------------------------------
291
292    #[test]
293    fn test_slow_client_policy_default() {
294        let policy = SlowClientPolicy::default();
295        assert!(matches!(policy, SlowClientPolicy::DropOldest));
296    }
297
298    #[test]
299    fn test_sink_format_default() {
300        let format = SinkFormat::default();
301        assert_eq!(format, SinkFormat::Json);
302    }
303
304    #[test]
305    fn test_sink_mode_default_is_server() {
306        let mode = SinkMode::default();
307        match mode {
308            SinkMode::Server {
309                bind_address,
310                path,
311                max_connections,
312                per_client_buffer,
313                slow_client_policy,
314                ping_interval,
315                ping_timeout,
316                enable_subscription_filter,
317                replay_buffer_size,
318            } => {
319                assert_eq!(bind_address, "0.0.0.0:8080");
320                assert!(path.is_none());
321                assert_eq!(max_connections, 10_000);
322                assert_eq!(per_client_buffer, 262_144);
323                assert!(matches!(slow_client_policy, SlowClientPolicy::DropOldest));
324                assert_eq!(ping_interval, Duration::from_secs(30));
325                assert_eq!(ping_timeout, Duration::from_secs(10));
326                assert!(!enable_subscription_filter);
327                assert!(replay_buffer_size.is_none());
328            }
329            SinkMode::Client { .. } => panic!("expected Server, got Client"),
330        }
331    }
332
333    #[test]
334    fn test_websocket_sink_config_default() {
335        let config = WebSocketSinkConfig::default();
336        assert_eq!(config.format, SinkFormat::Json);
337        assert!(config.auth.is_none());
338        assert!(matches!(config.mode, SinkMode::Server { .. }));
339    }
340
341    // -- Serde default helpers -----------------------------------------------
342
343    #[test]
344    fn test_default_max_connections() {
345        assert_eq!(default_max_connections(), 10_000);
346    }
347
348    #[test]
349    fn test_default_per_client_buffer() {
350        assert_eq!(default_per_client_buffer(), 262_144);
351    }
352
353    #[test]
354    fn test_default_ping_interval() {
355        assert_eq!(default_ping_interval(), Duration::from_secs(30));
356    }
357
358    #[test]
359    fn test_default_ping_timeout() {
360        assert_eq!(default_ping_timeout(), Duration::from_secs(10));
361    }
362
363    // -- Serialization round-trips -------------------------------------------
364
365    #[test]
366    fn test_server_mode_serde_roundtrip() {
367        let mode = SinkMode::Server {
368            bind_address: "127.0.0.1:9090".to_string(),
369            path: Some("/ws".to_string()),
370            max_connections: 500,
371            per_client_buffer: 1024,
372            slow_client_policy: SlowClientPolicy::DropNewest,
373            ping_interval: Duration::from_secs(15),
374            ping_timeout: Duration::from_secs(5),
375            enable_subscription_filter: true,
376            replay_buffer_size: Some(1000),
377        };
378        let json = serde_json::to_string(&mode).unwrap();
379        let deser: SinkMode = serde_json::from_str(&json).unwrap();
380        match deser {
381            SinkMode::Server {
382                bind_address,
383                path,
384                max_connections,
385                per_client_buffer,
386                replay_buffer_size,
387                enable_subscription_filter,
388                ..
389            } => {
390                assert_eq!(bind_address, "127.0.0.1:9090");
391                assert_eq!(path.as_deref(), Some("/ws"));
392                assert_eq!(max_connections, 500);
393                assert_eq!(per_client_buffer, 1024);
394                assert!(enable_subscription_filter);
395                assert_eq!(replay_buffer_size, Some(1000));
396            }
397            SinkMode::Client { .. } => panic!("expected Server"),
398        }
399    }
400
401    #[test]
402    fn test_client_mode_serde_roundtrip() {
403        let mode = SinkMode::Client {
404            url: "wss://example.com/feed".to_string(),
405            reconnect: ReconnectConfig::default(),
406            buffer_on_disconnect: Some(5000),
407            batch_interval: Some(Duration::from_millis(100)),
408            batch_max_size: Some(64),
409        };
410        let json = serde_json::to_string(&mode).unwrap();
411        let deser: SinkMode = serde_json::from_str(&json).unwrap();
412        match deser {
413            SinkMode::Client {
414                url,
415                buffer_on_disconnect,
416                batch_interval,
417                batch_max_size,
418                ..
419            } => {
420                assert_eq!(url, "wss://example.com/feed");
421                assert_eq!(buffer_on_disconnect, Some(5000));
422                assert_eq!(batch_interval, Some(Duration::from_millis(100)));
423                assert_eq!(batch_max_size, Some(64));
424            }
425            SinkMode::Server { .. } => panic!("expected Client"),
426        }
427    }
428
429    #[test]
430    fn test_sink_format_serde_variants() {
431        for (format, expected) in [
432            (SinkFormat::Json, "\"Json\""),
433            (SinkFormat::JsonLines, "\"JsonLines\""),
434            (SinkFormat::ArrowIpc, "\"ArrowIpc\""),
435            (SinkFormat::Binary, "\"Binary\""),
436        ] {
437            let json = serde_json::to_string(&format).unwrap();
438            assert_eq!(json, expected);
439            let deser: SinkFormat = serde_json::from_str(&json).unwrap();
440            assert_eq!(deser, format);
441        }
442    }
443
444    #[test]
445    fn test_slow_client_policy_serde_variants() {
446        let disconnect = SlowClientPolicy::Disconnect { threshold_pct: 90 };
447        let json = serde_json::to_string(&disconnect).unwrap();
448        let deser: SlowClientPolicy = serde_json::from_str(&json).unwrap();
449        assert!(matches!(
450            deser,
451            SlowClientPolicy::Disconnect { threshold_pct: 90 }
452        ));
453
454        let warn = SlowClientPolicy::WarnThenDisconnect {
455            warn_pct: 75,
456            disconnect_pct: 95,
457        };
458        let json = serde_json::to_string(&warn).unwrap();
459        let deser: SlowClientPolicy = serde_json::from_str(&json).unwrap();
460        assert!(matches!(
461            deser,
462            SlowClientPolicy::WarnThenDisconnect {
463                warn_pct: 75,
464                disconnect_pct: 95,
465            }
466        ));
467    }
468
469    #[test]
470    fn test_full_config_serde_roundtrip() {
471        let config = WebSocketSinkConfig {
472            mode: SinkMode::Server {
473                bind_address: "0.0.0.0:3000".to_string(),
474                path: None,
475                max_connections: 2000,
476                per_client_buffer: 131_072,
477                slow_client_policy: SlowClientPolicy::WarnThenDisconnect {
478                    warn_pct: 80,
479                    disconnect_pct: 95,
480                },
481                ping_interval: Duration::from_secs(20),
482                ping_timeout: Duration::from_secs(8),
483                enable_subscription_filter: false,
484                replay_buffer_size: Some(500),
485            },
486            format: SinkFormat::ArrowIpc,
487            auth: None,
488        };
489
490        let json = serde_json::to_string_pretty(&config).unwrap();
491        let deser: WebSocketSinkConfig = serde_json::from_str(&json).unwrap();
492        assert_eq!(deser.format, SinkFormat::ArrowIpc);
493        assert!(deser.auth.is_none());
494        match deser.mode {
495            SinkMode::Server {
496                max_connections,
497                replay_buffer_size,
498                ..
499            } => {
500                assert_eq!(max_connections, 2000);
501                assert_eq!(replay_buffer_size, Some(500));
502            }
503            SinkMode::Client { .. } => panic!("expected Server"),
504        }
505    }
506
507    // -- Serde defaults are applied on deserialization -----------------------
508
509    #[test]
510    fn test_server_mode_serde_defaults_applied() {
511        // Minimal JSON — only required fields and the tag.
512        let json = r#"{
513            "type": "Server",
514            "bind_address": "0.0.0.0:4000"
515        }"#;
516        let mode: SinkMode = serde_json::from_str(json).unwrap();
517        match mode {
518            SinkMode::Server {
519                max_connections,
520                per_client_buffer,
521                slow_client_policy,
522                ping_interval,
523                ping_timeout,
524                enable_subscription_filter,
525                replay_buffer_size,
526                ..
527            } => {
528                assert_eq!(max_connections, 10_000);
529                assert_eq!(per_client_buffer, 262_144);
530                assert!(matches!(slow_client_policy, SlowClientPolicy::DropOldest));
531                assert_eq!(ping_interval, Duration::from_secs(30));
532                assert_eq!(ping_timeout, Duration::from_secs(10));
533                assert!(!enable_subscription_filter);
534                assert!(replay_buffer_size.is_none());
535            }
536            SinkMode::Client { .. } => panic!("expected Server"),
537        }
538    }
539
540    // -- from_config tests ---------------------------------------------------
541
542    #[test]
543    fn test_from_config_server_mode() {
544        let mut config = ConnectorConfig::new("websocket");
545        config.set("bind.address", "0.0.0.0:3000");
546        config.set("max.connections", "2000");
547        config.set("format", "jsonlines");
548
549        let cfg = WebSocketSinkConfig::from_config(&config).unwrap();
550        assert_eq!(cfg.format, SinkFormat::JsonLines);
551        match cfg.mode {
552            SinkMode::Server {
553                bind_address,
554                max_connections,
555                ..
556            } => {
557                assert_eq!(bind_address, "0.0.0.0:3000");
558                assert_eq!(max_connections, 2000);
559            }
560            SinkMode::Client { .. } => panic!("expected Server mode"),
561        }
562    }
563
564    #[test]
565    fn test_from_config_client_mode() {
566        let mut config = ConnectorConfig::new("websocket");
567        config.set("mode", "client");
568        config.set("url", "wss://upstream.example.com/feed");
569        config.set("format", "json");
570
571        let cfg = WebSocketSinkConfig::from_config(&config).unwrap();
572        assert_eq!(cfg.format, SinkFormat::Json);
573        match cfg.mode {
574            SinkMode::Client { url, .. } => {
575                assert_eq!(url, "wss://upstream.example.com/feed");
576            }
577            SinkMode::Server { .. } => panic!("expected Client mode"),
578        }
579    }
580
581    #[test]
582    fn test_from_config_missing_bind_address_errors() {
583        let config = ConnectorConfig::new("websocket");
584        // Server mode is default — missing bind.address should error.
585        let result = WebSocketSinkConfig::from_config(&config);
586        assert!(result.is_err());
587        assert!(result.unwrap_err().to_string().contains("bind.address"));
588    }
589
590    #[test]
591    fn test_from_config_bearer_auth() {
592        let mut config = ConnectorConfig::new("websocket");
593        config.set("bind.address", "0.0.0.0:8080");
594        config.set("auth.type", "bearer");
595        config.set("auth.token", "my-secret");
596
597        let cfg = WebSocketSinkConfig::from_config(&config).unwrap();
598        assert!(matches!(
599            cfg.auth,
600            Some(WsAuthConfig::Bearer { ref token }) if token == "my-secret"
601        ));
602    }
603}