Skip to main content

laminar_connectors/websocket/
source.rs

1//! WebSocket source connector — client mode.
2//!
3//! [`WebSocketSource`] connects to an external WebSocket server (e.g., exchange
4//! market data feeds) and produces Arrow `RecordBatch` data via the
5//! [`SourceConnector`] trait.
6//!
7//! # Delivery Guarantees
8//!
9//! WebSocket is non-replayable. This connector provides **at-most-once** or
10//! **best-effort** delivery. On recovery, data gaps should be expected.
11
12use std::sync::Arc;
13
14use arrow_schema::SchemaRef;
15use async_trait::async_trait;
16use crossfire::{mpsc, AsyncRx, MAsyncTx, TryRecvError, TrySendError};
17use futures_util::{SinkExt, StreamExt};
18use tokio::sync::Notify;
19use tracing::{debug, info, warn};
20
21use crate::checkpoint::SourceCheckpoint;
22use crate::config::{ConnectorConfig, ConnectorState};
23use crate::connector::{SourceBatch, SourceConnector};
24use crate::error::ConnectorError;
25
26use crate::schema::json::decoder::JsonDecoderConfig;
27
28use super::backpressure::WsBackpressure;
29use super::checkpoint::WebSocketSourceCheckpoint;
30use super::connection::ConnectionManager;
31use super::metrics::WebSocketSourceMetrics;
32use super::parser::MessageParser;
33use super::source_config::{ReconnectConfig, SourceMode, WebSocketSourceConfig};
34
35/// Internal channel message from the WS reader task to the connector.
36enum WsMessage {
37    /// A raw WebSocket message payload.
38    Data(Vec<u8>),
39    /// The connection was lost.
40    Disconnected(String),
41}
42
43/// WebSocket source connector in client mode.
44///
45/// Connects to one or more external WebSocket server URLs and consumes
46/// messages, converting them to Arrow `RecordBatch` data.
47///
48/// All WebSocket I/O runs in a spawned Tokio task (Ring 2). Parsed data
49/// is delivered to `poll_batch()` via a bounded channel.
50pub struct WebSocketSource {
51    /// Parsed configuration.
52    config: WebSocketSourceConfig,
53    /// Output Arrow schema.
54    schema: SchemaRef,
55    /// Message parser (JSON/CSV/Binary → Arrow).
56    parser: MessageParser,
57    /// Connector lifecycle state.
58    state: ConnectorState,
59    /// Metrics.
60    metrics: WebSocketSourceMetrics,
61    /// Checkpoint state.
62    checkpoint_state: WebSocketSourceCheckpoint,
63    /// Bounded channel receiver for messages from the WS reader task.
64    rx: Option<AsyncRx<mpsc::Array<WsMessage>>>,
65    /// Shutdown signal sender.
66    shutdown_tx: Option<tokio::sync::watch::Sender<bool>>,
67    /// Handle to the spawned reader task.
68    reader_handle: Option<tokio::task::JoinHandle<()>>,
69    /// Buffer of raw messages accumulated between `poll_batch()` calls.
70    message_buffer: Vec<Vec<u8>>,
71    /// Maximum records per batch.
72    max_batch_size: usize,
73    /// Notification handle signalled when data arrives from the reader task.
74    data_ready: Arc<Notify>,
75}
76
77impl WebSocketSource {
78    /// Creates a new WebSocket source connector in client mode.
79    #[must_use]
80    pub fn new(
81        schema: SchemaRef,
82        config: WebSocketSourceConfig,
83        registry: Option<&prometheus::Registry>,
84    ) -> Self {
85        let parser = MessageParser::new(
86            schema.clone(),
87            config.format.clone(),
88            JsonDecoderConfig::default(),
89        );
90
91        Self {
92            config,
93            schema,
94            parser,
95            state: ConnectorState::Created,
96            metrics: WebSocketSourceMetrics::new(registry),
97            checkpoint_state: WebSocketSourceCheckpoint::default(),
98            rx: None,
99            shutdown_tx: None,
100            reader_handle: None,
101            message_buffer: Vec::new(),
102            max_batch_size: 1000,
103            data_ready: Arc::new(Notify::new()),
104        }
105    }
106
107    /// Returns the current connector state.
108    #[must_use]
109    pub fn state(&self) -> ConnectorState {
110        self.state
111    }
112
113    /// Spawns the WebSocket reader task that connects to the server and
114    /// feeds messages through the bounded channel.
115    #[allow(
116        clippy::too_many_arguments,
117        clippy::too_many_lines,
118        clippy::unused_self
119    )]
120    fn spawn_reader(
121        &self,
122        urls: Vec<String>,
123        subscribe_message: Option<String>,
124        reconnect: ReconnectConfig,
125        max_message_size: usize,
126        on_backpressure: WsBackpressure,
127        tx: MAsyncTx<mpsc::Array<WsMessage>>,
128        mut shutdown_rx: tokio::sync::watch::Receiver<bool>,
129        data_ready: Arc<Notify>,
130    ) -> tokio::task::JoinHandle<()> {
131        tokio::spawn(async move {
132            let mut conn_mgr = ConnectionManager::new(urls, reconnect);
133
134            'outer: loop {
135                // Check shutdown.
136                if *shutdown_rx.borrow() {
137                    break;
138                }
139
140                let url = conn_mgr.current_url().to_string();
141                info!(url = %url, "connecting to WebSocket server");
142
143                // Attempt connection with frame-level size cap.
144                let mut ws_config = tungstenite::protocol::WebSocketConfig::default();
145                ws_config.max_message_size = Some(max_message_size);
146                ws_config.max_frame_size = Some(max_message_size);
147                let ws_stream = match tokio_tungstenite::connect_async_with_config(
148                    &url,
149                    Some(ws_config),
150                    true, // disable Nagle for low latency
151                )
152                .await
153                {
154                    Ok((stream, _response)) => {
155                        conn_mgr.reset();
156                        info!(url = %url, "WebSocket connection established");
157                        stream
158                    }
159                    Err(e) => {
160                        warn!(url = %url, error = %e, "WebSocket connection failed");
161                        if let Some(delay) = conn_mgr.next_backoff() {
162                            tokio::select! {
163                                () = tokio::time::sleep(delay) => continue,
164                                _ = shutdown_rx.changed() => break,
165                            }
166                        } else {
167                            let _ = tx
168                                .send(WsMessage::Disconnected(format!(
169                                    "connection failed, no more retries: {e}"
170                                )))
171                                .await;
172                            break;
173                        }
174                    }
175                };
176
177                let (mut write, mut read) = ws_stream.split();
178
179                // Send subscription message if configured.
180                if let Some(ref msg) = subscribe_message {
181                    if let Err(e) = write
182                        .send(tungstenite::Message::Text(msg.clone().into()))
183                        .await
184                    {
185                        warn!(error = %e, "failed to send subscription message");
186                        if let Some(delay) = conn_mgr.next_backoff() {
187                            tokio::select! {
188                                () = tokio::time::sleep(delay) => continue,
189                                _ = shutdown_rx.changed() => break,
190                            }
191                        }
192                        continue;
193                    }
194                    debug!("subscription message sent");
195                }
196
197                // Read loop.
198                loop {
199                    tokio::select! {
200                        msg = read.next() => {
201                            match msg {
202                                Some(Ok(tungstenite::Message::Text(text))) => {
203                                    let payload = text.as_bytes().to_vec();
204                                    if payload.len() > max_message_size {
205                                        warn!(size = payload.len(), max = max_message_size, "message exceeds max size, dropping");
206                                        continue;
207                                    }
208                                    if send_with_backpressure(&tx, WsMessage::Data(payload), &on_backpressure, &data_ready).await.is_err() {
209                                        break 'outer;
210                                    }
211                                }
212                                Some(Ok(tungstenite::Message::Binary(data))) => {
213                                    let payload = data.to_vec();
214                                    if payload.len() > max_message_size {
215                                        warn!(size = payload.len(), max = max_message_size, "message exceeds max size, dropping");
216                                        continue;
217                                    }
218                                    if send_with_backpressure(&tx, WsMessage::Data(payload), &on_backpressure, &data_ready).await.is_err() {
219                                        break 'outer;
220                                    }
221                                }
222                                Some(Ok(tungstenite::Message::Ping(data))) => {
223                                    let _ = write.send(tungstenite::Message::Pong(data)).await;
224                                }
225                                Some(Ok(tungstenite::Message::Close(_))) => {
226                                    info!(url = %url, "server sent Close frame");
227                                    break;
228                                }
229                                Some(Ok(_)) => {} // Pong, Frame — ignore
230                                Some(Err(e)) => {
231                                    warn!(url = %url, error = %e, "WebSocket read error");
232                                    break;
233                                }
234                                None => {
235                                    info!(url = %url, "WebSocket stream ended");
236                                    break;
237                                }
238                            }
239                        }
240                        _ = shutdown_rx.changed() => {
241                            debug!("shutdown signal received in reader");
242                            let _ = write.send(tungstenite::Message::Close(None)).await;
243                            break 'outer;
244                        }
245                    }
246                }
247
248                // Disconnected — attempt reconnect.
249                let _ = tx
250                    .send(WsMessage::Disconnected(format!("disconnected from {url}")))
251                    .await;
252
253                if let Some(delay) = conn_mgr.next_backoff() {
254                    tokio::select! {
255                        () = tokio::time::sleep(delay) => {},
256                        _ = shutdown_rx.changed() => break,
257                    }
258                } else {
259                    break;
260                }
261            }
262        })
263    }
264}
265
266/// Sends a message through the channel, applying the backpressure strategy
267/// if the channel is full. Signals `data_ready` on successful send so the
268/// pipeline coordinator wakes immediately.
269///
270/// Returns `Err(())` if the channel is closed (shutdown).
271async fn send_with_backpressure(
272    tx: &MAsyncTx<mpsc::Array<WsMessage>>,
273    msg: WsMessage,
274    strategy: &WsBackpressure,
275    data_ready: &Notify,
276) -> Result<(), ()> {
277    let result = match strategy {
278        WsBackpressure::Block => tx.send(msg).await.map_err(|_| ()),
279        WsBackpressure::DropNewest => match tx.try_send(msg) {
280            Ok(()) | Err(TrySendError::Full(_)) => Ok(()),
281            Err(TrySendError::Disconnected(_)) => Err(()),
282        },
283        // TODO(F006): implement DropOldest, Buffer, Sample properly.
284        WsBackpressure::DropOldest
285        | WsBackpressure::Buffer { .. }
286        | WsBackpressure::Sample { .. } => match tx.try_send(msg) {
287            Ok(()) | Err(TrySendError::Full(_)) => Ok(()),
288            Err(TrySendError::Disconnected(_)) => Err(()),
289        },
290    };
291    if result.is_ok() {
292        data_ready.notify_one();
293    }
294    result
295}
296
297#[async_trait]
298impl SourceConnector for WebSocketSource {
299    async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
300        self.state = ConnectorState::Initializing;
301
302        // If config has properties, re-parse (supports runtime config via SQL WITH).
303        if !config.properties().is_empty() {
304            self.config = WebSocketSourceConfig::from_config(config)?;
305        }
306
307        // Override schema from SQL DDL if provided.
308        if let Some(schema) = config.arrow_schema() {
309            info!(
310                fields = schema.fields().len(),
311                "using SQL-defined schema for deserialization"
312            );
313            self.schema = schema;
314            let decoder_config = JsonDecoderConfig::from_connector_config(config);
315            self.parser = MessageParser::new(
316                self.schema.clone(),
317                self.config.format.clone(),
318                decoder_config,
319            );
320        }
321
322        let mode = &self.config.mode;
323        let (urls, subscribe_message, reconnect, ping_interval, ping_timeout) = match mode {
324            SourceMode::Client {
325                urls,
326                subscribe_message,
327                reconnect,
328                ping_interval,
329                ping_timeout,
330            } => (
331                urls.clone(),
332                subscribe_message.clone(),
333                reconnect.clone(),
334                *ping_interval,
335                *ping_timeout,
336            ),
337            SourceMode::Server { .. } => {
338                return Err(ConnectorError::ConfigurationError(
339                    "WebSocketSource is for client mode; use WebSocketSourceServer for server mode"
340                        .into(),
341                ));
342            }
343        };
344
345        if urls.is_empty() {
346            return Err(ConnectorError::ConfigurationError(
347                "at least one WebSocket URL is required".into(),
348            ));
349        }
350
351        // ping_interval and ping_timeout are accepted by config but not yet
352        // wired to WebSocket ping/pong frames in client-mode source.
353        // Tungstenite handles pong replies automatically for incoming pings.
354        let _ = (ping_interval, ping_timeout);
355
356        info!(
357            urls = ?urls,
358            format = ?self.config.format,
359            backpressure = ?self.config.on_backpressure,
360            "opening WebSocket source connector (client mode)"
361        );
362
363        if matches!(
364            self.config.on_backpressure,
365            WsBackpressure::DropOldest
366                | WsBackpressure::Buffer { .. }
367                | WsBackpressure::Sample { .. }
368        ) {
369            warn!(
370                strategy = ?self.config.on_backpressure,
371                "backpressure strategy not implemented, falling back to DropNewest"
372            );
373        }
374
375        // Create bounded channel between reader task and poll_batch().
376        let channel_capacity = 10_000;
377        let (tx, rx) = mpsc::bounded_async::<WsMessage>(channel_capacity);
378
379        // Create shutdown signal.
380        let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
381
382        // Spawn the reader task.
383        let handle = self.spawn_reader(
384            urls.clone(),
385            subscribe_message,
386            reconnect.clone(),
387            self.config.max_message_size,
388            self.config.on_backpressure.clone(),
389            tx,
390            shutdown_rx,
391            Arc::clone(&self.data_ready),
392        );
393
394        self.rx = Some(rx);
395        self.shutdown_tx = Some(shutdown_tx);
396        self.reader_handle = Some(handle);
397        self.state = ConnectorState::Running;
398
399        info!("WebSocket source connector opened successfully");
400        Ok(())
401    }
402
403    #[allow(clippy::cast_possible_truncation)]
404    async fn poll_batch(
405        &mut self,
406        max_records: usize,
407    ) -> Result<Option<SourceBatch>, ConnectorError> {
408        if self.state != ConnectorState::Running {
409            return Err(ConnectorError::InvalidState {
410                expected: "Running".into(),
411                actual: self.state.to_string(),
412            });
413        }
414
415        let rx = self
416            .rx
417            .as_mut()
418            .ok_or_else(|| ConnectorError::InvalidState {
419                expected: "channel initialized".into(),
420                actual: "channel is None".into(),
421            })?;
422
423        let limit = max_records.min(self.max_batch_size);
424
425        // Non-blocking drain: pull all available messages from the channel.
426        // The pipeline coordinator handles wake-up timing via data_ready_notify().
427        while self.message_buffer.len() < limit {
428            match rx.try_recv() {
429                Ok(WsMessage::Data(payload)) => {
430                    self.metrics.record_message(payload.len() as u64);
431                    self.message_buffer.push(payload);
432                }
433                Ok(WsMessage::Disconnected(reason)) => {
434                    self.metrics.record_reconnect();
435                    warn!(reason = %reason, "WebSocket disconnected");
436                    break;
437                }
438                Err(TryRecvError::Empty) => break,
439                Err(TryRecvError::Disconnected) => {
440                    // Channel closed — reader task ended.
441                    if self.message_buffer.is_empty() {
442                        self.state = ConnectorState::Failed;
443                        return Err(ConnectorError::ReadError(
444                            "WebSocket reader task terminated".into(),
445                        ));
446                    }
447                    // Drain the final batch before failing on next call.
448                    break;
449                }
450            }
451        }
452
453        if self.message_buffer.is_empty() {
454            return Ok(None);
455        }
456
457        // Parse the accumulated messages into a RecordBatch.
458        let refs: Vec<&[u8]> = self.message_buffer.iter().map(Vec::as_slice).collect();
459        let batch = self.parser.parse_batch(&refs).inspect_err(|_e| {
460            self.metrics.record_parse_error();
461        })?;
462
463        let num_rows = batch.num_rows();
464        self.message_buffer.clear();
465
466        // Update checkpoint state with wall-clock watermark.
467        // Ideally this would extract event time from the batch using the
468        // configured event_time_field, but that requires schema knowledge
469        // at this layer. The pipeline's WatermarkExtractor handles proper
470        // event-time watermarking at the SQL layer.
471        self.checkpoint_state.watermark = std::time::SystemTime::now()
472            .duration_since(std::time::UNIX_EPOCH)
473            .unwrap_or_default()
474            .as_millis() as i64;
475
476        debug!(records = num_rows, "polled batch from WebSocket");
477        Ok(Some(SourceBatch::new(batch)))
478    }
479
480    fn schema(&self) -> SchemaRef {
481        self.schema.clone()
482    }
483
484    fn checkpoint(&self) -> SourceCheckpoint {
485        // Epoch is managed by the checkpoint coordinator, not individual
486        // sources. Pass 0 here; the coordinator stamps the real epoch on
487        // the manifest. SourceCheckpoint.epoch is informational only.
488        self.checkpoint_state.to_source_checkpoint(0)
489    }
490
491    async fn restore(&mut self, checkpoint: &SourceCheckpoint) -> Result<(), ConnectorError> {
492        info!(
493            epoch = checkpoint.epoch(),
494            "restoring WebSocket source from checkpoint (best-effort)"
495        );
496        self.checkpoint_state = WebSocketSourceCheckpoint::from_source_checkpoint(checkpoint);
497
498        // WebSocket is non-replayable — log the gap.
499        warn!(
500            last_sequence = ?self.checkpoint_state.last_sequence,
501            last_event_time = ?self.checkpoint_state.last_event_time,
502            "WebSocket source restored; data gap expected (non-replayable transport)"
503        );
504
505        Ok(())
506    }
507
508    fn data_ready_notify(&self) -> Option<Arc<Notify>> {
509        Some(Arc::clone(&self.data_ready))
510    }
511
512    fn supports_replay(&self) -> bool {
513        false
514    }
515
516    async fn close(&mut self) -> Result<(), ConnectorError> {
517        info!("closing WebSocket source connector");
518
519        // Signal shutdown.
520        if let Some(tx) = self.shutdown_tx.take() {
521            let _ = tx.send(true);
522        }
523
524        // Wait for the reader task to finish.
525        if let Some(handle) = self.reader_handle.take() {
526            let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
527        }
528
529        self.rx = None;
530        self.state = ConnectorState::Closed;
531        info!("WebSocket source connector closed");
532        Ok(())
533    }
534}
535
536impl std::fmt::Debug for WebSocketSource {
537    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
538        f.debug_struct("WebSocketSource")
539            .field("state", &self.state)
540            .field("mode", &"client")
541            .field("format", &self.config.format)
542            .finish_non_exhaustive()
543    }
544}
545
546#[cfg(test)]
547mod tests {
548    use super::super::source_config::MessageFormat;
549    use super::*;
550    use arrow_schema::{DataType, Field, Schema};
551    use std::sync::Arc;
552
553    fn test_schema() -> SchemaRef {
554        Arc::new(Schema::new(vec![
555            Field::new("id", DataType::Utf8, true),
556            Field::new("value", DataType::Utf8, true),
557        ]))
558    }
559
560    fn test_config() -> WebSocketSourceConfig {
561        WebSocketSourceConfig {
562            mode: SourceMode::Client {
563                urls: vec!["ws://localhost:9090".into()],
564                subscribe_message: None,
565                reconnect: ReconnectConfig::default(),
566                ping_interval: std::time::Duration::from_secs(30),
567                ping_timeout: std::time::Duration::from_secs(10),
568            },
569            format: MessageFormat::Json,
570            on_backpressure: WsBackpressure::Block,
571            event_time_field: None,
572            event_time_format: None,
573            max_message_size: 64 * 1024 * 1024,
574            auth: None,
575        }
576    }
577
578    #[test]
579    fn test_new_defaults() {
580        let source = WebSocketSource::new(test_schema(), test_config(), None);
581        assert_eq!(source.state(), ConnectorState::Created);
582    }
583
584    #[test]
585    fn test_schema_returned() {
586        let schema = test_schema();
587        let source = WebSocketSource::new(schema.clone(), test_config(), None);
588        assert_eq!(source.schema(), schema);
589    }
590
591    #[test]
592    fn test_checkpoint_empty() {
593        let source = WebSocketSource::new(test_schema(), test_config(), None);
594        let cp = source.checkpoint();
595        assert!(!cp.is_empty()); // has websocket_state key
596    }
597}