1use std::time::Duration;
9
10use serde::{Deserialize, Serialize};
11
12use super::source_config::{ReconnectConfig, WsAuthConfig};
13use crate::config::ConnectorConfig;
14use crate::error::ConnectorError;
15
16fn default_max_connections() -> usize {
22 10_000
23}
24
25fn default_per_client_buffer() -> usize {
27 262_144
28}
29
30fn default_ping_interval() -> Duration {
32 Duration::from_secs(30)
33}
34
35fn default_ping_timeout() -> Duration {
37 Duration::from_secs(10)
38}
39
40#[derive(Debug, Clone, Default, Serialize, Deserialize)]
51pub struct WebSocketSinkConfig {
52 pub mode: SinkMode,
54 pub format: SinkFormat,
56 pub auth: Option<WsAuthConfig>,
58}
59
60impl WebSocketSinkConfig {
61 pub fn from_config(config: &ConnectorConfig) -> Result<Self, ConnectorError> {
68 let mode_str = config.get("mode").unwrap_or("server");
69 let mode = match mode_str.to_lowercase().as_str() {
70 "server" => {
71 let bind_address = config.require("bind.address").map(ToString::to_string)?;
72 let max_connections: usize = config
73 .get_parsed("max.connections")?
74 .unwrap_or(default_max_connections());
75 let per_client_buffer: usize = config
76 .get_parsed("per.client.buffer")?
77 .unwrap_or(default_per_client_buffer());
78 let ping_interval_ms: u64 =
79 config.get_parsed("ping.interval.ms")?.unwrap_or(30_000);
80 let ping_timeout_ms: u64 = config.get_parsed("ping.timeout.ms")?.unwrap_or(10_000);
81 let replay_buffer_size: Option<usize> = config.get_parsed("replay.buffer.size")?;
82 let path = config.get("path").map(ToString::to_string);
83
84 let slow_client_policy =
85 match config.get("slow.client.policy").map(str::to_lowercase) {
86 Some(ref s) if s == "drop_oldest" => SlowClientPolicy::DropOldest,
87 Some(ref s) if s == "drop_newest" => SlowClientPolicy::DropNewest,
88 Some(ref s) if s == "disconnect" => SlowClientPolicy::Disconnect {
89 threshold_pct: config
90 .get_parsed("slow.client.threshold.pct")?
91 .unwrap_or(90),
92 },
93 _ => SlowClientPolicy::default(),
94 };
95
96 SinkMode::Server {
97 bind_address,
98 path,
99 max_connections,
100 per_client_buffer,
101 slow_client_policy,
102 ping_interval: Duration::from_millis(ping_interval_ms),
103 ping_timeout: Duration::from_millis(ping_timeout_ms),
104 enable_subscription_filter: false,
105 replay_buffer_size,
106 }
107 }
108 "client" => {
109 let url = config.require("url").map(ToString::to_string)?;
110 let buffer_on_disconnect: Option<usize> =
111 config.get_parsed("buffer.on.disconnect")?;
112 let batch_max_size: Option<usize> = config.get_parsed("batch.max.size")?;
113 let batch_interval_ms: Option<u64> = config.get_parsed("batch.interval.ms")?;
114
115 SinkMode::Client {
116 url,
117 reconnect: ReconnectConfig::default(),
118 buffer_on_disconnect,
119 batch_interval: batch_interval_ms.map(Duration::from_millis),
120 batch_max_size,
121 }
122 }
123 other => {
124 return Err(ConnectorError::ConfigurationError(format!(
125 "invalid WebSocket sink mode '{other}': expected 'server' or 'client'"
126 )));
127 }
128 };
129
130 let format = match config.get("format").map(str::to_lowercase) {
131 Some(ref s) if s == "json" => SinkFormat::Json,
132 Some(ref s) if s == "jsonlines" || s == "json_lines" => SinkFormat::JsonLines,
133 Some(ref s) if s == "arrow_ipc" || s == "arowipc" => SinkFormat::ArrowIpc,
134 Some(ref s) if s == "binary" => SinkFormat::Binary,
135 Some(ref other) => {
136 return Err(ConnectorError::ConfigurationError(format!(
137 "invalid sink format '{other}': expected json, jsonlines, arrow_ipc, or binary"
138 )));
139 }
140 None => SinkFormat::Json,
141 };
142
143 let auth = match config.get("auth.type").map(str::to_lowercase) {
144 Some(ref s) if s == "bearer" => {
145 let token = config.require("auth.token").map(ToString::to_string)?;
146 Some(WsAuthConfig::Bearer { token })
147 }
148 Some(ref s) if s == "basic" => {
149 let username = config.require("auth.username").map(ToString::to_string)?;
150 let password = config.require("auth.password").map(ToString::to_string)?;
151 Some(WsAuthConfig::Basic { username, password })
152 }
153 Some(ref s) if s == "hmac" => {
154 let api_key = config.require("auth.api.key").map(ToString::to_string)?;
155 let secret = config.require("auth.secret").map(ToString::to_string)?;
156 Some(WsAuthConfig::Hmac { api_key, secret })
157 }
158 _ => None,
159 };
160
161 Ok(Self { mode, format, auth })
162 }
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
171#[serde(tag = "type")]
172pub enum SinkMode {
173 Server {
175 bind_address: String,
177 path: Option<String>,
179 #[serde(default = "default_max_connections")]
181 max_connections: usize,
182 #[serde(default = "default_per_client_buffer")]
184 per_client_buffer: usize,
185 #[serde(default)]
187 slow_client_policy: SlowClientPolicy,
188 #[serde(default = "default_ping_interval")]
190 ping_interval: Duration,
191 #[serde(default = "default_ping_timeout")]
193 ping_timeout: Duration,
194 #[serde(default)]
196 enable_subscription_filter: bool,
197 replay_buffer_size: Option<usize>,
200 },
201 Client {
203 url: String,
205 #[serde(default)]
207 reconnect: ReconnectConfig,
208 buffer_on_disconnect: Option<usize>,
211 batch_interval: Option<Duration>,
213 batch_max_size: Option<usize>,
215 },
216}
217
218impl Default for SinkMode {
219 fn default() -> Self {
220 Self::Server {
221 bind_address: "0.0.0.0:8080".to_string(),
222 path: None,
223 max_connections: default_max_connections(),
224 per_client_buffer: default_per_client_buffer(),
225 slow_client_policy: SlowClientPolicy::default(),
226 ping_interval: default_ping_interval(),
227 ping_timeout: default_ping_timeout(),
228 enable_subscription_filter: false,
229 replay_buffer_size: None,
230 }
231 }
232}
233
234#[derive(Debug, Clone, Default, Serialize, Deserialize)]
240pub enum SlowClientPolicy {
241 #[default]
243 DropOldest,
244 DropNewest,
246 Disconnect {
249 threshold_pct: u8,
252 },
253 WarnThenDisconnect {
256 warn_pct: u8,
258 disconnect_pct: u8,
261 },
262}
263
264#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
270pub enum SinkFormat {
271 #[default]
273 Json,
274 JsonLines,
276 ArrowIpc,
278 Binary,
280}
281
282#[cfg(test)]
287mod tests {
288 use super::*;
289
290 #[test]
293 fn test_slow_client_policy_default() {
294 let policy = SlowClientPolicy::default();
295 assert!(matches!(policy, SlowClientPolicy::DropOldest));
296 }
297
298 #[test]
299 fn test_sink_format_default() {
300 let format = SinkFormat::default();
301 assert_eq!(format, SinkFormat::Json);
302 }
303
304 #[test]
305 fn test_sink_mode_default_is_server() {
306 let mode = SinkMode::default();
307 match mode {
308 SinkMode::Server {
309 bind_address,
310 path,
311 max_connections,
312 per_client_buffer,
313 slow_client_policy,
314 ping_interval,
315 ping_timeout,
316 enable_subscription_filter,
317 replay_buffer_size,
318 } => {
319 assert_eq!(bind_address, "0.0.0.0:8080");
320 assert!(path.is_none());
321 assert_eq!(max_connections, 10_000);
322 assert_eq!(per_client_buffer, 262_144);
323 assert!(matches!(slow_client_policy, SlowClientPolicy::DropOldest));
324 assert_eq!(ping_interval, Duration::from_secs(30));
325 assert_eq!(ping_timeout, Duration::from_secs(10));
326 assert!(!enable_subscription_filter);
327 assert!(replay_buffer_size.is_none());
328 }
329 SinkMode::Client { .. } => panic!("expected Server, got Client"),
330 }
331 }
332
333 #[test]
334 fn test_websocket_sink_config_default() {
335 let config = WebSocketSinkConfig::default();
336 assert_eq!(config.format, SinkFormat::Json);
337 assert!(config.auth.is_none());
338 assert!(matches!(config.mode, SinkMode::Server { .. }));
339 }
340
341 #[test]
344 fn test_default_max_connections() {
345 assert_eq!(default_max_connections(), 10_000);
346 }
347
348 #[test]
349 fn test_default_per_client_buffer() {
350 assert_eq!(default_per_client_buffer(), 262_144);
351 }
352
353 #[test]
354 fn test_default_ping_interval() {
355 assert_eq!(default_ping_interval(), Duration::from_secs(30));
356 }
357
358 #[test]
359 fn test_default_ping_timeout() {
360 assert_eq!(default_ping_timeout(), Duration::from_secs(10));
361 }
362
363 #[test]
366 fn test_server_mode_serde_roundtrip() {
367 let mode = SinkMode::Server {
368 bind_address: "127.0.0.1:9090".to_string(),
369 path: Some("/ws".to_string()),
370 max_connections: 500,
371 per_client_buffer: 1024,
372 slow_client_policy: SlowClientPolicy::DropNewest,
373 ping_interval: Duration::from_secs(15),
374 ping_timeout: Duration::from_secs(5),
375 enable_subscription_filter: true,
376 replay_buffer_size: Some(1000),
377 };
378 let json = serde_json::to_string(&mode).unwrap();
379 let deser: SinkMode = serde_json::from_str(&json).unwrap();
380 match deser {
381 SinkMode::Server {
382 bind_address,
383 path,
384 max_connections,
385 per_client_buffer,
386 replay_buffer_size,
387 enable_subscription_filter,
388 ..
389 } => {
390 assert_eq!(bind_address, "127.0.0.1:9090");
391 assert_eq!(path.as_deref(), Some("/ws"));
392 assert_eq!(max_connections, 500);
393 assert_eq!(per_client_buffer, 1024);
394 assert!(enable_subscription_filter);
395 assert_eq!(replay_buffer_size, Some(1000));
396 }
397 SinkMode::Client { .. } => panic!("expected Server"),
398 }
399 }
400
401 #[test]
402 fn test_client_mode_serde_roundtrip() {
403 let mode = SinkMode::Client {
404 url: "wss://example.com/feed".to_string(),
405 reconnect: ReconnectConfig::default(),
406 buffer_on_disconnect: Some(5000),
407 batch_interval: Some(Duration::from_millis(100)),
408 batch_max_size: Some(64),
409 };
410 let json = serde_json::to_string(&mode).unwrap();
411 let deser: SinkMode = serde_json::from_str(&json).unwrap();
412 match deser {
413 SinkMode::Client {
414 url,
415 buffer_on_disconnect,
416 batch_interval,
417 batch_max_size,
418 ..
419 } => {
420 assert_eq!(url, "wss://example.com/feed");
421 assert_eq!(buffer_on_disconnect, Some(5000));
422 assert_eq!(batch_interval, Some(Duration::from_millis(100)));
423 assert_eq!(batch_max_size, Some(64));
424 }
425 SinkMode::Server { .. } => panic!("expected Client"),
426 }
427 }
428
429 #[test]
430 fn test_sink_format_serde_variants() {
431 for (format, expected) in [
432 (SinkFormat::Json, "\"Json\""),
433 (SinkFormat::JsonLines, "\"JsonLines\""),
434 (SinkFormat::ArrowIpc, "\"ArrowIpc\""),
435 (SinkFormat::Binary, "\"Binary\""),
436 ] {
437 let json = serde_json::to_string(&format).unwrap();
438 assert_eq!(json, expected);
439 let deser: SinkFormat = serde_json::from_str(&json).unwrap();
440 assert_eq!(deser, format);
441 }
442 }
443
444 #[test]
445 fn test_slow_client_policy_serde_variants() {
446 let disconnect = SlowClientPolicy::Disconnect { threshold_pct: 90 };
447 let json = serde_json::to_string(&disconnect).unwrap();
448 let deser: SlowClientPolicy = serde_json::from_str(&json).unwrap();
449 assert!(matches!(
450 deser,
451 SlowClientPolicy::Disconnect { threshold_pct: 90 }
452 ));
453
454 let warn = SlowClientPolicy::WarnThenDisconnect {
455 warn_pct: 75,
456 disconnect_pct: 95,
457 };
458 let json = serde_json::to_string(&warn).unwrap();
459 let deser: SlowClientPolicy = serde_json::from_str(&json).unwrap();
460 assert!(matches!(
461 deser,
462 SlowClientPolicy::WarnThenDisconnect {
463 warn_pct: 75,
464 disconnect_pct: 95,
465 }
466 ));
467 }
468
469 #[test]
470 fn test_full_config_serde_roundtrip() {
471 let config = WebSocketSinkConfig {
472 mode: SinkMode::Server {
473 bind_address: "0.0.0.0:3000".to_string(),
474 path: None,
475 max_connections: 2000,
476 per_client_buffer: 131_072,
477 slow_client_policy: SlowClientPolicy::WarnThenDisconnect {
478 warn_pct: 80,
479 disconnect_pct: 95,
480 },
481 ping_interval: Duration::from_secs(20),
482 ping_timeout: Duration::from_secs(8),
483 enable_subscription_filter: false,
484 replay_buffer_size: Some(500),
485 },
486 format: SinkFormat::ArrowIpc,
487 auth: None,
488 };
489
490 let json = serde_json::to_string_pretty(&config).unwrap();
491 let deser: WebSocketSinkConfig = serde_json::from_str(&json).unwrap();
492 assert_eq!(deser.format, SinkFormat::ArrowIpc);
493 assert!(deser.auth.is_none());
494 match deser.mode {
495 SinkMode::Server {
496 max_connections,
497 replay_buffer_size,
498 ..
499 } => {
500 assert_eq!(max_connections, 2000);
501 assert_eq!(replay_buffer_size, Some(500));
502 }
503 SinkMode::Client { .. } => panic!("expected Server"),
504 }
505 }
506
507 #[test]
510 fn test_server_mode_serde_defaults_applied() {
511 let json = r#"{
513 "type": "Server",
514 "bind_address": "0.0.0.0:4000"
515 }"#;
516 let mode: SinkMode = serde_json::from_str(json).unwrap();
517 match mode {
518 SinkMode::Server {
519 max_connections,
520 per_client_buffer,
521 slow_client_policy,
522 ping_interval,
523 ping_timeout,
524 enable_subscription_filter,
525 replay_buffer_size,
526 ..
527 } => {
528 assert_eq!(max_connections, 10_000);
529 assert_eq!(per_client_buffer, 262_144);
530 assert!(matches!(slow_client_policy, SlowClientPolicy::DropOldest));
531 assert_eq!(ping_interval, Duration::from_secs(30));
532 assert_eq!(ping_timeout, Duration::from_secs(10));
533 assert!(!enable_subscription_filter);
534 assert!(replay_buffer_size.is_none());
535 }
536 SinkMode::Client { .. } => panic!("expected Server"),
537 }
538 }
539
540 #[test]
543 fn test_from_config_server_mode() {
544 let mut config = ConnectorConfig::new("websocket");
545 config.set("bind.address", "0.0.0.0:3000");
546 config.set("max.connections", "2000");
547 config.set("format", "jsonlines");
548
549 let cfg = WebSocketSinkConfig::from_config(&config).unwrap();
550 assert_eq!(cfg.format, SinkFormat::JsonLines);
551 match cfg.mode {
552 SinkMode::Server {
553 bind_address,
554 max_connections,
555 ..
556 } => {
557 assert_eq!(bind_address, "0.0.0.0:3000");
558 assert_eq!(max_connections, 2000);
559 }
560 SinkMode::Client { .. } => panic!("expected Server mode"),
561 }
562 }
563
564 #[test]
565 fn test_from_config_client_mode() {
566 let mut config = ConnectorConfig::new("websocket");
567 config.set("mode", "client");
568 config.set("url", "wss://upstream.example.com/feed");
569 config.set("format", "json");
570
571 let cfg = WebSocketSinkConfig::from_config(&config).unwrap();
572 assert_eq!(cfg.format, SinkFormat::Json);
573 match cfg.mode {
574 SinkMode::Client { url, .. } => {
575 assert_eq!(url, "wss://upstream.example.com/feed");
576 }
577 SinkMode::Server { .. } => panic!("expected Client mode"),
578 }
579 }
580
581 #[test]
582 fn test_from_config_missing_bind_address_errors() {
583 let config = ConnectorConfig::new("websocket");
584 let result = WebSocketSinkConfig::from_config(&config);
586 assert!(result.is_err());
587 assert!(result.unwrap_err().to_string().contains("bind.address"));
588 }
589
590 #[test]
591 fn test_from_config_bearer_auth() {
592 let mut config = ConnectorConfig::new("websocket");
593 config.set("bind.address", "0.0.0.0:8080");
594 config.set("auth.type", "bearer");
595 config.set("auth.token", "my-secret");
596
597 let cfg = WebSocketSinkConfig::from_config(&config).unwrap();
598 assert!(matches!(
599 cfg.auth,
600 Some(WsAuthConfig::Bearer { ref token }) if token == "my-secret"
601 ));
602 }
603}