Skip to main content

laminar_connectors/websocket/
sink.rs

1//! WebSocket sink connector — server mode.
2//!
3//! [`WebSocketSinkServer`] hosts a WebSocket endpoint that connected clients
4//! subscribe to for streaming query results. Implements per-client isolation
5//! via the [`FanoutManager`].
6
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::Arc;
9use std::time::Duration;
10
11use arrow_array::RecordBatch;
12use arrow_schema::SchemaRef;
13use async_trait::async_trait;
14use bytes::Bytes;
15use futures_util::{SinkExt, StreamExt};
16use tokio::net::TcpListener;
17use tracing::{debug, info, warn};
18
19use crate::config::{ConnectorConfig, ConnectorState};
20use crate::connector::{SinkConnector, SinkConnectorCapabilities, WriteResult};
21use crate::error::ConnectorError;
22
23use super::fanout::FanoutManager;
24use super::protocol::{ClientMessage, ServerMessage};
25use super::serializer::BatchSerializer;
26use super::sink_config::{SinkMode, SlowClientPolicy, WebSocketSinkConfig};
27use super::sink_metrics::WebSocketSinkMetrics;
28
29/// WebSocket sink connector in server mode.
30///
31/// Hosts a WebSocket server. Connected clients subscribe and receive
32/// streaming query results via the fan-out manager.
33pub struct WebSocketSinkServer {
34    /// Configuration.
35    config: WebSocketSinkConfig,
36    /// Input Arrow schema.
37    schema: SchemaRef,
38    /// Serializer for `RecordBatch` → JSON/Binary.
39    serializer: BatchSerializer,
40    /// Fan-out manager for per-client message distribution.
41    fanout: Arc<FanoutManager>,
42    /// Connector state.
43    state: ConnectorState,
44    /// Metrics.
45    metrics: Arc<WebSocketSinkMetrics>,
46    /// Current epoch.
47    current_epoch: u64,
48    /// Shutdown signal sender.
49    shutdown_tx: Option<tokio::sync::watch::Sender<bool>>,
50    /// Acceptor task handle.
51    acceptor_handle: Option<tokio::task::JoinHandle<()>>,
52    /// Global sequence counter (shared with fanout).
53    sequence: Arc<AtomicU64>,
54}
55
56impl WebSocketSinkServer {
57    /// Creates a new WebSocket sink server connector.
58    #[must_use]
59    pub fn new(
60        schema: SchemaRef,
61        config: WebSocketSinkConfig,
62        registry: Option<&prometheus::Registry>,
63    ) -> Self {
64        let serializer = BatchSerializer::new(config.format.clone());
65
66        let (buffer_capacity, policy, replay_size) = match &config.mode {
67            SinkMode::Server {
68                per_client_buffer,
69                slow_client_policy,
70                replay_buffer_size,
71                ..
72            } => {
73                // Convert bytes to approximate message count (assume ~256 bytes/msg).
74                let msg_capacity = (*per_client_buffer / 256).max(1);
75                (
76                    msg_capacity,
77                    slow_client_policy.clone(),
78                    *replay_buffer_size,
79                )
80            }
81            SinkMode::Client { .. } => (1024, SlowClientPolicy::DropOldest, None),
82        };
83
84        let fanout = Arc::new(FanoutManager::new(policy, buffer_capacity, replay_size));
85
86        Self {
87            config,
88            schema,
89            serializer,
90            fanout,
91            state: ConnectorState::Created,
92            metrics: Arc::new(WebSocketSinkMetrics::new(registry)),
93            current_epoch: 0,
94            shutdown_tx: None,
95            acceptor_handle: None,
96            sequence: Arc::new(AtomicU64::new(0)),
97        }
98    }
99
100    /// Returns the current connector state.
101    #[must_use]
102    pub fn state(&self) -> ConnectorState {
103        self.state
104    }
105
106    /// Returns the number of connected clients.
107    #[must_use]
108    pub fn connected_clients(&self) -> usize {
109        self.fanout.client_count()
110    }
111
112    /// Returns a reference to the fan-out manager.
113    #[must_use]
114    pub fn fanout(&self) -> &Arc<FanoutManager> {
115        &self.fanout
116    }
117}
118
119#[async_trait]
120#[allow(clippy::too_many_lines)]
121impl SinkConnector for WebSocketSinkServer {
122    async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
123        self.state = ConnectorState::Initializing;
124
125        // If config has properties, re-parse (supports runtime config via SQL WITH).
126        if !config.properties().is_empty() {
127            self.config = WebSocketSinkConfig::from_config(config)?;
128        }
129
130        let (bind_address, max_connections, _path, ping_interval, ping_timeout) = match &self
131            .config
132            .mode
133        {
134            SinkMode::Server {
135                bind_address,
136                max_connections,
137                path,
138                ping_interval,
139                ping_timeout,
140                ..
141            } => (
142                bind_address.clone(),
143                *max_connections,
144                path.clone(),
145                *ping_interval,
146                *ping_timeout,
147            ),
148            SinkMode::Client { .. } => {
149                return Err(ConnectorError::ConfigurationError(
150                        "WebSocketSinkServer is for server mode; use WebSocketSinkClient for client mode".into(),
151                    ));
152            }
153        };
154
155        info!(
156            bind = %bind_address,
157            max_connections,
158            format = ?self.config.format,
159            "opening WebSocket sink server"
160        );
161
162        let listener = TcpListener::bind(&bind_address).await.map_err(|e| {
163            ConnectorError::ConnectionFailed(format!("failed to bind {bind_address}: {e}"))
164        })?;
165
166        let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
167        let fanout = Arc::clone(&self.fanout);
168        let metrics = Arc::clone(&self.metrics);
169
170        let handle = tokio::spawn(async move {
171            let mut shutdown_rx = shutdown_rx;
172
173            loop {
174                tokio::select! {
175                    accept_result = listener.accept() => {
176                        match accept_result {
177                            Ok((stream, addr)) => {
178                                if fanout.client_count() >= max_connections {
179                                    warn!(addr = %addr, "rejecting: max_connections exceeded");
180                                    drop(stream);
181                                    continue;
182                                }
183
184                                let _ = stream.set_nodelay(true);
185                                let fanout = Arc::clone(&fanout);
186                                let metrics = metrics.clone();
187                                let mut client_shutdown = shutdown_rx.clone();
188                                let client_ping_interval = ping_interval;
189                                let client_ping_timeout = ping_timeout;
190
191                                tokio::spawn(async move {
192                                    // Cap incoming frames at 1 MiB — clients only send
193                                    // small control messages (subscribe/unsubscribe/ping).
194                                    let mut ws_config = tungstenite::protocol::WebSocketConfig::default();
195                                    ws_config.max_message_size = Some(1024 * 1024);
196                                    ws_config.max_frame_size = Some(1024 * 1024);
197                                    let ws_stream = match tokio_tungstenite::accept_async_with_config(stream, Some(ws_config)).await {
198                                        Ok(ws) => ws,
199                                        Err(e) => {
200                                            warn!(addr = %addr, error = %e, "handshake failed");
201                                            return;
202                                        }
203                                    };
204
205                                    let (mut write, mut read) = ws_stream.split();
206
207                                    // Wait for a subscribe message or auto-subscribe.
208                                    let sub_id = format!("sub_{}", addr.port());
209                                    let filter = None;
210
211                                    // Check for initial subscribe message (with timeout).
212                                    let (filter, last_seq) = match tokio::time::timeout(
213                                        std::time::Duration::from_secs(5),
214                                        read.next(),
215                                    )
216                                    .await
217                                    {
218                                        Ok(Some(Ok(tungstenite::Message::Text(text)))) => {
219                                            match serde_json::from_str::<ClientMessage>(text.as_ref()) {
220                                                Ok(ClientMessage::Subscribe {
221                                                    filter,
222                                                    last_sequence,
223                                                    ..
224                                                }) => (filter, last_sequence),
225                                                _ => (None, None),
226                                            }
227                                        }
228                                        Ok(Some(Err(e))) => {
229                                            warn!(addr = %addr, error = %e, "client read error during subscribe, rejecting");
230                                            return;
231                                        }
232                                        _ => (filter, None),
233                                    };
234
235                                    // Register the client.
236                                    let (client_id, rx) =
237                                        fanout.add_client(sub_id.clone(), filter, None);
238
239                                    metrics.record_connect();
240
241                                    // Send subscription confirmation.
242                                    let confirm = ServerMessage::Subscribed {
243                                        subscription_id: sub_id.clone(),
244                                    };
245                                    if let Ok(json) = serde_json::to_string(&confirm) {
246                                        let _ = write
247                                            .send(tungstenite::Message::Text(json.into()))
248                                            .await;
249                                    }
250
251                                    // Replay if requested.
252                                    if let Some(seq) = last_seq {
253                                        let replay_msgs = fanout.replay_from(seq);
254                                        metrics.record_replay();
255                                        for (_seq, data) in replay_msgs {
256                                            if write
257                                                .send(tungstenite::Message::Text(
258                                                    String::from_utf8_lossy(&data).into_owned().into(),
259                                                ))
260                                                .await
261                                                .is_err()
262                                            {
263                                                break;
264                                            }
265                                        }
266                                    }
267
268                                    // Fan-out loop with ping/pong heartbeats.
269                                    let mut ping_ticker = tokio::time::interval(client_ping_interval);
270                                    ping_ticker.tick().await; // consume initial immediate tick
271                                    let mut awaiting_pong = false;
272                                    let mut last_ping_sent = tokio::time::Instant::now();
273
274                                    loop {
275                                        tokio::select! {
276                                            Some(data) = rx.recv() => {
277                                                let text = String::from_utf8_lossy(&data).into_owned();
278                                                if write.send(tungstenite::Message::Text(text.into())).await.is_err() {
279                                                    break;
280                                                }
281                                                metrics.record_send(data.len() as u64);
282                                            }
283                                            msg = read.next() => {
284                                                match msg {
285                                                    Some(Ok(tungstenite::Message::Close(_))) | None => break,
286                                                    Some(Ok(tungstenite::Message::Pong(_))) => {
287                                                        awaiting_pong = false;
288                                                    }
289                                                    Some(Ok(tungstenite::Message::Text(text))) => {
290                                                        if let Ok(ClientMessage::Unsubscribe { .. }) =
291                                                            serde_json::from_str::<ClientMessage>(text.as_ref())
292                                                        {
293                                                            break;
294                                                        }
295                                                    }
296                                                    Some(Err(e)) => {
297                                                        warn!(addr = %addr, error = %e, "client read error, disconnecting");
298                                                        break;
299                                                    }
300                                                    _ => {}
301                                                }
302                                            }
303                                            _ = ping_ticker.tick() => {
304                                                if awaiting_pong && last_ping_sent.elapsed() > client_ping_timeout {
305                                                    debug!(addr = %addr, "ping timeout — disconnecting");
306                                                    metrics.record_ping_timeout();
307                                                    break;
308                                                }
309                                                if write.send(tungstenite::Message::Ping(bytes::Bytes::new())).await.is_err() {
310                                                    break;
311                                                }
312                                                awaiting_pong = true;
313                                                last_ping_sent = tokio::time::Instant::now();
314                                            }
315                                            _ = client_shutdown.changed() => break,
316                                        }
317                                    }
318
319                                    fanout.remove_client(client_id);
320                                    metrics.record_disconnect();
321                                    debug!(addr = %addr, "sink client disconnected");
322                                });
323                            }
324                            Err(e) => {
325                                warn!(error = %e, "accept error");
326                            }
327                        }
328                    }
329                    _ = shutdown_rx.changed() => {
330                        info!("sink server acceptor shutting down");
331                        break;
332                    }
333                }
334            }
335        });
336
337        self.shutdown_tx = Some(shutdown_tx);
338        self.acceptor_handle = Some(handle);
339        self.state = ConnectorState::Running;
340
341        info!(bind = %bind_address, "WebSocket sink server started");
342        Ok(())
343    }
344
345    #[allow(clippy::cast_possible_truncation)]
346    async fn write_batch(&mut self, batch: &RecordBatch) -> Result<WriteResult, ConnectorError> {
347        if self.state != ConnectorState::Running {
348            return Err(ConnectorError::InvalidState {
349                expected: "Running".into(),
350                actual: self.state.to_string(),
351            });
352        }
353
354        if self.fanout.client_count() == 0 {
355            // No clients connected — discard.
356            return Ok(WriteResult::new(0, 0));
357        }
358
359        // Serialize the batch to JSON.
360        let json = self.serializer.serialize_to_json(batch)?;
361        let seq = self.sequence.fetch_add(1, Ordering::Relaxed) + 1;
362
363        let msg = ServerMessage::Data {
364            subscription_id: String::new(), // broadcast to all
365            data: json,
366            sequence: seq,
367            watermark: None,
368        };
369
370        let serialized = serde_json::to_vec(&msg)
371            .map_err(|e| ConnectorError::Serde(crate::error::SerdeError::Json(e.to_string())))?;
372
373        let bytes_len = serialized.len() as u64;
374        let data = Bytes::from(serialized);
375
376        let result = self.fanout.broadcast(data);
377
378        self.metrics.record_send(bytes_len);
379        if result.dropped > 0 {
380            for _ in 0..result.dropped {
381                self.metrics.record_drop();
382            }
383        }
384
385        debug!(
386            records = batch.num_rows(),
387            sent = result.sent,
388            dropped = result.dropped,
389            sequence = result.sequence,
390            "broadcast batch to WebSocket clients"
391        );
392
393        Ok(WriteResult::new(batch.num_rows(), bytes_len))
394    }
395
396    fn schema(&self) -> SchemaRef {
397        self.schema.clone()
398    }
399
400    fn capabilities(&self) -> SinkConnectorCapabilities {
401        // In-memory fanout — timeout is effectively unreachable.
402        SinkConnectorCapabilities::new(Duration::from_secs(10))
403    }
404
405    async fn begin_epoch(&mut self, epoch: u64) -> Result<(), ConnectorError> {
406        self.current_epoch = epoch;
407        Ok(())
408    }
409
410    async fn commit_epoch(&mut self, _epoch: u64) -> Result<(), ConnectorError> {
411        Ok(())
412    }
413
414    async fn close(&mut self) -> Result<(), ConnectorError> {
415        info!("closing WebSocket sink server");
416
417        if let Some(tx) = self.shutdown_tx.take() {
418            let _ = tx.send(true);
419        }
420
421        if let Some(handle) = self.acceptor_handle.take() {
422            let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
423        }
424
425        self.state = ConnectorState::Closed;
426        info!("WebSocket sink server closed");
427        Ok(())
428    }
429}
430
431impl std::fmt::Debug for WebSocketSinkServer {
432    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
433        f.debug_struct("WebSocketSinkServer")
434            .field("state", &self.state)
435            .field("connected_clients", &self.connected_clients())
436            .field("format", &self.config.format)
437            .field("current_epoch", &self.current_epoch)
438            .finish_non_exhaustive()
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::super::sink_config::SinkFormat;
445    use super::*;
446    use arrow_schema::{DataType, Field, Schema};
447
448    fn test_schema() -> SchemaRef {
449        Arc::new(Schema::new(vec![
450            Field::new("id", DataType::Int64, false),
451            Field::new("value", DataType::Utf8, false),
452        ]))
453    }
454
455    fn test_config() -> WebSocketSinkConfig {
456        WebSocketSinkConfig {
457            mode: SinkMode::Server {
458                bind_address: "127.0.0.1:0".into(),
459                path: None,
460                max_connections: 100,
461                per_client_buffer: 262_144,
462                slow_client_policy: SlowClientPolicy::DropOldest,
463                ping_interval: std::time::Duration::from_secs(30),
464                ping_timeout: std::time::Duration::from_secs(10),
465                enable_subscription_filter: false,
466                replay_buffer_size: None,
467            },
468            format: SinkFormat::Json,
469            auth: None,
470        }
471    }
472
473    #[test]
474    fn test_new() {
475        let sink = WebSocketSinkServer::new(test_schema(), test_config(), None);
476        assert_eq!(sink.state(), ConnectorState::Created);
477        assert_eq!(sink.connected_clients(), 0);
478    }
479
480    #[test]
481    fn test_schema_returned() {
482        let schema = test_schema();
483        let sink = WebSocketSinkServer::new(schema.clone(), test_config(), None);
484        assert_eq!(sink.schema(), schema);
485    }
486
487    #[test]
488    fn test_capabilities() {
489        let sink = WebSocketSinkServer::new(test_schema(), test_config(), None);
490        let caps = sink.capabilities();
491        assert!(!caps.exactly_once);
492        assert!(!caps.upsert);
493    }
494}