Skip to main content

laminar_connectors/websocket/
source.rs

1//! WebSocket source connector — client mode.
2//!
3//! [`WebSocketSource`] connects to an external WebSocket server (e.g., exchange
4//! market data feeds) and produces Arrow `RecordBatch` data via the
5//! [`SourceConnector`] trait.
6//!
7//! # Delivery Guarantees
8//!
9//! WebSocket is non-replayable. This connector provides **at-most-once** or
10//! **best-effort** delivery. On recovery, data gaps should be expected.
11
12use std::sync::Arc;
13
14use arrow_schema::SchemaRef;
15use async_trait::async_trait;
16use futures_util::{SinkExt, StreamExt};
17use tokio::sync::{mpsc, Notify};
18use tracing::{debug, info, warn};
19
20use crate::checkpoint::SourceCheckpoint;
21use crate::config::{ConnectorConfig, ConnectorState};
22use crate::connector::{SourceBatch, SourceConnector};
23use crate::error::ConnectorError;
24use crate::health::HealthStatus;
25use crate::metrics::ConnectorMetrics;
26
27use crate::schema::json::decoder::JsonDecoderConfig;
28
29use super::backpressure::BackpressureStrategy;
30use super::checkpoint::WebSocketSourceCheckpoint;
31use super::connection::ConnectionManager;
32use super::metrics::WebSocketSourceMetrics;
33use super::parser::MessageParser;
34use super::source_config::{ReconnectConfig, SourceMode, WebSocketSourceConfig};
35
36/// Internal channel message from the WS reader task to the connector.
37enum WsMessage {
38    /// A raw WebSocket message payload.
39    Data(Vec<u8>),
40    /// The connection was lost.
41    Disconnected(String),
42}
43
44/// WebSocket source connector in client mode.
45///
46/// Connects to one or more external WebSocket server URLs and consumes
47/// messages, converting them to Arrow `RecordBatch` data.
48///
49/// All WebSocket I/O runs in a spawned Tokio task (Ring 2). Parsed data
50/// is delivered to `poll_batch()` via a bounded channel.
51pub struct WebSocketSource {
52    /// Parsed configuration.
53    config: WebSocketSourceConfig,
54    /// Output Arrow schema.
55    schema: SchemaRef,
56    /// Message parser (JSON/CSV/Binary → Arrow).
57    parser: MessageParser,
58    /// Connector lifecycle state.
59    state: ConnectorState,
60    /// Metrics.
61    metrics: WebSocketSourceMetrics,
62    /// Checkpoint state.
63    checkpoint_state: WebSocketSourceCheckpoint,
64    /// Bounded channel receiver for messages from the WS reader task.
65    rx: Option<mpsc::Receiver<WsMessage>>,
66    /// Shutdown signal sender.
67    shutdown_tx: Option<tokio::sync::watch::Sender<bool>>,
68    /// Handle to the spawned reader task.
69    reader_handle: Option<tokio::task::JoinHandle<()>>,
70    /// Buffer of raw messages accumulated between `poll_batch()` calls.
71    message_buffer: Vec<Vec<u8>>,
72    /// Maximum records per batch.
73    max_batch_size: usize,
74    /// Notification handle signalled when data arrives from the reader task.
75    data_ready: Arc<Notify>,
76}
77
78impl WebSocketSource {
79    /// Creates a new WebSocket source connector in client mode.
80    ///
81    /// # Arguments
82    ///
83    /// * `schema` - Arrow schema for output batches.
84    /// * `config` - WebSocket source configuration.
85    #[must_use]
86    pub fn new(schema: SchemaRef, config: WebSocketSourceConfig) -> Self {
87        let parser = MessageParser::new(
88            schema.clone(),
89            config.format.clone(),
90            JsonDecoderConfig::default(),
91        );
92
93        Self {
94            config,
95            schema,
96            parser,
97            state: ConnectorState::Created,
98            metrics: WebSocketSourceMetrics::new(),
99            checkpoint_state: WebSocketSourceCheckpoint::default(),
100            rx: None,
101            shutdown_tx: None,
102            reader_handle: None,
103            message_buffer: Vec::new(),
104            max_batch_size: 1000,
105            data_ready: Arc::new(Notify::new()),
106        }
107    }
108
109    /// Returns the current connector state.
110    #[must_use]
111    pub fn state(&self) -> ConnectorState {
112        self.state
113    }
114
115    /// Spawns the WebSocket reader task that connects to the server and
116    /// feeds messages through the bounded channel.
117    #[allow(
118        clippy::too_many_arguments,
119        clippy::too_many_lines,
120        clippy::unused_self
121    )]
122    fn spawn_reader(
123        &self,
124        urls: Vec<String>,
125        subscribe_message: Option<String>,
126        reconnect: ReconnectConfig,
127        max_message_size: usize,
128        on_backpressure: BackpressureStrategy,
129        tx: mpsc::Sender<WsMessage>,
130        mut shutdown_rx: tokio::sync::watch::Receiver<bool>,
131        data_ready: Arc<Notify>,
132    ) -> tokio::task::JoinHandle<()> {
133        tokio::spawn(async move {
134            let mut conn_mgr = ConnectionManager::new(urls, reconnect);
135
136            'outer: loop {
137                // Check shutdown.
138                if *shutdown_rx.borrow() {
139                    break;
140                }
141
142                let url = conn_mgr.current_url().to_string();
143                info!(url = %url, "connecting to WebSocket server");
144
145                // Attempt connection with frame-level size cap.
146                let mut ws_config = tungstenite::protocol::WebSocketConfig::default();
147                ws_config.max_message_size = Some(max_message_size);
148                ws_config.max_frame_size = Some(max_message_size);
149                let ws_stream = match tokio_tungstenite::connect_async_with_config(
150                    &url,
151                    Some(ws_config),
152                    true, // disable Nagle for low latency
153                )
154                .await
155                {
156                    Ok((stream, _response)) => {
157                        conn_mgr.reset();
158                        info!(url = %url, "WebSocket connection established");
159                        stream
160                    }
161                    Err(e) => {
162                        warn!(url = %url, error = %e, "WebSocket connection failed");
163                        if let Some(delay) = conn_mgr.next_backoff() {
164                            tokio::select! {
165                                () = tokio::time::sleep(delay) => continue,
166                                _ = shutdown_rx.changed() => break,
167                            }
168                        } else {
169                            let _ = tx
170                                .send(WsMessage::Disconnected(format!(
171                                    "connection failed, no more retries: {e}"
172                                )))
173                                .await;
174                            break;
175                        }
176                    }
177                };
178
179                let (mut write, mut read) = ws_stream.split();
180
181                // Send subscription message if configured.
182                if let Some(ref msg) = subscribe_message {
183                    if let Err(e) = write
184                        .send(tungstenite::Message::Text(msg.clone().into()))
185                        .await
186                    {
187                        warn!(error = %e, "failed to send subscription message");
188                        if let Some(delay) = conn_mgr.next_backoff() {
189                            tokio::select! {
190                                () = tokio::time::sleep(delay) => continue,
191                                _ = shutdown_rx.changed() => break,
192                            }
193                        }
194                        continue;
195                    }
196                    debug!("subscription message sent");
197                }
198
199                // Read loop.
200                loop {
201                    tokio::select! {
202                        msg = read.next() => {
203                            match msg {
204                                Some(Ok(tungstenite::Message::Text(text))) => {
205                                    let payload = text.as_bytes().to_vec();
206                                    if payload.len() > max_message_size {
207                                        warn!(size = payload.len(), max = max_message_size, "message exceeds max size, dropping");
208                                        continue;
209                                    }
210                                    if send_with_backpressure(&tx, WsMessage::Data(payload), &on_backpressure, &data_ready).await.is_err() {
211                                        break 'outer;
212                                    }
213                                }
214                                Some(Ok(tungstenite::Message::Binary(data))) => {
215                                    let payload = data.to_vec();
216                                    if payload.len() > max_message_size {
217                                        warn!(size = payload.len(), max = max_message_size, "message exceeds max size, dropping");
218                                        continue;
219                                    }
220                                    if send_with_backpressure(&tx, WsMessage::Data(payload), &on_backpressure, &data_ready).await.is_err() {
221                                        break 'outer;
222                                    }
223                                }
224                                Some(Ok(tungstenite::Message::Ping(data))) => {
225                                    let _ = write.send(tungstenite::Message::Pong(data)).await;
226                                }
227                                Some(Ok(tungstenite::Message::Close(_))) => {
228                                    info!(url = %url, "server sent Close frame");
229                                    break;
230                                }
231                                Some(Ok(_)) => {} // Pong, Frame — ignore
232                                Some(Err(e)) => {
233                                    warn!(url = %url, error = %e, "WebSocket read error");
234                                    break;
235                                }
236                                None => {
237                                    info!(url = %url, "WebSocket stream ended");
238                                    break;
239                                }
240                            }
241                        }
242                        _ = shutdown_rx.changed() => {
243                            debug!("shutdown signal received in reader");
244                            let _ = write.send(tungstenite::Message::Close(None)).await;
245                            break 'outer;
246                        }
247                    }
248                }
249
250                // Disconnected — attempt reconnect.
251                let _ = tx
252                    .send(WsMessage::Disconnected(format!("disconnected from {url}")))
253                    .await;
254
255                if let Some(delay) = conn_mgr.next_backoff() {
256                    tokio::select! {
257                        () = tokio::time::sleep(delay) => {},
258                        _ = shutdown_rx.changed() => break,
259                    }
260                } else {
261                    break;
262                }
263            }
264        })
265    }
266}
267
268/// Sends a message through the channel, applying the backpressure strategy
269/// if the channel is full. Signals `data_ready` on successful send so the
270/// pipeline coordinator wakes immediately.
271///
272/// Returns `Err(())` if the channel is closed (shutdown).
273async fn send_with_backpressure(
274    tx: &mpsc::Sender<WsMessage>,
275    msg: WsMessage,
276    strategy: &BackpressureStrategy,
277    data_ready: &Notify,
278) -> Result<(), ()> {
279    let result = match strategy {
280        BackpressureStrategy::Block => tx.send(msg).await.map_err(|_| ()),
281        BackpressureStrategy::DropNewest => match tx.try_send(msg) {
282            Ok(()) | Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => Ok(()),
283            Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => Err(()),
284        },
285        BackpressureStrategy::DropOldest
286        | BackpressureStrategy::Buffer { .. }
287        | BackpressureStrategy::Sample { .. } => {
288            // These strategies are not yet differentiated from DropNewest.
289            // DropOldest would need drain-and-re-send, Buffer a secondary
290            // queue, and Sample a counter-based skip. All degrade to
291            // DropNewest (try_send, drop on full) for now.
292            tracing::debug!("backpressure strategy not fully implemented, using DropNewest");
293            match tx.try_send(msg) {
294                Ok(()) | Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => Ok(()),
295                Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => Err(()),
296            }
297        }
298    };
299    if result.is_ok() {
300        data_ready.notify_one();
301    }
302    result
303}
304
305#[async_trait]
306impl SourceConnector for WebSocketSource {
307    async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
308        self.state = ConnectorState::Initializing;
309
310        // If config has properties, re-parse (supports runtime config via SQL WITH).
311        if !config.properties().is_empty() {
312            self.config = WebSocketSourceConfig::from_config(config)?;
313        }
314
315        // Override schema from SQL DDL if provided.
316        if let Some(schema) = config.arrow_schema() {
317            info!(
318                fields = schema.fields().len(),
319                "using SQL-defined schema for deserialization"
320            );
321            self.schema = schema;
322            let decoder_config = JsonDecoderConfig::from_connector_config(config);
323            self.parser = MessageParser::new(
324                self.schema.clone(),
325                self.config.format.clone(),
326                decoder_config,
327            );
328        }
329
330        let mode = &self.config.mode;
331        let (urls, subscribe_message, reconnect, ping_interval, ping_timeout) = match mode {
332            SourceMode::Client {
333                urls,
334                subscribe_message,
335                reconnect,
336                ping_interval,
337                ping_timeout,
338            } => (
339                urls.clone(),
340                subscribe_message.clone(),
341                reconnect.clone(),
342                *ping_interval,
343                *ping_timeout,
344            ),
345            SourceMode::Server { .. } => {
346                return Err(ConnectorError::ConfigurationError(
347                    "WebSocketSource is for client mode; use WebSocketSourceServer for server mode"
348                        .into(),
349                ));
350            }
351        };
352
353        if urls.is_empty() {
354            return Err(ConnectorError::ConfigurationError(
355                "at least one WebSocket URL is required".into(),
356            ));
357        }
358
359        // ping_interval and ping_timeout are accepted by config but not yet
360        // wired to WebSocket ping/pong frames in client-mode source.
361        // Tungstenite handles pong replies automatically for incoming pings.
362        let _ = (ping_interval, ping_timeout);
363
364        info!(
365            urls = ?urls,
366            format = ?self.config.format,
367            backpressure = ?self.config.on_backpressure,
368            "opening WebSocket source connector (client mode)"
369        );
370
371        // Create bounded channel between reader task and poll_batch().
372        let channel_capacity = 10_000;
373        let (tx, rx) = mpsc::channel(channel_capacity);
374
375        // Create shutdown signal.
376        let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
377
378        // Spawn the reader task.
379        let handle = self.spawn_reader(
380            urls.clone(),
381            subscribe_message,
382            reconnect.clone(),
383            self.config.max_message_size,
384            self.config.on_backpressure.clone(),
385            tx,
386            shutdown_rx,
387            Arc::clone(&self.data_ready),
388        );
389
390        self.rx = Some(rx);
391        self.shutdown_tx = Some(shutdown_tx);
392        self.reader_handle = Some(handle);
393        self.state = ConnectorState::Running;
394
395        info!("WebSocket source connector opened successfully");
396        Ok(())
397    }
398
399    #[allow(clippy::cast_possible_truncation)]
400    async fn poll_batch(
401        &mut self,
402        max_records: usize,
403    ) -> Result<Option<SourceBatch>, ConnectorError> {
404        if self.state != ConnectorState::Running {
405            return Err(ConnectorError::InvalidState {
406                expected: "Running".into(),
407                actual: self.state.to_string(),
408            });
409        }
410
411        let rx = self
412            .rx
413            .as_mut()
414            .ok_or_else(|| ConnectorError::InvalidState {
415                expected: "channel initialized".into(),
416                actual: "channel is None".into(),
417            })?;
418
419        let limit = max_records.min(self.max_batch_size);
420
421        // Non-blocking drain: pull all available messages from the channel.
422        // The pipeline coordinator handles wake-up timing via data_ready_notify().
423        while self.message_buffer.len() < limit {
424            match rx.try_recv() {
425                Ok(WsMessage::Data(payload)) => {
426                    self.metrics.record_message(payload.len() as u64);
427                    self.message_buffer.push(payload);
428                }
429                Ok(WsMessage::Disconnected(reason)) => {
430                    self.metrics.record_reconnect();
431                    warn!(reason = %reason, "WebSocket disconnected");
432                    break;
433                }
434                Err(mpsc::error::TryRecvError::Empty) => break,
435                Err(mpsc::error::TryRecvError::Disconnected) => {
436                    // Channel closed — reader task ended.
437                    if self.message_buffer.is_empty() {
438                        self.state = ConnectorState::Failed;
439                        return Err(ConnectorError::ReadError(
440                            "WebSocket reader task terminated".into(),
441                        ));
442                    }
443                    // Drain the final batch before failing on next call.
444                    break;
445                }
446            }
447        }
448
449        if self.message_buffer.is_empty() {
450            return Ok(None);
451        }
452
453        // Parse the accumulated messages into a RecordBatch.
454        let refs: Vec<&[u8]> = self.message_buffer.iter().map(Vec::as_slice).collect();
455        let batch = self.parser.parse_batch(&refs).inspect_err(|_e| {
456            self.metrics.record_parse_error();
457        })?;
458
459        let num_rows = batch.num_rows();
460        self.message_buffer.clear();
461
462        // Update checkpoint state with wall-clock watermark.
463        // Ideally this would extract event time from the batch using the
464        // configured event_time_field, but that requires schema knowledge
465        // at this layer. The pipeline's WatermarkExtractor handles proper
466        // event-time watermarking at the SQL layer.
467        self.checkpoint_state.watermark = std::time::SystemTime::now()
468            .duration_since(std::time::UNIX_EPOCH)
469            .unwrap_or_default()
470            .as_millis() as i64;
471
472        debug!(records = num_rows, "polled batch from WebSocket");
473        Ok(Some(SourceBatch::new(batch)))
474    }
475
476    fn schema(&self) -> SchemaRef {
477        self.schema.clone()
478    }
479
480    fn checkpoint(&self) -> SourceCheckpoint {
481        // Epoch is managed by the checkpoint coordinator, not individual
482        // sources. Pass 0 here; the coordinator stamps the real epoch on
483        // the manifest. SourceCheckpoint.epoch is informational only.
484        self.checkpoint_state.to_source_checkpoint(0)
485    }
486
487    async fn restore(&mut self, checkpoint: &SourceCheckpoint) -> Result<(), ConnectorError> {
488        info!(
489            epoch = checkpoint.epoch(),
490            "restoring WebSocket source from checkpoint (best-effort)"
491        );
492        self.checkpoint_state = WebSocketSourceCheckpoint::from_source_checkpoint(checkpoint);
493
494        // WebSocket is non-replayable — log the gap.
495        warn!(
496            last_sequence = ?self.checkpoint_state.last_sequence,
497            last_event_time = ?self.checkpoint_state.last_event_time,
498            "WebSocket source restored; data gap expected (non-replayable transport)"
499        );
500
501        Ok(())
502    }
503
504    fn health_check(&self) -> HealthStatus {
505        match self.state {
506            ConnectorState::Running => HealthStatus::Healthy,
507            ConnectorState::Created | ConnectorState::Initializing => HealthStatus::Unknown,
508            ConnectorState::Paused => HealthStatus::Degraded("connector paused".into()),
509            ConnectorState::Recovering => HealthStatus::Degraded("recovering".into()),
510            ConnectorState::Closed => HealthStatus::Unhealthy("closed".into()),
511            ConnectorState::Failed => HealthStatus::Unhealthy("failed".into()),
512        }
513    }
514
515    fn metrics(&self) -> ConnectorMetrics {
516        self.metrics.to_connector_metrics()
517    }
518
519    fn data_ready_notify(&self) -> Option<Arc<Notify>> {
520        Some(Arc::clone(&self.data_ready))
521    }
522
523    fn supports_replay(&self) -> bool {
524        false
525    }
526
527    async fn close(&mut self) -> Result<(), ConnectorError> {
528        info!("closing WebSocket source connector");
529
530        // Signal shutdown.
531        if let Some(tx) = self.shutdown_tx.take() {
532            let _ = tx.send(true);
533        }
534
535        // Wait for the reader task to finish.
536        if let Some(handle) = self.reader_handle.take() {
537            let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
538        }
539
540        self.rx = None;
541        self.state = ConnectorState::Closed;
542        info!("WebSocket source connector closed");
543        Ok(())
544    }
545}
546
547impl std::fmt::Debug for WebSocketSource {
548    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
549        f.debug_struct("WebSocketSource")
550            .field("state", &self.state)
551            .field("mode", &"client")
552            .field("format", &self.config.format)
553            .finish_non_exhaustive()
554    }
555}
556
557#[cfg(test)]
558mod tests {
559    use super::super::source_config::MessageFormat;
560    use super::*;
561    use arrow_schema::{DataType, Field, Schema};
562    use std::sync::Arc;
563
564    fn test_schema() -> SchemaRef {
565        Arc::new(Schema::new(vec![
566            Field::new("id", DataType::Utf8, true),
567            Field::new("value", DataType::Utf8, true),
568        ]))
569    }
570
571    fn test_config() -> WebSocketSourceConfig {
572        WebSocketSourceConfig {
573            mode: SourceMode::Client {
574                urls: vec!["ws://localhost:9090".into()],
575                subscribe_message: None,
576                reconnect: ReconnectConfig::default(),
577                ping_interval: std::time::Duration::from_secs(30),
578                ping_timeout: std::time::Duration::from_secs(10),
579            },
580            format: MessageFormat::Json,
581            on_backpressure: BackpressureStrategy::Block,
582            event_time_field: None,
583            event_time_format: None,
584            max_message_size: 64 * 1024 * 1024,
585            auth: None,
586        }
587    }
588
589    #[test]
590    fn test_new_defaults() {
591        let source = WebSocketSource::new(test_schema(), test_config());
592        assert_eq!(source.state(), ConnectorState::Created);
593    }
594
595    #[test]
596    fn test_schema_returned() {
597        let schema = test_schema();
598        let source = WebSocketSource::new(schema.clone(), test_config());
599        assert_eq!(source.schema(), schema);
600    }
601
602    #[test]
603    fn test_checkpoint_empty() {
604        let source = WebSocketSource::new(test_schema(), test_config());
605        let cp = source.checkpoint();
606        assert!(!cp.is_empty()); // has websocket_state key
607    }
608
609    #[test]
610    fn test_health_check_created() {
611        let source = WebSocketSource::new(test_schema(), test_config());
612        assert_eq!(source.health_check(), HealthStatus::Unknown);
613    }
614
615    #[test]
616    fn test_metrics_initial() {
617        let source = WebSocketSource::new(test_schema(), test_config());
618        let m = source.metrics();
619        assert_eq!(m.records_total, 0);
620        assert_eq!(m.bytes_total, 0);
621    }
622}