Skip to main content

laminar_connectors/websocket/
sink.rs

1//! WebSocket sink connector — server mode.
2//!
3//! [`WebSocketSinkServer`] hosts a WebSocket endpoint that connected clients
4//! subscribe to for streaming query results. Implements per-client isolation
5//! via the [`FanoutManager`].
6
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::Arc;
9
10use arrow_array::RecordBatch;
11use arrow_schema::SchemaRef;
12use async_trait::async_trait;
13use bytes::Bytes;
14use futures_util::{SinkExt, StreamExt};
15use tokio::net::TcpListener;
16use tracing::{debug, info, warn};
17
18use crate::config::{ConnectorConfig, ConnectorState};
19use crate::connector::{SinkConnector, SinkConnectorCapabilities, WriteResult};
20use crate::error::ConnectorError;
21use crate::health::HealthStatus;
22use crate::metrics::ConnectorMetrics;
23
24use super::fanout::FanoutManager;
25use super::protocol::{ClientMessage, ServerMessage};
26use super::serializer::BatchSerializer;
27use super::sink_config::{SinkMode, SlowClientPolicy, WebSocketSinkConfig};
28use super::sink_metrics::WebSocketSinkMetrics;
29
30/// WebSocket sink connector in server mode.
31///
32/// Hosts a WebSocket server. Connected clients subscribe and receive
33/// streaming query results via the fan-out manager.
34pub struct WebSocketSinkServer {
35    /// Configuration.
36    config: WebSocketSinkConfig,
37    /// Input Arrow schema.
38    schema: SchemaRef,
39    /// Serializer for `RecordBatch` → JSON/Binary.
40    serializer: BatchSerializer,
41    /// Fan-out manager for per-client message distribution.
42    fanout: Arc<FanoutManager>,
43    /// Connector state.
44    state: ConnectorState,
45    /// Metrics.
46    metrics: Arc<WebSocketSinkMetrics>,
47    /// Current epoch.
48    current_epoch: u64,
49    /// Shutdown signal sender.
50    shutdown_tx: Option<tokio::sync::watch::Sender<bool>>,
51    /// Acceptor task handle.
52    acceptor_handle: Option<tokio::task::JoinHandle<()>>,
53    /// Global sequence counter (shared with fanout).
54    sequence: Arc<AtomicU64>,
55}
56
57impl WebSocketSinkServer {
58    /// Creates a new WebSocket sink server connector.
59    #[must_use]
60    pub fn new(schema: SchemaRef, config: WebSocketSinkConfig) -> Self {
61        let serializer = BatchSerializer::new(config.format.clone());
62
63        let (buffer_capacity, policy, replay_size) = match &config.mode {
64            SinkMode::Server {
65                per_client_buffer,
66                slow_client_policy,
67                replay_buffer_size,
68                ..
69            } => {
70                // Convert bytes to approximate message count (assume ~256 bytes/msg).
71                let msg_capacity = (*per_client_buffer / 256).max(1);
72                (
73                    msg_capacity,
74                    slow_client_policy.clone(),
75                    *replay_buffer_size,
76                )
77            }
78            SinkMode::Client { .. } => (1024, SlowClientPolicy::DropOldest, None),
79        };
80
81        let fanout = Arc::new(FanoutManager::new(policy, buffer_capacity, replay_size));
82
83        Self {
84            config,
85            schema,
86            serializer,
87            fanout,
88            state: ConnectorState::Created,
89            metrics: Arc::new(WebSocketSinkMetrics::new()),
90            current_epoch: 0,
91            shutdown_tx: None,
92            acceptor_handle: None,
93            sequence: Arc::new(AtomicU64::new(0)),
94        }
95    }
96
97    /// Returns the current connector state.
98    #[must_use]
99    pub fn state(&self) -> ConnectorState {
100        self.state
101    }
102
103    /// Returns the number of connected clients.
104    #[must_use]
105    pub fn connected_clients(&self) -> usize {
106        self.fanout.client_count()
107    }
108
109    /// Returns a reference to the fan-out manager.
110    #[must_use]
111    pub fn fanout(&self) -> &Arc<FanoutManager> {
112        &self.fanout
113    }
114}
115
116#[async_trait]
117#[allow(clippy::too_many_lines)]
118impl SinkConnector for WebSocketSinkServer {
119    async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
120        self.state = ConnectorState::Initializing;
121
122        // If config has properties, re-parse (supports runtime config via SQL WITH).
123        if !config.properties().is_empty() {
124            self.config = WebSocketSinkConfig::from_config(config)?;
125        }
126
127        let (bind_address, max_connections, _path, ping_interval, ping_timeout) = match &self
128            .config
129            .mode
130        {
131            SinkMode::Server {
132                bind_address,
133                max_connections,
134                path,
135                ping_interval,
136                ping_timeout,
137                ..
138            } => (
139                bind_address.clone(),
140                *max_connections,
141                path.clone(),
142                *ping_interval,
143                *ping_timeout,
144            ),
145            SinkMode::Client { .. } => {
146                return Err(ConnectorError::ConfigurationError(
147                        "WebSocketSinkServer is for server mode; use WebSocketSinkClient for client mode".into(),
148                    ));
149            }
150        };
151
152        info!(
153            bind = %bind_address,
154            max_connections,
155            format = ?self.config.format,
156            "opening WebSocket sink server"
157        );
158
159        let listener = TcpListener::bind(&bind_address).await.map_err(|e| {
160            ConnectorError::ConnectionFailed(format!("failed to bind {bind_address}: {e}"))
161        })?;
162
163        let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
164        let fanout = Arc::clone(&self.fanout);
165        let metrics = Arc::clone(&self.metrics);
166
167        let handle = tokio::spawn(async move {
168            let mut shutdown_rx = shutdown_rx;
169
170            loop {
171                tokio::select! {
172                    accept_result = listener.accept() => {
173                        match accept_result {
174                            Ok((stream, addr)) => {
175                                if fanout.client_count() >= max_connections {
176                                    warn!(addr = %addr, "rejecting: max_connections exceeded");
177                                    drop(stream);
178                                    continue;
179                                }
180
181                                let _ = stream.set_nodelay(true);
182                                let fanout = Arc::clone(&fanout);
183                                let metrics = metrics.clone();
184                                let mut client_shutdown = shutdown_rx.clone();
185                                let client_ping_interval = ping_interval;
186                                let client_ping_timeout = ping_timeout;
187
188                                tokio::spawn(async move {
189                                    // Cap incoming frames at 1 MiB — clients only send
190                                    // small control messages (subscribe/unsubscribe/ping).
191                                    let mut ws_config = tungstenite::protocol::WebSocketConfig::default();
192                                    ws_config.max_message_size = Some(1024 * 1024);
193                                    ws_config.max_frame_size = Some(1024 * 1024);
194                                    let ws_stream = match tokio_tungstenite::accept_async_with_config(stream, Some(ws_config)).await {
195                                        Ok(ws) => ws,
196                                        Err(e) => {
197                                            warn!(addr = %addr, error = %e, "handshake failed");
198                                            return;
199                                        }
200                                    };
201
202                                    let (mut write, mut read) = ws_stream.split();
203
204                                    // Wait for a subscribe message or auto-subscribe.
205                                    let sub_id = format!("sub_{}", addr.port());
206                                    let filter = None;
207
208                                    // Check for initial subscribe message (with timeout).
209                                    let (filter, last_seq) = match tokio::time::timeout(
210                                        std::time::Duration::from_secs(5),
211                                        read.next(),
212                                    )
213                                    .await
214                                    {
215                                        Ok(Some(Ok(tungstenite::Message::Text(text)))) => {
216                                            match serde_json::from_str::<ClientMessage>(text.as_ref()) {
217                                                Ok(ClientMessage::Subscribe {
218                                                    filter,
219                                                    last_sequence,
220                                                    ..
221                                                }) => (filter, last_sequence),
222                                                _ => (None, None),
223                                            }
224                                        }
225                                        Ok(Some(Err(e))) => {
226                                            warn!(addr = %addr, error = %e, "client read error during subscribe, rejecting");
227                                            return;
228                                        }
229                                        _ => (filter, None),
230                                    };
231
232                                    // Register the client.
233                                    let (client_id, rx) =
234                                        fanout.add_client(sub_id.clone(), filter, None);
235
236                                    metrics.record_connect();
237
238                                    // Send subscription confirmation.
239                                    let confirm = ServerMessage::Subscribed {
240                                        subscription_id: sub_id.clone(),
241                                    };
242                                    if let Ok(json) = serde_json::to_string(&confirm) {
243                                        let _ = write
244                                            .send(tungstenite::Message::Text(json.into()))
245                                            .await;
246                                    }
247
248                                    // Replay if requested.
249                                    if let Some(seq) = last_seq {
250                                        let replay_msgs = fanout.replay_from(seq);
251                                        metrics.record_replay();
252                                        for (_seq, data) in replay_msgs {
253                                            if write
254                                                .send(tungstenite::Message::Text(
255                                                    String::from_utf8_lossy(&data).into_owned().into(),
256                                                ))
257                                                .await
258                                                .is_err()
259                                            {
260                                                break;
261                                            }
262                                        }
263                                    }
264
265                                    // Fan-out loop with ping/pong heartbeats.
266                                    let mut ping_ticker = tokio::time::interval(client_ping_interval);
267                                    ping_ticker.tick().await; // consume initial immediate tick
268                                    let mut awaiting_pong = false;
269                                    let mut last_ping_sent = tokio::time::Instant::now();
270
271                                    loop {
272                                        tokio::select! {
273                                            Some(data) = rx.recv() => {
274                                                let text = String::from_utf8_lossy(&data).into_owned();
275                                                if write.send(tungstenite::Message::Text(text.into())).await.is_err() {
276                                                    break;
277                                                }
278                                                metrics.record_send(data.len() as u64);
279                                            }
280                                            msg = read.next() => {
281                                                match msg {
282                                                    Some(Ok(tungstenite::Message::Close(_))) | None => break,
283                                                    Some(Ok(tungstenite::Message::Pong(_))) => {
284                                                        awaiting_pong = false;
285                                                    }
286                                                    Some(Ok(tungstenite::Message::Text(text))) => {
287                                                        if let Ok(ClientMessage::Unsubscribe { .. }) =
288                                                            serde_json::from_str::<ClientMessage>(text.as_ref())
289                                                        {
290                                                            break;
291                                                        }
292                                                    }
293                                                    Some(Err(e)) => {
294                                                        warn!(addr = %addr, error = %e, "client read error, disconnecting");
295                                                        break;
296                                                    }
297                                                    _ => {}
298                                                }
299                                            }
300                                            _ = ping_ticker.tick() => {
301                                                if awaiting_pong && last_ping_sent.elapsed() > client_ping_timeout {
302                                                    debug!(addr = %addr, "ping timeout — disconnecting");
303                                                    metrics.record_ping_timeout();
304                                                    break;
305                                                }
306                                                if write.send(tungstenite::Message::Ping(bytes::Bytes::new())).await.is_err() {
307                                                    break;
308                                                }
309                                                awaiting_pong = true;
310                                                last_ping_sent = tokio::time::Instant::now();
311                                            }
312                                            _ = client_shutdown.changed() => break,
313                                        }
314                                    }
315
316                                    fanout.remove_client(client_id);
317                                    metrics.record_disconnect();
318                                    debug!(addr = %addr, "sink client disconnected");
319                                });
320                            }
321                            Err(e) => {
322                                warn!(error = %e, "accept error");
323                            }
324                        }
325                    }
326                    _ = shutdown_rx.changed() => {
327                        info!("sink server acceptor shutting down");
328                        break;
329                    }
330                }
331            }
332        });
333
334        self.shutdown_tx = Some(shutdown_tx);
335        self.acceptor_handle = Some(handle);
336        self.state = ConnectorState::Running;
337
338        info!(bind = %bind_address, "WebSocket sink server started");
339        Ok(())
340    }
341
342    #[allow(clippy::cast_possible_truncation)]
343    async fn write_batch(&mut self, batch: &RecordBatch) -> Result<WriteResult, ConnectorError> {
344        if self.state != ConnectorState::Running {
345            return Err(ConnectorError::InvalidState {
346                expected: "Running".into(),
347                actual: self.state.to_string(),
348            });
349        }
350
351        if self.fanout.client_count() == 0 {
352            // No clients connected — discard.
353            return Ok(WriteResult::new(0, 0));
354        }
355
356        // Serialize the batch to JSON.
357        let json = self.serializer.serialize_to_json(batch)?;
358        let seq = self.sequence.fetch_add(1, Ordering::Relaxed) + 1;
359
360        let msg = ServerMessage::Data {
361            subscription_id: String::new(), // broadcast to all
362            data: json,
363            sequence: seq,
364            watermark: None,
365        };
366
367        let serialized = serde_json::to_vec(&msg)
368            .map_err(|e| ConnectorError::Serde(crate::error::SerdeError::Json(e.to_string())))?;
369
370        let bytes_len = serialized.len() as u64;
371        let data = Bytes::from(serialized);
372
373        let result = self.fanout.broadcast(data);
374
375        self.metrics.record_send(bytes_len);
376        if result.dropped > 0 {
377            for _ in 0..result.dropped {
378                self.metrics.record_drop();
379            }
380        }
381
382        debug!(
383            records = batch.num_rows(),
384            sent = result.sent,
385            dropped = result.dropped,
386            sequence = result.sequence,
387            "broadcast batch to WebSocket clients"
388        );
389
390        Ok(WriteResult::new(batch.num_rows(), bytes_len))
391    }
392
393    fn schema(&self) -> SchemaRef {
394        self.schema.clone()
395    }
396
397    fn health_check(&self) -> HealthStatus {
398        match self.state {
399            ConnectorState::Running => HealthStatus::Healthy,
400            ConnectorState::Created | ConnectorState::Initializing => HealthStatus::Unknown,
401            ConnectorState::Paused => HealthStatus::Degraded("connector paused".into()),
402            ConnectorState::Recovering => HealthStatus::Degraded("recovering".into()),
403            ConnectorState::Closed => HealthStatus::Unhealthy("closed".into()),
404            ConnectorState::Failed => HealthStatus::Unhealthy("failed".into()),
405        }
406    }
407
408    fn metrics(&self) -> ConnectorMetrics {
409        self.metrics.to_connector_metrics()
410    }
411
412    fn capabilities(&self) -> SinkConnectorCapabilities {
413        SinkConnectorCapabilities::default()
414        // WebSocket sink: no exactly-once, no upsert, no changelog
415    }
416
417    async fn begin_epoch(&mut self, epoch: u64) -> Result<(), ConnectorError> {
418        self.current_epoch = epoch;
419        Ok(())
420    }
421
422    async fn commit_epoch(&mut self, _epoch: u64) -> Result<(), ConnectorError> {
423        Ok(())
424    }
425
426    async fn close(&mut self) -> Result<(), ConnectorError> {
427        info!("closing WebSocket sink server");
428
429        if let Some(tx) = self.shutdown_tx.take() {
430            let _ = tx.send(true);
431        }
432
433        if let Some(handle) = self.acceptor_handle.take() {
434            let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
435        }
436
437        self.state = ConnectorState::Closed;
438        info!("WebSocket sink server closed");
439        Ok(())
440    }
441}
442
443impl std::fmt::Debug for WebSocketSinkServer {
444    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
445        f.debug_struct("WebSocketSinkServer")
446            .field("state", &self.state)
447            .field("connected_clients", &self.connected_clients())
448            .field("format", &self.config.format)
449            .field("current_epoch", &self.current_epoch)
450            .finish_non_exhaustive()
451    }
452}
453
454#[cfg(test)]
455mod tests {
456    use super::super::sink_config::SinkFormat;
457    use super::*;
458    use arrow_schema::{DataType, Field, Schema};
459
460    fn test_schema() -> SchemaRef {
461        Arc::new(Schema::new(vec![
462            Field::new("id", DataType::Int64, false),
463            Field::new("value", DataType::Utf8, false),
464        ]))
465    }
466
467    fn test_config() -> WebSocketSinkConfig {
468        WebSocketSinkConfig {
469            mode: SinkMode::Server {
470                bind_address: "127.0.0.1:0".into(),
471                path: None,
472                max_connections: 100,
473                per_client_buffer: 262_144,
474                slow_client_policy: SlowClientPolicy::DropOldest,
475                ping_interval: std::time::Duration::from_secs(30),
476                ping_timeout: std::time::Duration::from_secs(10),
477                enable_subscription_filter: false,
478                replay_buffer_size: None,
479            },
480            format: SinkFormat::Json,
481            auth: None,
482        }
483    }
484
485    #[test]
486    fn test_new() {
487        let sink = WebSocketSinkServer::new(test_schema(), test_config());
488        assert_eq!(sink.state(), ConnectorState::Created);
489        assert_eq!(sink.connected_clients(), 0);
490    }
491
492    #[test]
493    fn test_schema_returned() {
494        let schema = test_schema();
495        let sink = WebSocketSinkServer::new(schema.clone(), test_config());
496        assert_eq!(sink.schema(), schema);
497    }
498
499    #[test]
500    fn test_capabilities() {
501        let sink = WebSocketSinkServer::new(test_schema(), test_config());
502        let caps = sink.capabilities();
503        assert!(!caps.exactly_once);
504        assert!(!caps.upsert);
505    }
506
507    #[test]
508    fn test_health_created() {
509        let sink = WebSocketSinkServer::new(test_schema(), test_config());
510        assert_eq!(sink.health_check(), HealthStatus::Unknown);
511    }
512
513    #[test]
514    fn test_metrics_initial() {
515        let sink = WebSocketSinkServer::new(test_schema(), test_config());
516        let m = sink.metrics();
517        assert_eq!(m.records_total, 0);
518    }
519}