Skip to main content

laminar_connectors/websocket/
sink_client.rs

1//! WebSocket sink connector — client mode.
2//!
3//! [`WebSocketSinkClient`] pushes streaming query output to an external
4//! WebSocket server by connecting as a client.
5
6use std::collections::VecDeque;
7use std::time::Duration;
8
9use arrow_array::RecordBatch;
10use arrow_schema::SchemaRef;
11use async_trait::async_trait;
12use futures_util::{SinkExt, StreamExt};
13use tracing::{debug, info, warn};
14
15use crate::config::{ConnectorConfig, ConnectorState};
16use crate::connector::{SinkConnector, SinkConnectorCapabilities, WriteResult};
17use crate::error::ConnectorError;
18use crate::health::HealthStatus;
19use crate::metrics::ConnectorMetrics;
20
21use super::connection::ConnectionManager;
22use super::serializer::BatchSerializer;
23use super::sink_config::{SinkMode, WebSocketSinkConfig};
24use super::sink_metrics::WebSocketSinkMetrics;
25
26/// Type alias for the split WebSocket sink half.
27type WsSink = futures_util::stream::SplitSink<
28    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
29    tungstenite::Message,
30>;
31
32/// WebSocket sink connector in client mode.
33///
34/// Connects to an external WebSocket server and pushes serialized
35/// `RecordBatch` data as text or binary messages.
36pub struct WebSocketSinkClient {
37    /// Configuration.
38    config: WebSocketSinkConfig,
39    /// Input Arrow schema.
40    schema: SchemaRef,
41    /// Serializer for `RecordBatch` → messages.
42    serializer: BatchSerializer,
43    /// Connection manager for reconnection.
44    conn_mgr: Option<ConnectionManager>,
45    /// WebSocket sink (write half).
46    ws_sink: Option<WsSink>,
47    /// Connector state.
48    state: ConnectorState,
49    /// Metrics.
50    metrics: WebSocketSinkMetrics,
51    /// Current epoch.
52    current_epoch: u64,
53    /// Buffer for messages while disconnected.
54    disconnect_buffer: VecDeque<String>,
55    /// Max buffer size in bytes when disconnected.
56    max_buffer_bytes: usize,
57    /// Current buffered bytes.
58    buffered_bytes: usize,
59}
60
61impl WebSocketSinkClient {
62    /// Creates a new WebSocket sink client connector.
63    #[must_use]
64    pub fn new(
65        schema: SchemaRef,
66        config: WebSocketSinkConfig,
67        registry: Option<&prometheus::Registry>,
68    ) -> Self {
69        let serializer = BatchSerializer::new(config.format.clone());
70
71        let max_buffer_bytes = match &config.mode {
72            SinkMode::Client {
73                buffer_on_disconnect,
74                ..
75            } => buffer_on_disconnect.unwrap_or(0),
76            SinkMode::Server { .. } => 0,
77        };
78
79        Self {
80            config,
81            schema,
82            serializer,
83            conn_mgr: None,
84            ws_sink: None,
85            state: ConnectorState::Created,
86            metrics: WebSocketSinkMetrics::new(registry),
87            current_epoch: 0,
88            disconnect_buffer: VecDeque::new(),
89            max_buffer_bytes,
90            buffered_bytes: 0,
91        }
92    }
93
94    /// Returns the current connector state.
95    #[must_use]
96    pub fn state(&self) -> ConnectorState {
97        self.state
98    }
99
100    /// Attempts to reconnect and flush the disconnect buffer.
101    async fn try_reconnect(&mut self) -> Result<(), ConnectorError> {
102        let conn_mgr = self
103            .conn_mgr
104            .as_mut()
105            .ok_or_else(|| ConnectorError::InvalidState {
106                expected: "connection manager initialized".into(),
107                actual: "None".into(),
108            })?;
109
110        let url = conn_mgr.current_url().to_string();
111        info!(url = %url, "attempting WebSocket reconnection");
112
113        match tokio_tungstenite::connect_async(&url).await {
114            Ok((stream, _)) => {
115                conn_mgr.reset();
116                let (sink, _read) = stream.split();
117                self.ws_sink = Some(sink);
118                self.metrics.record_connect();
119                info!(url = %url, "WebSocket reconnected");
120
121                // Flush disconnect buffer.
122                self.flush_disconnect_buffer().await?;
123                Ok(())
124            }
125            Err(e) => {
126                self.metrics.record_disconnect();
127                Err(ConnectorError::ConnectionFailed(format!(
128                    "reconnection to {url} failed: {e}"
129                )))
130            }
131        }
132    }
133
134    /// Flushes buffered messages that accumulated during disconnection.
135    async fn flush_disconnect_buffer(&mut self) -> Result<(), ConnectorError> {
136        if self.disconnect_buffer.is_empty() {
137            return Ok(());
138        }
139
140        let sink = self
141            .ws_sink
142            .as_mut()
143            .ok_or_else(|| ConnectorError::InvalidState {
144                expected: "ws_sink initialized".into(),
145                actual: "None".into(),
146            })?;
147
148        let count = self.disconnect_buffer.len();
149        debug!(buffered_messages = count, "flushing disconnect buffer");
150
151        while let Some(msg) = self.disconnect_buffer.pop_front() {
152            self.buffered_bytes -= msg.len();
153            if let Err(e) = sink.send(tungstenite::Message::Text(msg.into())).await {
154                warn!(error = %e, "failed to flush buffered message");
155                return Err(ConnectorError::WriteError(format!(
156                    "buffer flush failed: {e}"
157                )));
158            }
159        }
160
161        Ok(())
162    }
163
164    /// Buffers a message for later delivery (when disconnected).
165    fn buffer_message(&mut self, msg: String) {
166        if self.max_buffer_bytes == 0 {
167            return; // buffering disabled
168        }
169
170        let msg_len = msg.len();
171
172        // Evict oldest if buffer would exceed limit.
173        while self.buffered_bytes + msg_len > self.max_buffer_bytes {
174            if let Some(old) = self.disconnect_buffer.pop_front() {
175                self.buffered_bytes -= old.len();
176            } else {
177                break;
178            }
179        }
180
181        if msg_len <= self.max_buffer_bytes {
182            self.buffered_bytes += msg_len;
183            self.disconnect_buffer.push_back(msg);
184        }
185    }
186}
187
188#[async_trait]
189impl SinkConnector for WebSocketSinkClient {
190    async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
191        self.state = ConnectorState::Initializing;
192
193        // If config has properties, re-parse (supports runtime config via SQL WITH).
194        if !config.properties().is_empty() {
195            self.config = WebSocketSinkConfig::from_config(config)?;
196        }
197
198        let (url, reconnect) = match &self.config.mode {
199            SinkMode::Client { url, reconnect, .. } => (url.clone(), reconnect.clone()),
200            SinkMode::Server { .. } => {
201                return Err(ConnectorError::ConfigurationError(
202                    "WebSocketSinkClient is for client mode; use WebSocketSinkServer for server mode".into(),
203                ));
204            }
205        };
206
207        info!(url = %url, "opening WebSocket sink client");
208
209        let (stream, _response) = tokio_tungstenite::connect_async(&url).await.map_err(|e| {
210            ConnectorError::ConnectionFailed(format!("failed to connect to {url}: {e}"))
211        })?;
212
213        let (sink, _read) = stream.split();
214        self.ws_sink = Some(sink);
215        self.conn_mgr = Some(ConnectionManager::new(vec![url.clone()], reconnect));
216        self.state = ConnectorState::Running;
217        self.metrics.record_connect();
218
219        info!(url = %url, "WebSocket sink client connected");
220        Ok(())
221    }
222
223    #[allow(clippy::cast_possible_truncation)]
224    async fn write_batch(&mut self, batch: &RecordBatch) -> Result<WriteResult, ConnectorError> {
225        if self.state != ConnectorState::Running {
226            return Err(ConnectorError::InvalidState {
227                expected: "Running".into(),
228                actual: self.state.to_string(),
229            });
230        }
231
232        let rows = self.serializer.serialize_rows(batch)?;
233        let mut bytes_written: u64 = 0;
234        let mut records_written: usize = 0;
235
236        for (i, row) in rows.iter().enumerate() {
237            if let Some(ref mut sink) = self.ws_sink {
238                match sink
239                    .send(tungstenite::Message::Text(row.clone().into()))
240                    .await
241                {
242                    Ok(()) => {
243                        bytes_written += row.len() as u64;
244                        records_written += 1;
245                        self.metrics.record_send(row.len() as u64);
246                    }
247                    Err(e) => {
248                        warn!(error = %e, "send failed, buffering and attempting reconnect");
249                        self.ws_sink = None;
250                        self.buffer_message(row.clone());
251
252                        // Try to reconnect.
253                        if self.try_reconnect().await.is_err() {
254                            // Buffer rows after the failed one.
255                            for remaining in &rows[i + 1..] {
256                                self.buffer_message(remaining.clone());
257                            }
258                            return Ok(WriteResult::new(records_written, bytes_written));
259                        }
260                    }
261                }
262            } else {
263                self.buffer_message(row.clone());
264            }
265        }
266
267        debug!(
268            records = records_written,
269            bytes = bytes_written,
270            "wrote batch to WebSocket"
271        );
272
273        Ok(WriteResult::new(records_written, bytes_written))
274    }
275
276    fn schema(&self) -> SchemaRef {
277        self.schema.clone()
278    }
279
280    fn health_check(&self) -> HealthStatus {
281        match self.state {
282            ConnectorState::Running => {
283                if self.ws_sink.is_some() {
284                    HealthStatus::Healthy
285                } else {
286                    HealthStatus::Degraded("disconnected, buffering".into())
287                }
288            }
289            ConnectorState::Created | ConnectorState::Initializing => HealthStatus::Unknown,
290            ConnectorState::Paused => HealthStatus::Degraded("paused".into()),
291            ConnectorState::Recovering => HealthStatus::Degraded("recovering".into()),
292            ConnectorState::Closed => HealthStatus::Unhealthy("closed".into()),
293            ConnectorState::Failed => HealthStatus::Unhealthy("failed".into()),
294        }
295    }
296
297    fn metrics(&self) -> ConnectorMetrics {
298        self.metrics.to_connector_metrics()
299    }
300
301    fn capabilities(&self) -> SinkConnectorCapabilities {
302        SinkConnectorCapabilities::new(Duration::from_secs(10))
303    }
304
305    async fn begin_epoch(&mut self, epoch: u64) -> Result<(), ConnectorError> {
306        self.current_epoch = epoch;
307        Ok(())
308    }
309
310    async fn commit_epoch(&mut self, _epoch: u64) -> Result<(), ConnectorError> {
311        // Flush the WebSocket.
312        if let Some(ref mut sink) = self.ws_sink {
313            sink.flush()
314                .await
315                .map_err(|e| ConnectorError::WriteError(format!("flush failed: {e}")))?;
316        }
317        Ok(())
318    }
319
320    async fn close(&mut self) -> Result<(), ConnectorError> {
321        info!("closing WebSocket sink client");
322
323        if let Some(ref mut sink) = self.ws_sink {
324            let _ = sink.send(tungstenite::Message::Close(None)).await;
325        }
326
327        self.ws_sink = None;
328        self.state = ConnectorState::Closed;
329        info!("WebSocket sink client closed");
330        Ok(())
331    }
332}
333
334impl std::fmt::Debug for WebSocketSinkClient {
335    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
336        f.debug_struct("WebSocketSinkClient")
337            .field("state", &self.state)
338            .field("connected", &self.ws_sink.is_some())
339            .field("buffered_messages", &self.disconnect_buffer.len())
340            .field("current_epoch", &self.current_epoch)
341            .finish_non_exhaustive()
342    }
343}
344
345#[cfg(test)]
346mod tests {
347    use super::super::source_config::ReconnectConfig;
348    use super::*;
349    use arrow_schema::{DataType, Field, Schema};
350    use std::sync::Arc;
351
352    fn test_schema() -> SchemaRef {
353        Arc::new(Schema::new(vec![
354            Field::new("id", DataType::Int64, false),
355            Field::new("value", DataType::Utf8, false),
356        ]))
357    }
358
359    fn test_config() -> WebSocketSinkConfig {
360        WebSocketSinkConfig {
361            mode: SinkMode::Client {
362                url: "ws://localhost:9090".into(),
363                reconnect: ReconnectConfig::default(),
364                buffer_on_disconnect: Some(1_048_576), // 1MB
365                batch_interval: None,
366                batch_max_size: None,
367            },
368            format: super::super::sink_config::SinkFormat::Json,
369            auth: None,
370        }
371    }
372
373    #[test]
374    fn test_new() {
375        let sink = WebSocketSinkClient::new(test_schema(), test_config(), None);
376        assert_eq!(sink.state(), ConnectorState::Created);
377        assert!(sink.ws_sink.is_none());
378    }
379
380    #[test]
381    fn test_schema_returned() {
382        let schema = test_schema();
383        let sink = WebSocketSinkClient::new(schema.clone(), test_config(), None);
384        assert_eq!(sink.schema(), schema);
385    }
386
387    #[test]
388    fn test_buffer_message() {
389        let mut sink = WebSocketSinkClient::new(test_schema(), test_config(), None);
390        sink.buffer_message("hello".into());
391        sink.buffer_message("world".into());
392        assert_eq!(sink.disconnect_buffer.len(), 2);
393        assert_eq!(sink.buffered_bytes, 10);
394    }
395
396    #[test]
397    fn test_buffer_eviction() {
398        let config = WebSocketSinkConfig {
399            mode: SinkMode::Client {
400                url: "ws://localhost:9090".into(),
401                reconnect: ReconnectConfig::default(),
402                buffer_on_disconnect: Some(10), // 10 bytes max
403                batch_interval: None,
404                batch_max_size: None,
405            },
406            format: super::super::sink_config::SinkFormat::Json,
407            auth: None,
408        };
409        let mut sink = WebSocketSinkClient::new(test_schema(), config, None);
410
411        sink.buffer_message("12345".into()); // 5 bytes
412        sink.buffer_message("67890".into()); // 5 bytes, total 10
413        sink.buffer_message("abcde".into()); // evicts "12345"
414
415        assert_eq!(sink.disconnect_buffer.len(), 2);
416        assert_eq!(sink.disconnect_buffer[0], "67890");
417        assert_eq!(sink.disconnect_buffer[1], "abcde");
418    }
419
420    #[test]
421    fn test_buffer_disabled() {
422        let config = WebSocketSinkConfig {
423            mode: SinkMode::Client {
424                url: "ws://localhost:9090".into(),
425                reconnect: ReconnectConfig::default(),
426                buffer_on_disconnect: None, // disabled
427                batch_interval: None,
428                batch_max_size: None,
429            },
430            format: super::super::sink_config::SinkFormat::Json,
431            auth: None,
432        };
433        let mut sink = WebSocketSinkClient::new(test_schema(), config, None);
434        sink.buffer_message("hello".into());
435        assert!(sink.disconnect_buffer.is_empty());
436    }
437
438    #[test]
439    fn test_capabilities() {
440        let sink = WebSocketSinkClient::new(test_schema(), test_config(), None);
441        let caps = sink.capabilities();
442        assert!(!caps.exactly_once);
443    }
444
445    #[test]
446    fn test_health_created() {
447        let sink = WebSocketSinkClient::new(test_schema(), test_config(), None);
448        assert_eq!(sink.health_check(), HealthStatus::Unknown);
449    }
450}