Skip to main content

laminar_connectors/websocket/
sink.rs

1//! WebSocket sink connector — server mode.
2//!
3//! [`WebSocketSinkServer`] hosts a WebSocket endpoint that connected clients
4//! subscribe to for streaming query results. Implements per-client isolation
5//! via the [`FanoutManager`].
6
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::Arc;
9use std::time::Duration;
10
11use arrow_array::RecordBatch;
12use arrow_schema::SchemaRef;
13use async_trait::async_trait;
14use bytes::Bytes;
15use futures_util::{SinkExt, StreamExt};
16use tokio::net::TcpListener;
17use tracing::{debug, info, warn};
18
19use crate::config::{ConnectorConfig, ConnectorState};
20use crate::connector::{SinkConnector, SinkConnectorCapabilities, WriteResult};
21use crate::error::ConnectorError;
22use crate::health::HealthStatus;
23use crate::metrics::ConnectorMetrics;
24
25use super::fanout::FanoutManager;
26use super::protocol::{ClientMessage, ServerMessage};
27use super::serializer::BatchSerializer;
28use super::sink_config::{SinkMode, SlowClientPolicy, WebSocketSinkConfig};
29use super::sink_metrics::WebSocketSinkMetrics;
30
31/// WebSocket sink connector in server mode.
32///
33/// Hosts a WebSocket server. Connected clients subscribe and receive
34/// streaming query results via the fan-out manager.
35pub struct WebSocketSinkServer {
36    /// Configuration.
37    config: WebSocketSinkConfig,
38    /// Input Arrow schema.
39    schema: SchemaRef,
40    /// Serializer for `RecordBatch` → JSON/Binary.
41    serializer: BatchSerializer,
42    /// Fan-out manager for per-client message distribution.
43    fanout: Arc<FanoutManager>,
44    /// Connector state.
45    state: ConnectorState,
46    /// Metrics.
47    metrics: Arc<WebSocketSinkMetrics>,
48    /// Current epoch.
49    current_epoch: u64,
50    /// Shutdown signal sender.
51    shutdown_tx: Option<tokio::sync::watch::Sender<bool>>,
52    /// Acceptor task handle.
53    acceptor_handle: Option<tokio::task::JoinHandle<()>>,
54    /// Global sequence counter (shared with fanout).
55    sequence: Arc<AtomicU64>,
56}
57
58impl WebSocketSinkServer {
59    /// Creates a new WebSocket sink server connector.
60    #[must_use]
61    pub fn new(
62        schema: SchemaRef,
63        config: WebSocketSinkConfig,
64        registry: Option<&prometheus::Registry>,
65    ) -> Self {
66        let serializer = BatchSerializer::new(config.format.clone());
67
68        let (buffer_capacity, policy, replay_size) = match &config.mode {
69            SinkMode::Server {
70                per_client_buffer,
71                slow_client_policy,
72                replay_buffer_size,
73                ..
74            } => {
75                // Convert bytes to approximate message count (assume ~256 bytes/msg).
76                let msg_capacity = (*per_client_buffer / 256).max(1);
77                (
78                    msg_capacity,
79                    slow_client_policy.clone(),
80                    *replay_buffer_size,
81                )
82            }
83            SinkMode::Client { .. } => (1024, SlowClientPolicy::DropOldest, None),
84        };
85
86        let fanout = Arc::new(FanoutManager::new(policy, buffer_capacity, replay_size));
87
88        Self {
89            config,
90            schema,
91            serializer,
92            fanout,
93            state: ConnectorState::Created,
94            metrics: Arc::new(WebSocketSinkMetrics::new(registry)),
95            current_epoch: 0,
96            shutdown_tx: None,
97            acceptor_handle: None,
98            sequence: Arc::new(AtomicU64::new(0)),
99        }
100    }
101
102    /// Returns the current connector state.
103    #[must_use]
104    pub fn state(&self) -> ConnectorState {
105        self.state
106    }
107
108    /// Returns the number of connected clients.
109    #[must_use]
110    pub fn connected_clients(&self) -> usize {
111        self.fanout.client_count()
112    }
113
114    /// Returns a reference to the fan-out manager.
115    #[must_use]
116    pub fn fanout(&self) -> &Arc<FanoutManager> {
117        &self.fanout
118    }
119}
120
121#[async_trait]
122#[allow(clippy::too_many_lines)]
123impl SinkConnector for WebSocketSinkServer {
124    async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
125        self.state = ConnectorState::Initializing;
126
127        // If config has properties, re-parse (supports runtime config via SQL WITH).
128        if !config.properties().is_empty() {
129            self.config = WebSocketSinkConfig::from_config(config)?;
130        }
131
132        let (bind_address, max_connections, _path, ping_interval, ping_timeout) = match &self
133            .config
134            .mode
135        {
136            SinkMode::Server {
137                bind_address,
138                max_connections,
139                path,
140                ping_interval,
141                ping_timeout,
142                ..
143            } => (
144                bind_address.clone(),
145                *max_connections,
146                path.clone(),
147                *ping_interval,
148                *ping_timeout,
149            ),
150            SinkMode::Client { .. } => {
151                return Err(ConnectorError::ConfigurationError(
152                        "WebSocketSinkServer is for server mode; use WebSocketSinkClient for client mode".into(),
153                    ));
154            }
155        };
156
157        info!(
158            bind = %bind_address,
159            max_connections,
160            format = ?self.config.format,
161            "opening WebSocket sink server"
162        );
163
164        let listener = TcpListener::bind(&bind_address).await.map_err(|e| {
165            ConnectorError::ConnectionFailed(format!("failed to bind {bind_address}: {e}"))
166        })?;
167
168        let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
169        let fanout = Arc::clone(&self.fanout);
170        let metrics = Arc::clone(&self.metrics);
171
172        let handle = tokio::spawn(async move {
173            let mut shutdown_rx = shutdown_rx;
174
175            loop {
176                tokio::select! {
177                    accept_result = listener.accept() => {
178                        match accept_result {
179                            Ok((stream, addr)) => {
180                                if fanout.client_count() >= max_connections {
181                                    warn!(addr = %addr, "rejecting: max_connections exceeded");
182                                    drop(stream);
183                                    continue;
184                                }
185
186                                let _ = stream.set_nodelay(true);
187                                let fanout = Arc::clone(&fanout);
188                                let metrics = metrics.clone();
189                                let mut client_shutdown = shutdown_rx.clone();
190                                let client_ping_interval = ping_interval;
191                                let client_ping_timeout = ping_timeout;
192
193                                tokio::spawn(async move {
194                                    // Cap incoming frames at 1 MiB — clients only send
195                                    // small control messages (subscribe/unsubscribe/ping).
196                                    let mut ws_config = tungstenite::protocol::WebSocketConfig::default();
197                                    ws_config.max_message_size = Some(1024 * 1024);
198                                    ws_config.max_frame_size = Some(1024 * 1024);
199                                    let ws_stream = match tokio_tungstenite::accept_async_with_config(stream, Some(ws_config)).await {
200                                        Ok(ws) => ws,
201                                        Err(e) => {
202                                            warn!(addr = %addr, error = %e, "handshake failed");
203                                            return;
204                                        }
205                                    };
206
207                                    let (mut write, mut read) = ws_stream.split();
208
209                                    // Wait for a subscribe message or auto-subscribe.
210                                    let sub_id = format!("sub_{}", addr.port());
211                                    let filter = None;
212
213                                    // Check for initial subscribe message (with timeout).
214                                    let (filter, last_seq) = match tokio::time::timeout(
215                                        std::time::Duration::from_secs(5),
216                                        read.next(),
217                                    )
218                                    .await
219                                    {
220                                        Ok(Some(Ok(tungstenite::Message::Text(text)))) => {
221                                            match serde_json::from_str::<ClientMessage>(text.as_ref()) {
222                                                Ok(ClientMessage::Subscribe {
223                                                    filter,
224                                                    last_sequence,
225                                                    ..
226                                                }) => (filter, last_sequence),
227                                                _ => (None, None),
228                                            }
229                                        }
230                                        Ok(Some(Err(e))) => {
231                                            warn!(addr = %addr, error = %e, "client read error during subscribe, rejecting");
232                                            return;
233                                        }
234                                        _ => (filter, None),
235                                    };
236
237                                    // Register the client.
238                                    let (client_id, rx) =
239                                        fanout.add_client(sub_id.clone(), filter, None);
240
241                                    metrics.record_connect();
242
243                                    // Send subscription confirmation.
244                                    let confirm = ServerMessage::Subscribed {
245                                        subscription_id: sub_id.clone(),
246                                    };
247                                    if let Ok(json) = serde_json::to_string(&confirm) {
248                                        let _ = write
249                                            .send(tungstenite::Message::Text(json.into()))
250                                            .await;
251                                    }
252
253                                    // Replay if requested.
254                                    if let Some(seq) = last_seq {
255                                        let replay_msgs = fanout.replay_from(seq);
256                                        metrics.record_replay();
257                                        for (_seq, data) in replay_msgs {
258                                            if write
259                                                .send(tungstenite::Message::Text(
260                                                    String::from_utf8_lossy(&data).into_owned().into(),
261                                                ))
262                                                .await
263                                                .is_err()
264                                            {
265                                                break;
266                                            }
267                                        }
268                                    }
269
270                                    // Fan-out loop with ping/pong heartbeats.
271                                    let mut ping_ticker = tokio::time::interval(client_ping_interval);
272                                    ping_ticker.tick().await; // consume initial immediate tick
273                                    let mut awaiting_pong = false;
274                                    let mut last_ping_sent = tokio::time::Instant::now();
275
276                                    loop {
277                                        tokio::select! {
278                                            Some(data) = rx.recv() => {
279                                                let text = String::from_utf8_lossy(&data).into_owned();
280                                                if write.send(tungstenite::Message::Text(text.into())).await.is_err() {
281                                                    break;
282                                                }
283                                                metrics.record_send(data.len() as u64);
284                                            }
285                                            msg = read.next() => {
286                                                match msg {
287                                                    Some(Ok(tungstenite::Message::Close(_))) | None => break,
288                                                    Some(Ok(tungstenite::Message::Pong(_))) => {
289                                                        awaiting_pong = false;
290                                                    }
291                                                    Some(Ok(tungstenite::Message::Text(text))) => {
292                                                        if let Ok(ClientMessage::Unsubscribe { .. }) =
293                                                            serde_json::from_str::<ClientMessage>(text.as_ref())
294                                                        {
295                                                            break;
296                                                        }
297                                                    }
298                                                    Some(Err(e)) => {
299                                                        warn!(addr = %addr, error = %e, "client read error, disconnecting");
300                                                        break;
301                                                    }
302                                                    _ => {}
303                                                }
304                                            }
305                                            _ = ping_ticker.tick() => {
306                                                if awaiting_pong && last_ping_sent.elapsed() > client_ping_timeout {
307                                                    debug!(addr = %addr, "ping timeout — disconnecting");
308                                                    metrics.record_ping_timeout();
309                                                    break;
310                                                }
311                                                if write.send(tungstenite::Message::Ping(bytes::Bytes::new())).await.is_err() {
312                                                    break;
313                                                }
314                                                awaiting_pong = true;
315                                                last_ping_sent = tokio::time::Instant::now();
316                                            }
317                                            _ = client_shutdown.changed() => break,
318                                        }
319                                    }
320
321                                    fanout.remove_client(client_id);
322                                    metrics.record_disconnect();
323                                    debug!(addr = %addr, "sink client disconnected");
324                                });
325                            }
326                            Err(e) => {
327                                warn!(error = %e, "accept error");
328                            }
329                        }
330                    }
331                    _ = shutdown_rx.changed() => {
332                        info!("sink server acceptor shutting down");
333                        break;
334                    }
335                }
336            }
337        });
338
339        self.shutdown_tx = Some(shutdown_tx);
340        self.acceptor_handle = Some(handle);
341        self.state = ConnectorState::Running;
342
343        info!(bind = %bind_address, "WebSocket sink server started");
344        Ok(())
345    }
346
347    #[allow(clippy::cast_possible_truncation)]
348    async fn write_batch(&mut self, batch: &RecordBatch) -> Result<WriteResult, ConnectorError> {
349        if self.state != ConnectorState::Running {
350            return Err(ConnectorError::InvalidState {
351                expected: "Running".into(),
352                actual: self.state.to_string(),
353            });
354        }
355
356        if self.fanout.client_count() == 0 {
357            // No clients connected — discard.
358            return Ok(WriteResult::new(0, 0));
359        }
360
361        // Serialize the batch to JSON.
362        let json = self.serializer.serialize_to_json(batch)?;
363        let seq = self.sequence.fetch_add(1, Ordering::Relaxed) + 1;
364
365        let msg = ServerMessage::Data {
366            subscription_id: String::new(), // broadcast to all
367            data: json,
368            sequence: seq,
369            watermark: None,
370        };
371
372        let serialized = serde_json::to_vec(&msg)
373            .map_err(|e| ConnectorError::Serde(crate::error::SerdeError::Json(e.to_string())))?;
374
375        let bytes_len = serialized.len() as u64;
376        let data = Bytes::from(serialized);
377
378        let result = self.fanout.broadcast(data);
379
380        self.metrics.record_send(bytes_len);
381        if result.dropped > 0 {
382            for _ in 0..result.dropped {
383                self.metrics.record_drop();
384            }
385        }
386
387        debug!(
388            records = batch.num_rows(),
389            sent = result.sent,
390            dropped = result.dropped,
391            sequence = result.sequence,
392            "broadcast batch to WebSocket clients"
393        );
394
395        Ok(WriteResult::new(batch.num_rows(), bytes_len))
396    }
397
398    fn schema(&self) -> SchemaRef {
399        self.schema.clone()
400    }
401
402    fn health_check(&self) -> HealthStatus {
403        match self.state {
404            ConnectorState::Running => HealthStatus::Healthy,
405            ConnectorState::Created | ConnectorState::Initializing => HealthStatus::Unknown,
406            ConnectorState::Paused => HealthStatus::Degraded("connector paused".into()),
407            ConnectorState::Recovering => HealthStatus::Degraded("recovering".into()),
408            ConnectorState::Closed => HealthStatus::Unhealthy("closed".into()),
409            ConnectorState::Failed => HealthStatus::Unhealthy("failed".into()),
410        }
411    }
412
413    fn metrics(&self) -> ConnectorMetrics {
414        self.metrics.to_connector_metrics()
415    }
416
417    fn capabilities(&self) -> SinkConnectorCapabilities {
418        // In-memory fanout — timeout is effectively unreachable.
419        SinkConnectorCapabilities::new(Duration::from_secs(10))
420    }
421
422    async fn begin_epoch(&mut self, epoch: u64) -> Result<(), ConnectorError> {
423        self.current_epoch = epoch;
424        Ok(())
425    }
426
427    async fn commit_epoch(&mut self, _epoch: u64) -> Result<(), ConnectorError> {
428        Ok(())
429    }
430
431    async fn close(&mut self) -> Result<(), ConnectorError> {
432        info!("closing WebSocket sink server");
433
434        if let Some(tx) = self.shutdown_tx.take() {
435            let _ = tx.send(true);
436        }
437
438        if let Some(handle) = self.acceptor_handle.take() {
439            let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
440        }
441
442        self.state = ConnectorState::Closed;
443        info!("WebSocket sink server closed");
444        Ok(())
445    }
446}
447
448impl std::fmt::Debug for WebSocketSinkServer {
449    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
450        f.debug_struct("WebSocketSinkServer")
451            .field("state", &self.state)
452            .field("connected_clients", &self.connected_clients())
453            .field("format", &self.config.format)
454            .field("current_epoch", &self.current_epoch)
455            .finish_non_exhaustive()
456    }
457}
458
459#[cfg(test)]
460mod tests {
461    use super::super::sink_config::SinkFormat;
462    use super::*;
463    use arrow_schema::{DataType, Field, Schema};
464
465    fn test_schema() -> SchemaRef {
466        Arc::new(Schema::new(vec![
467            Field::new("id", DataType::Int64, false),
468            Field::new("value", DataType::Utf8, false),
469        ]))
470    }
471
472    fn test_config() -> WebSocketSinkConfig {
473        WebSocketSinkConfig {
474            mode: SinkMode::Server {
475                bind_address: "127.0.0.1:0".into(),
476                path: None,
477                max_connections: 100,
478                per_client_buffer: 262_144,
479                slow_client_policy: SlowClientPolicy::DropOldest,
480                ping_interval: std::time::Duration::from_secs(30),
481                ping_timeout: std::time::Duration::from_secs(10),
482                enable_subscription_filter: false,
483                replay_buffer_size: None,
484            },
485            format: SinkFormat::Json,
486            auth: None,
487        }
488    }
489
490    #[test]
491    fn test_new() {
492        let sink = WebSocketSinkServer::new(test_schema(), test_config(), None);
493        assert_eq!(sink.state(), ConnectorState::Created);
494        assert_eq!(sink.connected_clients(), 0);
495    }
496
497    #[test]
498    fn test_schema_returned() {
499        let schema = test_schema();
500        let sink = WebSocketSinkServer::new(schema.clone(), test_config(), None);
501        assert_eq!(sink.schema(), schema);
502    }
503
504    #[test]
505    fn test_capabilities() {
506        let sink = WebSocketSinkServer::new(test_schema(), test_config(), None);
507        let caps = sink.capabilities();
508        assert!(!caps.exactly_once);
509        assert!(!caps.upsert);
510    }
511
512    #[test]
513    fn test_health_created() {
514        let sink = WebSocketSinkServer::new(test_schema(), test_config(), None);
515        assert_eq!(sink.health_check(), HealthStatus::Unknown);
516    }
517
518    #[test]
519    fn test_metrics_initial() {
520        let sink = WebSocketSinkServer::new(test_schema(), test_config(), None);
521        let m = sink.metrics();
522        assert_eq!(m.records_total, 0);
523    }
524}