Skip to main content

laminar_connectors/websocket/
source_server.rs

1//! WebSocket source connector — server mode.
2//!
3//! [`WebSocketSourceServer`] listens on a TCP port and accepts incoming
4//! WebSocket connections from clients pushing data (e.g., `IoT` sensors,
5//! browser events).
6
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::Arc;
9
10use arrow_schema::SchemaRef;
11use async_trait::async_trait;
12use crossfire::{mpsc, AsyncRx, TryRecvError};
13use futures_util::StreamExt;
14use tokio::net::TcpListener;
15use tokio::sync::Notify;
16use tracing::{debug, info, warn};
17
18use crate::checkpoint::SourceCheckpoint;
19use crate::config::{ConnectorConfig, ConnectorState};
20use crate::connector::{SourceBatch, SourceConnector};
21use crate::error::ConnectorError;
22
23use crate::schema::json::decoder::JsonDecoderConfig;
24
25use super::checkpoint::WebSocketSourceCheckpoint;
26use super::metrics::WebSocketSourceMetrics;
27use super::parser::MessageParser;
28use super::source_config::{SourceMode, WebSocketSourceConfig};
29
30/// WebSocket source connector in server mode.
31///
32/// Binds a TCP listener and accepts incoming WebSocket connections.
33/// All clients' messages feed into the same bounded channel for
34/// `poll_batch()` consumption.
35pub struct WebSocketSourceServer {
36    /// Parsed configuration.
37    config: WebSocketSourceConfig,
38    /// Output Arrow schema.
39    schema: SchemaRef,
40    /// Message parser.
41    parser: MessageParser,
42    /// Connector lifecycle state.
43    state: ConnectorState,
44    /// Metrics.
45    metrics: WebSocketSourceMetrics,
46    /// Checkpoint state.
47    checkpoint_state: WebSocketSourceCheckpoint,
48    /// Bounded channel receiver for messages from client handler tasks.
49    rx: Option<AsyncRx<mpsc::Array<Vec<u8>>>>,
50    /// Shutdown signal sender.
51    shutdown_tx: Option<tokio::sync::watch::Sender<bool>>,
52    /// Handle to the spawned acceptor task.
53    acceptor_handle: Option<tokio::task::JoinHandle<()>>,
54    /// Message buffer between polls.
55    message_buffer: Vec<Vec<u8>>,
56    /// Connected client count (shared with acceptor task).
57    connected_clients: Arc<AtomicU64>,
58    /// Maximum records per batch.
59    max_batch_size: usize,
60    /// Notification handle signalled when data arrives from client handler tasks.
61    data_ready: Arc<Notify>,
62}
63
64impl WebSocketSourceServer {
65    /// Creates a new WebSocket source connector in server mode.
66    #[must_use]
67    pub fn new(
68        schema: SchemaRef,
69        config: WebSocketSourceConfig,
70        registry: Option<&prometheus::Registry>,
71    ) -> Self {
72        let parser = MessageParser::new(
73            schema.clone(),
74            config.format.clone(),
75            JsonDecoderConfig::default(),
76        );
77
78        Self {
79            config,
80            schema,
81            parser,
82            state: ConnectorState::Created,
83            metrics: WebSocketSourceMetrics::new(registry),
84            checkpoint_state: WebSocketSourceCheckpoint::default(),
85            rx: None,
86            shutdown_tx: None,
87            acceptor_handle: None,
88            message_buffer: Vec::new(),
89            connected_clients: Arc::new(AtomicU64::new(0)),
90            max_batch_size: 1000,
91            data_ready: Arc::new(Notify::new()),
92        }
93    }
94
95    /// Returns the current connector state.
96    #[must_use]
97    pub fn state(&self) -> ConnectorState {
98        self.state
99    }
100
101    /// Returns the number of currently connected clients.
102    #[must_use]
103    pub fn connected_clients(&self) -> u64 {
104        self.connected_clients.load(Ordering::Relaxed)
105    }
106}
107
108#[async_trait]
109#[allow(clippy::too_many_lines)]
110impl SourceConnector for WebSocketSourceServer {
111    async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
112        self.state = ConnectorState::Initializing;
113
114        // If config has properties, re-parse (supports runtime config via SQL WITH).
115        if !config.properties().is_empty() {
116            self.config = WebSocketSourceConfig::from_config(config)?;
117        }
118
119        let (bind_address, max_connections, _path) = match &self.config.mode {
120            SourceMode::Server {
121                bind_address,
122                max_connections,
123                path,
124            } => (bind_address.clone(), *max_connections, path.clone()),
125            SourceMode::Client { .. } => {
126                return Err(ConnectorError::ConfigurationError(
127                    "WebSocketSourceServer is for server mode; use WebSocketSource for client mode"
128                        .into(),
129                ));
130            }
131        };
132
133        info!(
134            bind = %bind_address,
135            max_connections,
136            format = ?self.config.format,
137            "opening WebSocket source connector (server mode)"
138        );
139
140        let listener = TcpListener::bind(&bind_address).await.map_err(|e| {
141            ConnectorError::ConnectionFailed(format!("failed to bind {bind_address}: {e}"))
142        })?;
143
144        let channel_capacity = 10_000;
145        let (tx, rx) = mpsc::bounded_async::<Vec<u8>>(channel_capacity);
146        let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
147
148        let connected = Arc::clone(&self.connected_clients);
149        let max_msg_size = self.config.max_message_size;
150        let data_ready = Arc::clone(&self.data_ready);
151
152        let handle = tokio::spawn(async move {
153            let mut shutdown_rx = shutdown_rx;
154
155            loop {
156                tokio::select! {
157                    accept_result = listener.accept() => {
158                        match accept_result {
159                            Ok((stream, addr)) => {
160                                let current = connected.load(Ordering::Relaxed);
161                                if current >= max_connections as u64 {
162                                    warn!(
163                                        current_connections = current,
164                                        max = max_connections,
165                                        addr = %addr,
166                                        "rejecting connection: max_connections exceeded"
167                                    );
168                                    drop(stream);
169                                    continue;
170                                }
171
172                                // Set TCP_NODELAY for low latency.
173                                let _ = stream.set_nodelay(true);
174
175                                let tx = tx.clone();
176                                let connected = Arc::clone(&connected);
177                                let mut client_shutdown = shutdown_rx.clone();
178                                let data_ready = Arc::clone(&data_ready);
179
180                                connected.fetch_add(1, Ordering::Relaxed);
181                                debug!(addr = %addr, "accepted WebSocket client");
182
183                                tokio::spawn(async move {
184                                    let mut ws_config = tungstenite::protocol::WebSocketConfig::default();
185                                    ws_config.max_message_size = Some(max_msg_size);
186                                    ws_config.max_frame_size = Some(max_msg_size);
187                                    let ws_stream = match tokio_tungstenite::accept_async_with_config(stream, Some(ws_config)).await {
188                                        Ok(ws) => ws,
189                                        Err(e) => {
190                                            warn!(addr = %addr, error = %e, "WebSocket handshake failed");
191                                            connected.fetch_sub(1, Ordering::Relaxed);
192                                            return;
193                                        }
194                                    };
195
196                                    let (_write, mut read) = ws_stream.split();
197
198                                    loop {
199                                        tokio::select! {
200                                            msg = read.next() => {
201                                                match msg {
202                                                    Some(Ok(tungstenite::Message::Text(text))) => {
203                                                        let payload = text.as_bytes().to_vec();
204                                                        if payload.len() > max_msg_size {
205                                                            warn!(
206                                                                size = payload.len(),
207                                                                max = max_msg_size,
208                                                                addr = %addr,
209                                                                "dropping oversized text message"
210                                                            );
211                                                        } else if tx.send(payload).await.is_err() {
212                                                            break;
213                                                        } else {
214                                                            data_ready.notify_one();
215                                                        }
216                                                    }
217                                                    Some(Ok(tungstenite::Message::Binary(data))) => {
218                                                        let payload = data.to_vec();
219                                                        if payload.len() > max_msg_size {
220                                                            warn!(
221                                                                size = payload.len(),
222                                                                max = max_msg_size,
223                                                                addr = %addr,
224                                                                "dropping oversized binary message"
225                                                            );
226                                                        } else if tx.send(payload).await.is_err() {
227                                                            break;
228                                                        } else {
229                                                            data_ready.notify_one();
230                                                        }
231                                                    }
232                                                    Some(Ok(tungstenite::Message::Close(_))) | None => break,
233                                                    Some(Ok(_)) => {} // Ping/Pong handled by tungstenite
234                                                    Some(Err(e)) => {
235                                                        debug!(addr = %addr, error = %e, "client read error");
236                                                        break;
237                                                    }
238                                                }
239                                            }
240                                            _ = client_shutdown.changed() => break,
241                                        }
242                                    }
243
244                                    connected.fetch_sub(1, Ordering::Relaxed);
245                                    debug!(addr = %addr, "client disconnected");
246                                });
247                            }
248                            Err(e) => {
249                                warn!(error = %e, "accept error");
250                            }
251                        }
252                    }
253                    _ = shutdown_rx.changed() => {
254                        info!("acceptor shutting down");
255                        break;
256                    }
257                }
258            }
259        });
260
261        self.rx = Some(rx);
262        self.shutdown_tx = Some(shutdown_tx);
263        self.acceptor_handle = Some(handle);
264        self.state = ConnectorState::Running;
265
266        info!(bind = %bind_address, "WebSocket source server started");
267        Ok(())
268    }
269
270    #[allow(clippy::cast_possible_truncation)]
271    async fn poll_batch(
272        &mut self,
273        max_records: usize,
274    ) -> Result<Option<SourceBatch>, ConnectorError> {
275        if self.state != ConnectorState::Running {
276            return Err(ConnectorError::InvalidState {
277                expected: "Running".into(),
278                actual: self.state.to_string(),
279            });
280        }
281
282        let rx = self
283            .rx
284            .as_mut()
285            .ok_or_else(|| ConnectorError::InvalidState {
286                expected: "channel initialized".into(),
287                actual: "channel is None".into(),
288            })?;
289
290        let limit = max_records.min(self.max_batch_size);
291
292        // Non-blocking drain: pull all available messages from the channel.
293        // The pipeline coordinator handles wake-up timing via data_ready_notify().
294        while self.message_buffer.len() < limit {
295            match rx.try_recv() {
296                Ok(payload) => {
297                    self.metrics.record_message(payload.len() as u64);
298                    self.message_buffer.push(payload);
299                }
300                Err(TryRecvError::Empty) => break,
301                Err(TryRecvError::Disconnected) => {
302                    if self.message_buffer.is_empty() {
303                        self.state = ConnectorState::Failed;
304                        return Err(ConnectorError::ReadError(
305                            "WebSocket source server acceptor terminated".into(),
306                        ));
307                    }
308                    // Buffer has data already dequeued — break so we can
309                    // parse and return the final batch. Next poll_batch()
310                    // call will see Disconnected again with an empty buffer
311                    // and transition to Failed.
312                    break;
313                }
314            }
315        }
316
317        self.metrics
318            .set_connected_clients(self.connected_clients.load(Ordering::Relaxed));
319
320        if self.message_buffer.is_empty() {
321            return Ok(None);
322        }
323
324        let refs: Vec<&[u8]> = self.message_buffer.iter().map(Vec::as_slice).collect();
325        let batch = self.parser.parse_batch(&refs).inspect_err(|_e| {
326            self.metrics.record_parse_error();
327        })?;
328
329        let num_rows = batch.num_rows();
330        self.message_buffer.clear();
331
332        // Use event time from batch data if configured, otherwise wall-clock.
333        if let Some(ref field) = self.config.event_time_field {
334            if let Some(max_ts) = super::parser::extract_max_event_time(&batch, field)? {
335                self.checkpoint_state.watermark = max_ts;
336            }
337        } else {
338            self.checkpoint_state.watermark = std::time::SystemTime::now()
339                .duration_since(std::time::UNIX_EPOCH)
340                .unwrap_or_default()
341                .as_millis() as i64;
342        }
343
344        debug!(
345            records = num_rows,
346            clients = self.connected_clients(),
347            "polled batch from WebSocket server source"
348        );
349        Ok(Some(SourceBatch::new(batch)))
350    }
351
352    fn schema(&self) -> SchemaRef {
353        self.schema.clone()
354    }
355
356    fn checkpoint(&self) -> SourceCheckpoint {
357        self.checkpoint_state.to_source_checkpoint(0)
358    }
359
360    async fn restore(&mut self, checkpoint: &SourceCheckpoint) -> Result<(), ConnectorError> {
361        info!(
362            epoch = checkpoint.epoch(),
363            "restoring WebSocket source server from checkpoint (best-effort)"
364        );
365        self.checkpoint_state = WebSocketSourceCheckpoint::from_source_checkpoint(checkpoint);
366        warn!("WebSocket source server restored; data gap expected (non-replayable)");
367        Ok(())
368    }
369
370    fn data_ready_notify(&self) -> Option<Arc<Notify>> {
371        Some(Arc::clone(&self.data_ready))
372    }
373
374    fn supports_replay(&self) -> bool {
375        false
376    }
377
378    async fn close(&mut self) -> Result<(), ConnectorError> {
379        info!("closing WebSocket source server");
380
381        if let Some(tx) = self.shutdown_tx.take() {
382            let _ = tx.send(true);
383        }
384
385        if let Some(handle) = self.acceptor_handle.take() {
386            let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
387        }
388
389        self.rx = None;
390        self.state = ConnectorState::Closed;
391        info!("WebSocket source server closed");
392        Ok(())
393    }
394}
395
396impl std::fmt::Debug for WebSocketSourceServer {
397    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
398        f.debug_struct("WebSocketSourceServer")
399            .field("state", &self.state)
400            .field("connected_clients", &self.connected_clients())
401            .finish_non_exhaustive()
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use super::super::source_config::MessageFormat;
408    use super::*;
409    use arrow_schema::{DataType, Field, Schema};
410
411    fn test_schema() -> SchemaRef {
412        Arc::new(Schema::new(vec![
413            Field::new("id", DataType::Utf8, true),
414            Field::new("value", DataType::Utf8, true),
415        ]))
416    }
417
418    fn test_config() -> WebSocketSourceConfig {
419        WebSocketSourceConfig {
420            mode: SourceMode::Server {
421                bind_address: "127.0.0.1:0".into(),
422                max_connections: 100,
423                path: None,
424            },
425            format: MessageFormat::Json,
426            on_backpressure: super::super::backpressure::WsBackpressure::Block,
427            event_time_field: None,
428            event_time_format: None,
429            max_message_size: 64 * 1024 * 1024,
430            auth: None,
431        }
432    }
433
434    #[test]
435    fn test_new() {
436        let server = WebSocketSourceServer::new(test_schema(), test_config(), None);
437        assert_eq!(server.state(), ConnectorState::Created);
438        assert_eq!(server.connected_clients(), 0);
439    }
440
441    #[test]
442    fn test_schema_returned() {
443        let schema = test_schema();
444        let server = WebSocketSourceServer::new(schema.clone(), test_config(), None);
445        assert_eq!(server.schema(), schema);
446    }
447}