Skip to main content

laminar_connectors/websocket/
source_server.rs

1//! WebSocket source connector — server mode.
2//!
3//! [`WebSocketSourceServer`] listens on a TCP port and accepts incoming
4//! WebSocket connections from clients pushing data (e.g., `IoT` sensors,
5//! browser events).
6
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::Arc;
9
10use arrow_schema::SchemaRef;
11use async_trait::async_trait;
12use crossfire::{mpsc, AsyncRx, TryRecvError};
13use futures_util::StreamExt;
14use tokio::net::TcpListener;
15use tokio::sync::Notify;
16use tracing::{debug, info, warn};
17
18use crate::checkpoint::SourceCheckpoint;
19use crate::config::{ConnectorConfig, ConnectorState};
20use crate::connector::{SourceBatch, SourceConnector};
21use crate::error::ConnectorError;
22use crate::health::HealthStatus;
23use crate::metrics::ConnectorMetrics;
24
25use crate::schema::json::decoder::JsonDecoderConfig;
26
27use super::checkpoint::WebSocketSourceCheckpoint;
28use super::metrics::WebSocketSourceMetrics;
29use super::parser::MessageParser;
30use super::source_config::{SourceMode, WebSocketSourceConfig};
31
32/// WebSocket source connector in server mode.
33///
34/// Binds a TCP listener and accepts incoming WebSocket connections.
35/// All clients' messages feed into the same bounded channel for
36/// `poll_batch()` consumption.
37pub struct WebSocketSourceServer {
38    /// Parsed configuration.
39    config: WebSocketSourceConfig,
40    /// Output Arrow schema.
41    schema: SchemaRef,
42    /// Message parser.
43    parser: MessageParser,
44    /// Connector lifecycle state.
45    state: ConnectorState,
46    /// Metrics.
47    metrics: WebSocketSourceMetrics,
48    /// Checkpoint state.
49    checkpoint_state: WebSocketSourceCheckpoint,
50    /// Bounded channel receiver for messages from client handler tasks.
51    rx: Option<AsyncRx<mpsc::Array<Vec<u8>>>>,
52    /// Shutdown signal sender.
53    shutdown_tx: Option<tokio::sync::watch::Sender<bool>>,
54    /// Handle to the spawned acceptor task.
55    acceptor_handle: Option<tokio::task::JoinHandle<()>>,
56    /// Message buffer between polls.
57    message_buffer: Vec<Vec<u8>>,
58    /// Connected client count (shared with acceptor task).
59    connected_clients: Arc<AtomicU64>,
60    /// Maximum records per batch.
61    max_batch_size: usize,
62    /// Notification handle signalled when data arrives from client handler tasks.
63    data_ready: Arc<Notify>,
64}
65
66impl WebSocketSourceServer {
67    /// Creates a new WebSocket source connector in server mode.
68    #[must_use]
69    pub fn new(
70        schema: SchemaRef,
71        config: WebSocketSourceConfig,
72        registry: Option<&prometheus::Registry>,
73    ) -> Self {
74        let parser = MessageParser::new(
75            schema.clone(),
76            config.format.clone(),
77            JsonDecoderConfig::default(),
78        );
79
80        Self {
81            config,
82            schema,
83            parser,
84            state: ConnectorState::Created,
85            metrics: WebSocketSourceMetrics::new(registry),
86            checkpoint_state: WebSocketSourceCheckpoint::default(),
87            rx: None,
88            shutdown_tx: None,
89            acceptor_handle: None,
90            message_buffer: Vec::new(),
91            connected_clients: Arc::new(AtomicU64::new(0)),
92            max_batch_size: 1000,
93            data_ready: Arc::new(Notify::new()),
94        }
95    }
96
97    /// Returns the current connector state.
98    #[must_use]
99    pub fn state(&self) -> ConnectorState {
100        self.state
101    }
102
103    /// Returns the number of currently connected clients.
104    #[must_use]
105    pub fn connected_clients(&self) -> u64 {
106        self.connected_clients.load(Ordering::Relaxed)
107    }
108}
109
110#[async_trait]
111#[allow(clippy::too_many_lines)]
112impl SourceConnector for WebSocketSourceServer {
113    async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
114        self.state = ConnectorState::Initializing;
115
116        // If config has properties, re-parse (supports runtime config via SQL WITH).
117        if !config.properties().is_empty() {
118            self.config = WebSocketSourceConfig::from_config(config)?;
119        }
120
121        let (bind_address, max_connections, _path) = match &self.config.mode {
122            SourceMode::Server {
123                bind_address,
124                max_connections,
125                path,
126            } => (bind_address.clone(), *max_connections, path.clone()),
127            SourceMode::Client { .. } => {
128                return Err(ConnectorError::ConfigurationError(
129                    "WebSocketSourceServer is for server mode; use WebSocketSource for client mode"
130                        .into(),
131                ));
132            }
133        };
134
135        info!(
136            bind = %bind_address,
137            max_connections,
138            format = ?self.config.format,
139            "opening WebSocket source connector (server mode)"
140        );
141
142        let listener = TcpListener::bind(&bind_address).await.map_err(|e| {
143            ConnectorError::ConnectionFailed(format!("failed to bind {bind_address}: {e}"))
144        })?;
145
146        let channel_capacity = 10_000;
147        let (tx, rx) = mpsc::bounded_async::<Vec<u8>>(channel_capacity);
148        let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
149
150        let connected = Arc::clone(&self.connected_clients);
151        let max_msg_size = self.config.max_message_size;
152        let data_ready = Arc::clone(&self.data_ready);
153
154        let handle = tokio::spawn(async move {
155            let mut shutdown_rx = shutdown_rx;
156
157            loop {
158                tokio::select! {
159                    accept_result = listener.accept() => {
160                        match accept_result {
161                            Ok((stream, addr)) => {
162                                let current = connected.load(Ordering::Relaxed);
163                                if current >= max_connections as u64 {
164                                    warn!(
165                                        current_connections = current,
166                                        max = max_connections,
167                                        addr = %addr,
168                                        "rejecting connection: max_connections exceeded"
169                                    );
170                                    drop(stream);
171                                    continue;
172                                }
173
174                                // Set TCP_NODELAY for low latency.
175                                let _ = stream.set_nodelay(true);
176
177                                let tx = tx.clone();
178                                let connected = Arc::clone(&connected);
179                                let mut client_shutdown = shutdown_rx.clone();
180                                let data_ready = Arc::clone(&data_ready);
181
182                                connected.fetch_add(1, Ordering::Relaxed);
183                                debug!(addr = %addr, "accepted WebSocket client");
184
185                                tokio::spawn(async move {
186                                    let mut ws_config = tungstenite::protocol::WebSocketConfig::default();
187                                    ws_config.max_message_size = Some(max_msg_size);
188                                    ws_config.max_frame_size = Some(max_msg_size);
189                                    let ws_stream = match tokio_tungstenite::accept_async_with_config(stream, Some(ws_config)).await {
190                                        Ok(ws) => ws,
191                                        Err(e) => {
192                                            warn!(addr = %addr, error = %e, "WebSocket handshake failed");
193                                            connected.fetch_sub(1, Ordering::Relaxed);
194                                            return;
195                                        }
196                                    };
197
198                                    let (_write, mut read) = ws_stream.split();
199
200                                    loop {
201                                        tokio::select! {
202                                            msg = read.next() => {
203                                                match msg {
204                                                    Some(Ok(tungstenite::Message::Text(text))) => {
205                                                        let payload = text.as_bytes().to_vec();
206                                                        if payload.len() > max_msg_size {
207                                                            warn!(
208                                                                size = payload.len(),
209                                                                max = max_msg_size,
210                                                                addr = %addr,
211                                                                "dropping oversized text message"
212                                                            );
213                                                        } else if tx.send(payload).await.is_err() {
214                                                            break;
215                                                        } else {
216                                                            data_ready.notify_one();
217                                                        }
218                                                    }
219                                                    Some(Ok(tungstenite::Message::Binary(data))) => {
220                                                        let payload = data.to_vec();
221                                                        if payload.len() > max_msg_size {
222                                                            warn!(
223                                                                size = payload.len(),
224                                                                max = max_msg_size,
225                                                                addr = %addr,
226                                                                "dropping oversized binary message"
227                                                            );
228                                                        } else if tx.send(payload).await.is_err() {
229                                                            break;
230                                                        } else {
231                                                            data_ready.notify_one();
232                                                        }
233                                                    }
234                                                    Some(Ok(tungstenite::Message::Close(_))) | None => break,
235                                                    Some(Ok(_)) => {} // Ping/Pong handled by tungstenite
236                                                    Some(Err(e)) => {
237                                                        debug!(addr = %addr, error = %e, "client read error");
238                                                        break;
239                                                    }
240                                                }
241                                            }
242                                            _ = client_shutdown.changed() => break,
243                                        }
244                                    }
245
246                                    connected.fetch_sub(1, Ordering::Relaxed);
247                                    debug!(addr = %addr, "client disconnected");
248                                });
249                            }
250                            Err(e) => {
251                                warn!(error = %e, "accept error");
252                            }
253                        }
254                    }
255                    _ = shutdown_rx.changed() => {
256                        info!("acceptor shutting down");
257                        break;
258                    }
259                }
260            }
261        });
262
263        self.rx = Some(rx);
264        self.shutdown_tx = Some(shutdown_tx);
265        self.acceptor_handle = Some(handle);
266        self.state = ConnectorState::Running;
267
268        info!(bind = %bind_address, "WebSocket source server started");
269        Ok(())
270    }
271
272    #[allow(clippy::cast_possible_truncation)]
273    async fn poll_batch(
274        &mut self,
275        max_records: usize,
276    ) -> Result<Option<SourceBatch>, ConnectorError> {
277        if self.state != ConnectorState::Running {
278            return Err(ConnectorError::InvalidState {
279                expected: "Running".into(),
280                actual: self.state.to_string(),
281            });
282        }
283
284        let rx = self
285            .rx
286            .as_mut()
287            .ok_or_else(|| ConnectorError::InvalidState {
288                expected: "channel initialized".into(),
289                actual: "channel is None".into(),
290            })?;
291
292        let limit = max_records.min(self.max_batch_size);
293
294        // Non-blocking drain: pull all available messages from the channel.
295        // The pipeline coordinator handles wake-up timing via data_ready_notify().
296        while self.message_buffer.len() < limit {
297            match rx.try_recv() {
298                Ok(payload) => {
299                    self.metrics.record_message(payload.len() as u64);
300                    self.message_buffer.push(payload);
301                }
302                Err(TryRecvError::Empty) => break,
303                Err(TryRecvError::Disconnected) => {
304                    if self.message_buffer.is_empty() {
305                        self.state = ConnectorState::Failed;
306                        return Err(ConnectorError::ReadError(
307                            "WebSocket source server acceptor terminated".into(),
308                        ));
309                    }
310                    // Buffer has data already dequeued — break so we can
311                    // parse and return the final batch. Next poll_batch()
312                    // call will see Disconnected again with an empty buffer
313                    // and transition to Failed.
314                    break;
315                }
316            }
317        }
318
319        self.metrics
320            .set_connected_clients(self.connected_clients.load(Ordering::Relaxed));
321
322        if self.message_buffer.is_empty() {
323            return Ok(None);
324        }
325
326        let refs: Vec<&[u8]> = self.message_buffer.iter().map(Vec::as_slice).collect();
327        let batch = self.parser.parse_batch(&refs).inspect_err(|_e| {
328            self.metrics.record_parse_error();
329        })?;
330
331        let num_rows = batch.num_rows();
332        self.message_buffer.clear();
333
334        // Use event time from batch data if configured, otherwise wall-clock.
335        if let Some(ref field) = self.config.event_time_field {
336            if let Some(max_ts) = super::parser::extract_max_event_time(&batch, field)? {
337                self.checkpoint_state.watermark = max_ts;
338            }
339        } else {
340            self.checkpoint_state.watermark = std::time::SystemTime::now()
341                .duration_since(std::time::UNIX_EPOCH)
342                .unwrap_or_default()
343                .as_millis() as i64;
344        }
345
346        debug!(
347            records = num_rows,
348            clients = self.connected_clients(),
349            "polled batch from WebSocket server source"
350        );
351        Ok(Some(SourceBatch::new(batch)))
352    }
353
354    fn schema(&self) -> SchemaRef {
355        self.schema.clone()
356    }
357
358    fn checkpoint(&self) -> SourceCheckpoint {
359        self.checkpoint_state.to_source_checkpoint(0)
360    }
361
362    async fn restore(&mut self, checkpoint: &SourceCheckpoint) -> Result<(), ConnectorError> {
363        info!(
364            epoch = checkpoint.epoch(),
365            "restoring WebSocket source server from checkpoint (best-effort)"
366        );
367        self.checkpoint_state = WebSocketSourceCheckpoint::from_source_checkpoint(checkpoint);
368        warn!("WebSocket source server restored; data gap expected (non-replayable)");
369        Ok(())
370    }
371
372    fn health_check(&self) -> HealthStatus {
373        match self.state {
374            ConnectorState::Running => HealthStatus::Healthy,
375            ConnectorState::Created | ConnectorState::Initializing => HealthStatus::Unknown,
376            ConnectorState::Paused => HealthStatus::Degraded("connector paused".into()),
377            ConnectorState::Recovering => HealthStatus::Degraded("recovering".into()),
378            ConnectorState::Closed => HealthStatus::Unhealthy("closed".into()),
379            ConnectorState::Failed => HealthStatus::Unhealthy("failed".into()),
380        }
381    }
382
383    fn metrics(&self) -> ConnectorMetrics {
384        self.metrics.to_connector_metrics()
385    }
386
387    fn data_ready_notify(&self) -> Option<Arc<Notify>> {
388        Some(Arc::clone(&self.data_ready))
389    }
390
391    fn supports_replay(&self) -> bool {
392        false
393    }
394
395    async fn close(&mut self) -> Result<(), ConnectorError> {
396        info!("closing WebSocket source server");
397
398        if let Some(tx) = self.shutdown_tx.take() {
399            let _ = tx.send(true);
400        }
401
402        if let Some(handle) = self.acceptor_handle.take() {
403            let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
404        }
405
406        self.rx = None;
407        self.state = ConnectorState::Closed;
408        info!("WebSocket source server closed");
409        Ok(())
410    }
411}
412
413impl std::fmt::Debug for WebSocketSourceServer {
414    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
415        f.debug_struct("WebSocketSourceServer")
416            .field("state", &self.state)
417            .field("connected_clients", &self.connected_clients())
418            .finish_non_exhaustive()
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    use super::super::source_config::MessageFormat;
425    use super::*;
426    use arrow_schema::{DataType, Field, Schema};
427
428    fn test_schema() -> SchemaRef {
429        Arc::new(Schema::new(vec![
430            Field::new("id", DataType::Utf8, true),
431            Field::new("value", DataType::Utf8, true),
432        ]))
433    }
434
435    fn test_config() -> WebSocketSourceConfig {
436        WebSocketSourceConfig {
437            mode: SourceMode::Server {
438                bind_address: "127.0.0.1:0".into(),
439                max_connections: 100,
440                path: None,
441            },
442            format: MessageFormat::Json,
443            on_backpressure: super::super::backpressure::WsBackpressure::Block,
444            event_time_field: None,
445            event_time_format: None,
446            max_message_size: 64 * 1024 * 1024,
447            auth: None,
448        }
449    }
450
451    #[test]
452    fn test_new() {
453        let server = WebSocketSourceServer::new(test_schema(), test_config(), None);
454        assert_eq!(server.state(), ConnectorState::Created);
455        assert_eq!(server.connected_clients(), 0);
456    }
457
458    #[test]
459    fn test_schema_returned() {
460        let schema = test_schema();
461        let server = WebSocketSourceServer::new(schema.clone(), test_config(), None);
462        assert_eq!(server.schema(), schema);
463    }
464
465    #[test]
466    fn test_health_created() {
467        let server = WebSocketSourceServer::new(test_schema(), test_config(), None);
468        assert_eq!(server.health_check(), HealthStatus::Unknown);
469    }
470}