Skip to main content

laminar_connectors/websocket/
source_server.rs

1//! WebSocket source connector — server mode.
2//!
3//! [`WebSocketSourceServer`] listens on a TCP port and accepts incoming
4//! WebSocket connections from clients pushing data (e.g., `IoT` sensors,
5//! browser events).
6
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::Arc;
9
10use arrow_schema::SchemaRef;
11use async_trait::async_trait;
12use futures_util::StreamExt;
13use tokio::net::TcpListener;
14use tokio::sync::{mpsc, Notify};
15use tracing::{debug, info, warn};
16
17use crate::checkpoint::SourceCheckpoint;
18use crate::config::{ConnectorConfig, ConnectorState};
19use crate::connector::{SourceBatch, SourceConnector};
20use crate::error::ConnectorError;
21use crate::health::HealthStatus;
22use crate::metrics::ConnectorMetrics;
23
24use crate::schema::json::decoder::JsonDecoderConfig;
25
26use super::checkpoint::WebSocketSourceCheckpoint;
27use super::metrics::WebSocketSourceMetrics;
28use super::parser::MessageParser;
29use super::source_config::{SourceMode, WebSocketSourceConfig};
30
31/// WebSocket source connector in server mode.
32///
33/// Binds a TCP listener and accepts incoming WebSocket connections.
34/// All clients' messages feed into the same bounded channel for
35/// `poll_batch()` consumption.
36pub struct WebSocketSourceServer {
37    /// Parsed configuration.
38    config: WebSocketSourceConfig,
39    /// Output Arrow schema.
40    schema: SchemaRef,
41    /// Message parser.
42    parser: MessageParser,
43    /// Connector lifecycle state.
44    state: ConnectorState,
45    /// Metrics.
46    metrics: WebSocketSourceMetrics,
47    /// Checkpoint state.
48    checkpoint_state: WebSocketSourceCheckpoint,
49    /// Bounded channel receiver for messages from client handler tasks.
50    rx: Option<mpsc::Receiver<Vec<u8>>>,
51    /// Shutdown signal sender.
52    shutdown_tx: Option<tokio::sync::watch::Sender<bool>>,
53    /// Handle to the spawned acceptor task.
54    acceptor_handle: Option<tokio::task::JoinHandle<()>>,
55    /// Message buffer between polls.
56    message_buffer: Vec<Vec<u8>>,
57    /// Connected client count (shared with acceptor task).
58    connected_clients: Arc<AtomicU64>,
59    /// Maximum records per batch.
60    max_batch_size: usize,
61    /// Notification handle signalled when data arrives from client handler tasks.
62    data_ready: Arc<Notify>,
63}
64
65impl WebSocketSourceServer {
66    /// Creates a new WebSocket source connector in server mode.
67    #[must_use]
68    pub fn new(schema: SchemaRef, config: WebSocketSourceConfig) -> Self {
69        let parser = MessageParser::new(
70            schema.clone(),
71            config.format.clone(),
72            JsonDecoderConfig::default(),
73        );
74
75        Self {
76            config,
77            schema,
78            parser,
79            state: ConnectorState::Created,
80            metrics: WebSocketSourceMetrics::new(),
81            checkpoint_state: WebSocketSourceCheckpoint::default(),
82            rx: None,
83            shutdown_tx: None,
84            acceptor_handle: None,
85            message_buffer: Vec::new(),
86            connected_clients: Arc::new(AtomicU64::new(0)),
87            max_batch_size: 1000,
88            data_ready: Arc::new(Notify::new()),
89        }
90    }
91
92    /// Returns the current connector state.
93    #[must_use]
94    pub fn state(&self) -> ConnectorState {
95        self.state
96    }
97
98    /// Returns the number of currently connected clients.
99    #[must_use]
100    pub fn connected_clients(&self) -> u64 {
101        self.connected_clients.load(Ordering::Relaxed)
102    }
103}
104
105#[async_trait]
106#[allow(clippy::too_many_lines)]
107impl SourceConnector for WebSocketSourceServer {
108    async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
109        self.state = ConnectorState::Initializing;
110
111        // If config has properties, re-parse (supports runtime config via SQL WITH).
112        if !config.properties().is_empty() {
113            self.config = WebSocketSourceConfig::from_config(config)?;
114        }
115
116        let (bind_address, max_connections, _path) = match &self.config.mode {
117            SourceMode::Server {
118                bind_address,
119                max_connections,
120                path,
121            } => (bind_address.clone(), *max_connections, path.clone()),
122            SourceMode::Client { .. } => {
123                return Err(ConnectorError::ConfigurationError(
124                    "WebSocketSourceServer is for server mode; use WebSocketSource for client mode"
125                        .into(),
126                ));
127            }
128        };
129
130        info!(
131            bind = %bind_address,
132            max_connections,
133            format = ?self.config.format,
134            "opening WebSocket source connector (server mode)"
135        );
136
137        let listener = TcpListener::bind(&bind_address).await.map_err(|e| {
138            ConnectorError::ConnectionFailed(format!("failed to bind {bind_address}: {e}"))
139        })?;
140
141        let channel_capacity = 10_000;
142        let (tx, rx) = mpsc::channel(channel_capacity);
143        let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
144
145        let connected = Arc::clone(&self.connected_clients);
146        let max_msg_size = self.config.max_message_size;
147        let data_ready = Arc::clone(&self.data_ready);
148
149        let handle = tokio::spawn(async move {
150            let mut shutdown_rx = shutdown_rx;
151
152            loop {
153                tokio::select! {
154                    accept_result = listener.accept() => {
155                        match accept_result {
156                            Ok((stream, addr)) => {
157                                let current = connected.load(Ordering::Relaxed);
158                                if current >= max_connections as u64 {
159                                    warn!(
160                                        current_connections = current,
161                                        max = max_connections,
162                                        addr = %addr,
163                                        "rejecting connection: max_connections exceeded"
164                                    );
165                                    drop(stream);
166                                    continue;
167                                }
168
169                                // Set TCP_NODELAY for low latency.
170                                let _ = stream.set_nodelay(true);
171
172                                let tx = tx.clone();
173                                let connected = Arc::clone(&connected);
174                                let mut client_shutdown = shutdown_rx.clone();
175                                let data_ready = Arc::clone(&data_ready);
176
177                                connected.fetch_add(1, Ordering::Relaxed);
178                                debug!(addr = %addr, "accepted WebSocket client");
179
180                                tokio::spawn(async move {
181                                    let mut ws_config = tungstenite::protocol::WebSocketConfig::default();
182                                    ws_config.max_message_size = Some(max_msg_size);
183                                    ws_config.max_frame_size = Some(max_msg_size);
184                                    let ws_stream = match tokio_tungstenite::accept_async_with_config(stream, Some(ws_config)).await {
185                                        Ok(ws) => ws,
186                                        Err(e) => {
187                                            warn!(addr = %addr, error = %e, "WebSocket handshake failed");
188                                            connected.fetch_sub(1, Ordering::Relaxed);
189                                            return;
190                                        }
191                                    };
192
193                                    let (_write, mut read) = ws_stream.split();
194
195                                    loop {
196                                        tokio::select! {
197                                            msg = read.next() => {
198                                                match msg {
199                                                    Some(Ok(tungstenite::Message::Text(text))) => {
200                                                        let payload = text.as_bytes().to_vec();
201                                                        if payload.len() > max_msg_size {
202                                                            warn!(
203                                                                size = payload.len(),
204                                                                max = max_msg_size,
205                                                                addr = %addr,
206                                                                "dropping oversized text message"
207                                                            );
208                                                        } else if tx.send(payload).await.is_err() {
209                                                            break;
210                                                        } else {
211                                                            data_ready.notify_one();
212                                                        }
213                                                    }
214                                                    Some(Ok(tungstenite::Message::Binary(data))) => {
215                                                        let payload = data.to_vec();
216                                                        if payload.len() > max_msg_size {
217                                                            warn!(
218                                                                size = payload.len(),
219                                                                max = max_msg_size,
220                                                                addr = %addr,
221                                                                "dropping oversized binary message"
222                                                            );
223                                                        } else if tx.send(payload).await.is_err() {
224                                                            break;
225                                                        } else {
226                                                            data_ready.notify_one();
227                                                        }
228                                                    }
229                                                    Some(Ok(tungstenite::Message::Close(_))) | None => break,
230                                                    Some(Ok(_)) => {} // Ping/Pong handled by tungstenite
231                                                    Some(Err(e)) => {
232                                                        debug!(addr = %addr, error = %e, "client read error");
233                                                        break;
234                                                    }
235                                                }
236                                            }
237                                            _ = client_shutdown.changed() => break,
238                                        }
239                                    }
240
241                                    connected.fetch_sub(1, Ordering::Relaxed);
242                                    debug!(addr = %addr, "client disconnected");
243                                });
244                            }
245                            Err(e) => {
246                                warn!(error = %e, "accept error");
247                            }
248                        }
249                    }
250                    _ = shutdown_rx.changed() => {
251                        info!("acceptor shutting down");
252                        break;
253                    }
254                }
255            }
256        });
257
258        self.rx = Some(rx);
259        self.shutdown_tx = Some(shutdown_tx);
260        self.acceptor_handle = Some(handle);
261        self.state = ConnectorState::Running;
262
263        info!(bind = %bind_address, "WebSocket source server started");
264        Ok(())
265    }
266
267    #[allow(clippy::cast_possible_truncation)]
268    async fn poll_batch(
269        &mut self,
270        max_records: usize,
271    ) -> Result<Option<SourceBatch>, ConnectorError> {
272        if self.state != ConnectorState::Running {
273            return Err(ConnectorError::InvalidState {
274                expected: "Running".into(),
275                actual: self.state.to_string(),
276            });
277        }
278
279        let rx = self
280            .rx
281            .as_mut()
282            .ok_or_else(|| ConnectorError::InvalidState {
283                expected: "channel initialized".into(),
284                actual: "channel is None".into(),
285            })?;
286
287        let limit = max_records.min(self.max_batch_size);
288
289        // Non-blocking drain: pull all available messages from the channel.
290        // The pipeline coordinator handles wake-up timing via data_ready_notify().
291        while self.message_buffer.len() < limit {
292            match rx.try_recv() {
293                Ok(payload) => {
294                    self.metrics.record_message(payload.len() as u64);
295                    self.message_buffer.push(payload);
296                }
297                Err(mpsc::error::TryRecvError::Empty) => break,
298                Err(mpsc::error::TryRecvError::Disconnected) => {
299                    if self.message_buffer.is_empty() {
300                        self.state = ConnectorState::Failed;
301                        return Err(ConnectorError::ReadError(
302                            "WebSocket source server acceptor terminated".into(),
303                        ));
304                    }
305                    // Buffer has data already dequeued — break so we can
306                    // parse and return the final batch. Next poll_batch()
307                    // call will see Disconnected again with an empty buffer
308                    // and transition to Failed.
309                    break;
310                }
311            }
312        }
313
314        self.metrics
315            .set_connected_clients(self.connected_clients.load(Ordering::Relaxed));
316
317        if self.message_buffer.is_empty() {
318            return Ok(None);
319        }
320
321        let refs: Vec<&[u8]> = self.message_buffer.iter().map(Vec::as_slice).collect();
322        let batch = self.parser.parse_batch(&refs).inspect_err(|_e| {
323            self.metrics.record_parse_error();
324        })?;
325
326        let num_rows = batch.num_rows();
327        self.message_buffer.clear();
328
329        // Use event time from batch data if configured, otherwise wall-clock.
330        if let Some(ref field) = self.config.event_time_field {
331            if let Some(max_ts) = super::parser::extract_max_event_time(&batch, field) {
332                self.checkpoint_state.watermark = max_ts;
333            }
334        } else {
335            self.checkpoint_state.watermark = std::time::SystemTime::now()
336                .duration_since(std::time::UNIX_EPOCH)
337                .unwrap_or_default()
338                .as_millis() as i64;
339        }
340
341        debug!(
342            records = num_rows,
343            clients = self.connected_clients(),
344            "polled batch from WebSocket server source"
345        );
346        Ok(Some(SourceBatch::new(batch)))
347    }
348
349    fn schema(&self) -> SchemaRef {
350        self.schema.clone()
351    }
352
353    fn checkpoint(&self) -> SourceCheckpoint {
354        self.checkpoint_state.to_source_checkpoint(0)
355    }
356
357    async fn restore(&mut self, checkpoint: &SourceCheckpoint) -> Result<(), ConnectorError> {
358        info!(
359            epoch = checkpoint.epoch(),
360            "restoring WebSocket source server from checkpoint (best-effort)"
361        );
362        self.checkpoint_state = WebSocketSourceCheckpoint::from_source_checkpoint(checkpoint);
363        warn!("WebSocket source server restored; data gap expected (non-replayable)");
364        Ok(())
365    }
366
367    fn health_check(&self) -> HealthStatus {
368        match self.state {
369            ConnectorState::Running => HealthStatus::Healthy,
370            ConnectorState::Created | ConnectorState::Initializing => HealthStatus::Unknown,
371            ConnectorState::Paused => HealthStatus::Degraded("connector paused".into()),
372            ConnectorState::Recovering => HealthStatus::Degraded("recovering".into()),
373            ConnectorState::Closed => HealthStatus::Unhealthy("closed".into()),
374            ConnectorState::Failed => HealthStatus::Unhealthy("failed".into()),
375        }
376    }
377
378    fn metrics(&self) -> ConnectorMetrics {
379        self.metrics.to_connector_metrics()
380    }
381
382    fn data_ready_notify(&self) -> Option<Arc<Notify>> {
383        Some(Arc::clone(&self.data_ready))
384    }
385
386    fn supports_replay(&self) -> bool {
387        false
388    }
389
390    async fn close(&mut self) -> Result<(), ConnectorError> {
391        info!("closing WebSocket source server");
392
393        if let Some(tx) = self.shutdown_tx.take() {
394            let _ = tx.send(true);
395        }
396
397        if let Some(handle) = self.acceptor_handle.take() {
398            let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
399        }
400
401        self.rx = None;
402        self.state = ConnectorState::Closed;
403        info!("WebSocket source server closed");
404        Ok(())
405    }
406}
407
408impl std::fmt::Debug for WebSocketSourceServer {
409    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
410        f.debug_struct("WebSocketSourceServer")
411            .field("state", &self.state)
412            .field("connected_clients", &self.connected_clients())
413            .finish_non_exhaustive()
414    }
415}
416
417#[cfg(test)]
418mod tests {
419    use super::super::source_config::MessageFormat;
420    use super::*;
421    use arrow_schema::{DataType, Field, Schema};
422
423    fn test_schema() -> SchemaRef {
424        Arc::new(Schema::new(vec![
425            Field::new("id", DataType::Utf8, true),
426            Field::new("value", DataType::Utf8, true),
427        ]))
428    }
429
430    fn test_config() -> WebSocketSourceConfig {
431        WebSocketSourceConfig {
432            mode: SourceMode::Server {
433                bind_address: "127.0.0.1:0".into(),
434                max_connections: 100,
435                path: None,
436            },
437            format: MessageFormat::Json,
438            on_backpressure: super::super::backpressure::BackpressureStrategy::Block,
439            event_time_field: None,
440            event_time_format: None,
441            max_message_size: 64 * 1024 * 1024,
442            auth: None,
443        }
444    }
445
446    #[test]
447    fn test_new() {
448        let server = WebSocketSourceServer::new(test_schema(), test_config());
449        assert_eq!(server.state(), ConnectorState::Created);
450        assert_eq!(server.connected_clients(), 0);
451    }
452
453    #[test]
454    fn test_schema_returned() {
455        let schema = test_schema();
456        let server = WebSocketSourceServer::new(schema.clone(), test_config());
457        assert_eq!(server.schema(), schema);
458    }
459
460    #[test]
461    fn test_health_created() {
462        let server = WebSocketSourceServer::new(test_schema(), test_config());
463        assert_eq!(server.health_check(), HealthStatus::Unknown);
464    }
465}