Skip to main content

laminar_connectors/websocket/
source.rs

1//! WebSocket source connector — client mode.
2//!
3//! [`WebSocketSource`] connects to an external WebSocket server (e.g., exchange
4//! market data feeds) and produces Arrow `RecordBatch` data via the
5//! [`SourceConnector`] trait.
6//!
7//! # Delivery Guarantees
8//!
9//! WebSocket is non-replayable. This connector provides **at-most-once** or
10//! **best-effort** delivery. On recovery, data gaps should be expected.
11
12use std::sync::Arc;
13
14use arrow_schema::SchemaRef;
15use async_trait::async_trait;
16use crossfire::{mpsc, AsyncRx, MAsyncTx, TryRecvError, TrySendError};
17use futures_util::{SinkExt, StreamExt};
18use tokio::sync::Notify;
19use tracing::{debug, info, warn};
20
21use crate::checkpoint::SourceCheckpoint;
22use crate::config::{ConnectorConfig, ConnectorState};
23use crate::connector::{SourceBatch, SourceConnector};
24use crate::error::ConnectorError;
25use crate::health::HealthStatus;
26use crate::metrics::ConnectorMetrics;
27
28use crate::schema::json::decoder::JsonDecoderConfig;
29
30use super::backpressure::WsBackpressure;
31use super::checkpoint::WebSocketSourceCheckpoint;
32use super::connection::ConnectionManager;
33use super::metrics::WebSocketSourceMetrics;
34use super::parser::MessageParser;
35use super::source_config::{ReconnectConfig, SourceMode, WebSocketSourceConfig};
36
37/// Internal channel message from the WS reader task to the connector.
38enum WsMessage {
39    /// A raw WebSocket message payload.
40    Data(Vec<u8>),
41    /// The connection was lost.
42    Disconnected(String),
43}
44
45/// WebSocket source connector in client mode.
46///
47/// Connects to one or more external WebSocket server URLs and consumes
48/// messages, converting them to Arrow `RecordBatch` data.
49///
50/// All WebSocket I/O runs in a spawned Tokio task (Ring 2). Parsed data
51/// is delivered to `poll_batch()` via a bounded channel.
52pub struct WebSocketSource {
53    /// Parsed configuration.
54    config: WebSocketSourceConfig,
55    /// Output Arrow schema.
56    schema: SchemaRef,
57    /// Message parser (JSON/CSV/Binary → Arrow).
58    parser: MessageParser,
59    /// Connector lifecycle state.
60    state: ConnectorState,
61    /// Metrics.
62    metrics: WebSocketSourceMetrics,
63    /// Checkpoint state.
64    checkpoint_state: WebSocketSourceCheckpoint,
65    /// Bounded channel receiver for messages from the WS reader task.
66    rx: Option<AsyncRx<mpsc::Array<WsMessage>>>,
67    /// Shutdown signal sender.
68    shutdown_tx: Option<tokio::sync::watch::Sender<bool>>,
69    /// Handle to the spawned reader task.
70    reader_handle: Option<tokio::task::JoinHandle<()>>,
71    /// Buffer of raw messages accumulated between `poll_batch()` calls.
72    message_buffer: Vec<Vec<u8>>,
73    /// Maximum records per batch.
74    max_batch_size: usize,
75    /// Notification handle signalled when data arrives from the reader task.
76    data_ready: Arc<Notify>,
77}
78
79impl WebSocketSource {
80    /// Creates a new WebSocket source connector in client mode.
81    #[must_use]
82    pub fn new(
83        schema: SchemaRef,
84        config: WebSocketSourceConfig,
85        registry: Option<&prometheus::Registry>,
86    ) -> Self {
87        let parser = MessageParser::new(
88            schema.clone(),
89            config.format.clone(),
90            JsonDecoderConfig::default(),
91        );
92
93        Self {
94            config,
95            schema,
96            parser,
97            state: ConnectorState::Created,
98            metrics: WebSocketSourceMetrics::new(registry),
99            checkpoint_state: WebSocketSourceCheckpoint::default(),
100            rx: None,
101            shutdown_tx: None,
102            reader_handle: None,
103            message_buffer: Vec::new(),
104            max_batch_size: 1000,
105            data_ready: Arc::new(Notify::new()),
106        }
107    }
108
109    /// Returns the current connector state.
110    #[must_use]
111    pub fn state(&self) -> ConnectorState {
112        self.state
113    }
114
115    /// Spawns the WebSocket reader task that connects to the server and
116    /// feeds messages through the bounded channel.
117    #[allow(
118        clippy::too_many_arguments,
119        clippy::too_many_lines,
120        clippy::unused_self
121    )]
122    fn spawn_reader(
123        &self,
124        urls: Vec<String>,
125        subscribe_message: Option<String>,
126        reconnect: ReconnectConfig,
127        max_message_size: usize,
128        on_backpressure: WsBackpressure,
129        tx: MAsyncTx<mpsc::Array<WsMessage>>,
130        mut shutdown_rx: tokio::sync::watch::Receiver<bool>,
131        data_ready: Arc<Notify>,
132    ) -> tokio::task::JoinHandle<()> {
133        tokio::spawn(async move {
134            let mut conn_mgr = ConnectionManager::new(urls, reconnect);
135
136            'outer: loop {
137                // Check shutdown.
138                if *shutdown_rx.borrow() {
139                    break;
140                }
141
142                let url = conn_mgr.current_url().to_string();
143                info!(url = %url, "connecting to WebSocket server");
144
145                // Attempt connection with frame-level size cap.
146                let mut ws_config = tungstenite::protocol::WebSocketConfig::default();
147                ws_config.max_message_size = Some(max_message_size);
148                ws_config.max_frame_size = Some(max_message_size);
149                let ws_stream = match tokio_tungstenite::connect_async_with_config(
150                    &url,
151                    Some(ws_config),
152                    true, // disable Nagle for low latency
153                )
154                .await
155                {
156                    Ok((stream, _response)) => {
157                        conn_mgr.reset();
158                        info!(url = %url, "WebSocket connection established");
159                        stream
160                    }
161                    Err(e) => {
162                        warn!(url = %url, error = %e, "WebSocket connection failed");
163                        if let Some(delay) = conn_mgr.next_backoff() {
164                            tokio::select! {
165                                () = tokio::time::sleep(delay) => continue,
166                                _ = shutdown_rx.changed() => break,
167                            }
168                        } else {
169                            let _ = tx
170                                .send(WsMessage::Disconnected(format!(
171                                    "connection failed, no more retries: {e}"
172                                )))
173                                .await;
174                            break;
175                        }
176                    }
177                };
178
179                let (mut write, mut read) = ws_stream.split();
180
181                // Send subscription message if configured.
182                if let Some(ref msg) = subscribe_message {
183                    if let Err(e) = write
184                        .send(tungstenite::Message::Text(msg.clone().into()))
185                        .await
186                    {
187                        warn!(error = %e, "failed to send subscription message");
188                        if let Some(delay) = conn_mgr.next_backoff() {
189                            tokio::select! {
190                                () = tokio::time::sleep(delay) => continue,
191                                _ = shutdown_rx.changed() => break,
192                            }
193                        }
194                        continue;
195                    }
196                    debug!("subscription message sent");
197                }
198
199                // Read loop.
200                loop {
201                    tokio::select! {
202                        msg = read.next() => {
203                            match msg {
204                                Some(Ok(tungstenite::Message::Text(text))) => {
205                                    let payload = text.as_bytes().to_vec();
206                                    if payload.len() > max_message_size {
207                                        warn!(size = payload.len(), max = max_message_size, "message exceeds max size, dropping");
208                                        continue;
209                                    }
210                                    if send_with_backpressure(&tx, WsMessage::Data(payload), &on_backpressure, &data_ready).await.is_err() {
211                                        break 'outer;
212                                    }
213                                }
214                                Some(Ok(tungstenite::Message::Binary(data))) => {
215                                    let payload = data.to_vec();
216                                    if payload.len() > max_message_size {
217                                        warn!(size = payload.len(), max = max_message_size, "message exceeds max size, dropping");
218                                        continue;
219                                    }
220                                    if send_with_backpressure(&tx, WsMessage::Data(payload), &on_backpressure, &data_ready).await.is_err() {
221                                        break 'outer;
222                                    }
223                                }
224                                Some(Ok(tungstenite::Message::Ping(data))) => {
225                                    let _ = write.send(tungstenite::Message::Pong(data)).await;
226                                }
227                                Some(Ok(tungstenite::Message::Close(_))) => {
228                                    info!(url = %url, "server sent Close frame");
229                                    break;
230                                }
231                                Some(Ok(_)) => {} // Pong, Frame — ignore
232                                Some(Err(e)) => {
233                                    warn!(url = %url, error = %e, "WebSocket read error");
234                                    break;
235                                }
236                                None => {
237                                    info!(url = %url, "WebSocket stream ended");
238                                    break;
239                                }
240                            }
241                        }
242                        _ = shutdown_rx.changed() => {
243                            debug!("shutdown signal received in reader");
244                            let _ = write.send(tungstenite::Message::Close(None)).await;
245                            break 'outer;
246                        }
247                    }
248                }
249
250                // Disconnected — attempt reconnect.
251                let _ = tx
252                    .send(WsMessage::Disconnected(format!("disconnected from {url}")))
253                    .await;
254
255                if let Some(delay) = conn_mgr.next_backoff() {
256                    tokio::select! {
257                        () = tokio::time::sleep(delay) => {},
258                        _ = shutdown_rx.changed() => break,
259                    }
260                } else {
261                    break;
262                }
263            }
264        })
265    }
266}
267
268/// Sends a message through the channel, applying the backpressure strategy
269/// if the channel is full. Signals `data_ready` on successful send so the
270/// pipeline coordinator wakes immediately.
271///
272/// Returns `Err(())` if the channel is closed (shutdown).
273async fn send_with_backpressure(
274    tx: &MAsyncTx<mpsc::Array<WsMessage>>,
275    msg: WsMessage,
276    strategy: &WsBackpressure,
277    data_ready: &Notify,
278) -> Result<(), ()> {
279    let result = match strategy {
280        WsBackpressure::Block => tx.send(msg).await.map_err(|_| ()),
281        WsBackpressure::DropNewest => match tx.try_send(msg) {
282            Ok(()) | Err(TrySendError::Full(_)) => Ok(()),
283            Err(TrySendError::Disconnected(_)) => Err(()),
284        },
285        // TODO(F006): implement DropOldest, Buffer, Sample properly.
286        WsBackpressure::DropOldest
287        | WsBackpressure::Buffer { .. }
288        | WsBackpressure::Sample { .. } => match tx.try_send(msg) {
289            Ok(()) | Err(TrySendError::Full(_)) => Ok(()),
290            Err(TrySendError::Disconnected(_)) => Err(()),
291        },
292    };
293    if result.is_ok() {
294        data_ready.notify_one();
295    }
296    result
297}
298
299#[async_trait]
300impl SourceConnector for WebSocketSource {
301    async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
302        self.state = ConnectorState::Initializing;
303
304        // If config has properties, re-parse (supports runtime config via SQL WITH).
305        if !config.properties().is_empty() {
306            self.config = WebSocketSourceConfig::from_config(config)?;
307        }
308
309        // Override schema from SQL DDL if provided.
310        if let Some(schema) = config.arrow_schema() {
311            info!(
312                fields = schema.fields().len(),
313                "using SQL-defined schema for deserialization"
314            );
315            self.schema = schema;
316            let decoder_config = JsonDecoderConfig::from_connector_config(config);
317            self.parser = MessageParser::new(
318                self.schema.clone(),
319                self.config.format.clone(),
320                decoder_config,
321            );
322        }
323
324        let mode = &self.config.mode;
325        let (urls, subscribe_message, reconnect, ping_interval, ping_timeout) = match mode {
326            SourceMode::Client {
327                urls,
328                subscribe_message,
329                reconnect,
330                ping_interval,
331                ping_timeout,
332            } => (
333                urls.clone(),
334                subscribe_message.clone(),
335                reconnect.clone(),
336                *ping_interval,
337                *ping_timeout,
338            ),
339            SourceMode::Server { .. } => {
340                return Err(ConnectorError::ConfigurationError(
341                    "WebSocketSource is for client mode; use WebSocketSourceServer for server mode"
342                        .into(),
343                ));
344            }
345        };
346
347        if urls.is_empty() {
348            return Err(ConnectorError::ConfigurationError(
349                "at least one WebSocket URL is required".into(),
350            ));
351        }
352
353        // ping_interval and ping_timeout are accepted by config but not yet
354        // wired to WebSocket ping/pong frames in client-mode source.
355        // Tungstenite handles pong replies automatically for incoming pings.
356        let _ = (ping_interval, ping_timeout);
357
358        info!(
359            urls = ?urls,
360            format = ?self.config.format,
361            backpressure = ?self.config.on_backpressure,
362            "opening WebSocket source connector (client mode)"
363        );
364
365        if matches!(
366            self.config.on_backpressure,
367            WsBackpressure::DropOldest
368                | WsBackpressure::Buffer { .. }
369                | WsBackpressure::Sample { .. }
370        ) {
371            warn!(
372                strategy = ?self.config.on_backpressure,
373                "backpressure strategy not implemented, falling back to DropNewest"
374            );
375        }
376
377        // Create bounded channel between reader task and poll_batch().
378        let channel_capacity = 10_000;
379        let (tx, rx) = mpsc::bounded_async::<WsMessage>(channel_capacity);
380
381        // Create shutdown signal.
382        let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
383
384        // Spawn the reader task.
385        let handle = self.spawn_reader(
386            urls.clone(),
387            subscribe_message,
388            reconnect.clone(),
389            self.config.max_message_size,
390            self.config.on_backpressure.clone(),
391            tx,
392            shutdown_rx,
393            Arc::clone(&self.data_ready),
394        );
395
396        self.rx = Some(rx);
397        self.shutdown_tx = Some(shutdown_tx);
398        self.reader_handle = Some(handle);
399        self.state = ConnectorState::Running;
400
401        info!("WebSocket source connector opened successfully");
402        Ok(())
403    }
404
405    #[allow(clippy::cast_possible_truncation)]
406    async fn poll_batch(
407        &mut self,
408        max_records: usize,
409    ) -> Result<Option<SourceBatch>, ConnectorError> {
410        if self.state != ConnectorState::Running {
411            return Err(ConnectorError::InvalidState {
412                expected: "Running".into(),
413                actual: self.state.to_string(),
414            });
415        }
416
417        let rx = self
418            .rx
419            .as_mut()
420            .ok_or_else(|| ConnectorError::InvalidState {
421                expected: "channel initialized".into(),
422                actual: "channel is None".into(),
423            })?;
424
425        let limit = max_records.min(self.max_batch_size);
426
427        // Non-blocking drain: pull all available messages from the channel.
428        // The pipeline coordinator handles wake-up timing via data_ready_notify().
429        while self.message_buffer.len() < limit {
430            match rx.try_recv() {
431                Ok(WsMessage::Data(payload)) => {
432                    self.metrics.record_message(payload.len() as u64);
433                    self.message_buffer.push(payload);
434                }
435                Ok(WsMessage::Disconnected(reason)) => {
436                    self.metrics.record_reconnect();
437                    warn!(reason = %reason, "WebSocket disconnected");
438                    break;
439                }
440                Err(TryRecvError::Empty) => break,
441                Err(TryRecvError::Disconnected) => {
442                    // Channel closed — reader task ended.
443                    if self.message_buffer.is_empty() {
444                        self.state = ConnectorState::Failed;
445                        return Err(ConnectorError::ReadError(
446                            "WebSocket reader task terminated".into(),
447                        ));
448                    }
449                    // Drain the final batch before failing on next call.
450                    break;
451                }
452            }
453        }
454
455        if self.message_buffer.is_empty() {
456            return Ok(None);
457        }
458
459        // Parse the accumulated messages into a RecordBatch.
460        let refs: Vec<&[u8]> = self.message_buffer.iter().map(Vec::as_slice).collect();
461        let batch = self.parser.parse_batch(&refs).inspect_err(|_e| {
462            self.metrics.record_parse_error();
463        })?;
464
465        let num_rows = batch.num_rows();
466        self.message_buffer.clear();
467
468        // Update checkpoint state with wall-clock watermark.
469        // Ideally this would extract event time from the batch using the
470        // configured event_time_field, but that requires schema knowledge
471        // at this layer. The pipeline's WatermarkExtractor handles proper
472        // event-time watermarking at the SQL layer.
473        self.checkpoint_state.watermark = std::time::SystemTime::now()
474            .duration_since(std::time::UNIX_EPOCH)
475            .unwrap_or_default()
476            .as_millis() as i64;
477
478        debug!(records = num_rows, "polled batch from WebSocket");
479        Ok(Some(SourceBatch::new(batch)))
480    }
481
482    fn schema(&self) -> SchemaRef {
483        self.schema.clone()
484    }
485
486    fn checkpoint(&self) -> SourceCheckpoint {
487        // Epoch is managed by the checkpoint coordinator, not individual
488        // sources. Pass 0 here; the coordinator stamps the real epoch on
489        // the manifest. SourceCheckpoint.epoch is informational only.
490        self.checkpoint_state.to_source_checkpoint(0)
491    }
492
493    async fn restore(&mut self, checkpoint: &SourceCheckpoint) -> Result<(), ConnectorError> {
494        info!(
495            epoch = checkpoint.epoch(),
496            "restoring WebSocket source from checkpoint (best-effort)"
497        );
498        self.checkpoint_state = WebSocketSourceCheckpoint::from_source_checkpoint(checkpoint);
499
500        // WebSocket is non-replayable — log the gap.
501        warn!(
502            last_sequence = ?self.checkpoint_state.last_sequence,
503            last_event_time = ?self.checkpoint_state.last_event_time,
504            "WebSocket source restored; data gap expected (non-replayable transport)"
505        );
506
507        Ok(())
508    }
509
510    fn health_check(&self) -> HealthStatus {
511        match self.state {
512            ConnectorState::Running => HealthStatus::Healthy,
513            ConnectorState::Created | ConnectorState::Initializing => HealthStatus::Unknown,
514            ConnectorState::Paused => HealthStatus::Degraded("connector paused".into()),
515            ConnectorState::Recovering => HealthStatus::Degraded("recovering".into()),
516            ConnectorState::Closed => HealthStatus::Unhealthy("closed".into()),
517            ConnectorState::Failed => HealthStatus::Unhealthy("failed".into()),
518        }
519    }
520
521    fn metrics(&self) -> ConnectorMetrics {
522        self.metrics.to_connector_metrics()
523    }
524
525    fn data_ready_notify(&self) -> Option<Arc<Notify>> {
526        Some(Arc::clone(&self.data_ready))
527    }
528
529    fn supports_replay(&self) -> bool {
530        false
531    }
532
533    async fn close(&mut self) -> Result<(), ConnectorError> {
534        info!("closing WebSocket source connector");
535
536        // Signal shutdown.
537        if let Some(tx) = self.shutdown_tx.take() {
538            let _ = tx.send(true);
539        }
540
541        // Wait for the reader task to finish.
542        if let Some(handle) = self.reader_handle.take() {
543            let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
544        }
545
546        self.rx = None;
547        self.state = ConnectorState::Closed;
548        info!("WebSocket source connector closed");
549        Ok(())
550    }
551}
552
553impl std::fmt::Debug for WebSocketSource {
554    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
555        f.debug_struct("WebSocketSource")
556            .field("state", &self.state)
557            .field("mode", &"client")
558            .field("format", &self.config.format)
559            .finish_non_exhaustive()
560    }
561}
562
563#[cfg(test)]
564mod tests {
565    use super::super::source_config::MessageFormat;
566    use super::*;
567    use arrow_schema::{DataType, Field, Schema};
568    use std::sync::Arc;
569
570    fn test_schema() -> SchemaRef {
571        Arc::new(Schema::new(vec![
572            Field::new("id", DataType::Utf8, true),
573            Field::new("value", DataType::Utf8, true),
574        ]))
575    }
576
577    fn test_config() -> WebSocketSourceConfig {
578        WebSocketSourceConfig {
579            mode: SourceMode::Client {
580                urls: vec!["ws://localhost:9090".into()],
581                subscribe_message: None,
582                reconnect: ReconnectConfig::default(),
583                ping_interval: std::time::Duration::from_secs(30),
584                ping_timeout: std::time::Duration::from_secs(10),
585            },
586            format: MessageFormat::Json,
587            on_backpressure: WsBackpressure::Block,
588            event_time_field: None,
589            event_time_format: None,
590            max_message_size: 64 * 1024 * 1024,
591            auth: None,
592        }
593    }
594
595    #[test]
596    fn test_new_defaults() {
597        let source = WebSocketSource::new(test_schema(), test_config(), None);
598        assert_eq!(source.state(), ConnectorState::Created);
599    }
600
601    #[test]
602    fn test_schema_returned() {
603        let schema = test_schema();
604        let source = WebSocketSource::new(schema.clone(), test_config(), None);
605        assert_eq!(source.schema(), schema);
606    }
607
608    #[test]
609    fn test_checkpoint_empty() {
610        let source = WebSocketSource::new(test_schema(), test_config(), None);
611        let cp = source.checkpoint();
612        assert!(!cp.is_empty()); // has websocket_state key
613    }
614
615    #[test]
616    fn test_health_check_created() {
617        let source = WebSocketSource::new(test_schema(), test_config(), None);
618        assert_eq!(source.health_check(), HealthStatus::Unknown);
619    }
620
621    #[test]
622    fn test_metrics_initial() {
623        let source = WebSocketSource::new(test_schema(), test_config(), None);
624        let m = source.metrics();
625        assert_eq!(m.records_total, 0);
626        assert_eq!(m.bytes_total, 0);
627    }
628}