Skip to main content

laminar_connectors/websocket/
sink_client.rs

1//! WebSocket sink connector — client mode.
2//!
3//! [`WebSocketSinkClient`] pushes streaming query output to an external
4//! WebSocket server by connecting as a client.
5
6use std::collections::VecDeque;
7use std::time::Duration;
8
9use arrow_array::RecordBatch;
10use arrow_schema::SchemaRef;
11use async_trait::async_trait;
12use futures_util::{SinkExt, StreamExt};
13use tracing::{debug, info, warn};
14
15use crate::config::{ConnectorConfig, ConnectorState};
16use crate::connector::{SinkConnector, SinkConnectorCapabilities, WriteResult};
17use crate::error::ConnectorError;
18
19use super::connection::ConnectionManager;
20use super::serializer::BatchSerializer;
21use super::sink_config::{SinkMode, WebSocketSinkConfig};
22use super::sink_metrics::WebSocketSinkMetrics;
23
24/// Type alias for the split WebSocket sink half.
25type WsSink = futures_util::stream::SplitSink<
26    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
27    tungstenite::Message,
28>;
29
30/// WebSocket sink connector in client mode.
31///
32/// Connects to an external WebSocket server and pushes serialized
33/// `RecordBatch` data as text or binary messages.
34pub struct WebSocketSinkClient {
35    /// Configuration.
36    config: WebSocketSinkConfig,
37    /// Input Arrow schema.
38    schema: SchemaRef,
39    /// Serializer for `RecordBatch` → messages.
40    serializer: BatchSerializer,
41    /// Connection manager for reconnection.
42    conn_mgr: Option<ConnectionManager>,
43    /// WebSocket sink (write half).
44    ws_sink: Option<WsSink>,
45    /// Connector state.
46    state: ConnectorState,
47    /// Metrics.
48    metrics: WebSocketSinkMetrics,
49    /// Current epoch.
50    current_epoch: u64,
51    /// Buffer for messages while disconnected.
52    disconnect_buffer: VecDeque<String>,
53    /// Max buffer size in bytes when disconnected.
54    max_buffer_bytes: usize,
55    /// Current buffered bytes.
56    buffered_bytes: usize,
57}
58
59impl WebSocketSinkClient {
60    /// Creates a new WebSocket sink client connector.
61    #[must_use]
62    pub fn new(
63        schema: SchemaRef,
64        config: WebSocketSinkConfig,
65        registry: Option<&prometheus::Registry>,
66    ) -> Self {
67        let serializer = BatchSerializer::new(config.format.clone());
68
69        let max_buffer_bytes = match &config.mode {
70            SinkMode::Client {
71                buffer_on_disconnect,
72                ..
73            } => buffer_on_disconnect.unwrap_or(0),
74            SinkMode::Server { .. } => 0,
75        };
76
77        Self {
78            config,
79            schema,
80            serializer,
81            conn_mgr: None,
82            ws_sink: None,
83            state: ConnectorState::Created,
84            metrics: WebSocketSinkMetrics::new(registry),
85            current_epoch: 0,
86            disconnect_buffer: VecDeque::new(),
87            max_buffer_bytes,
88            buffered_bytes: 0,
89        }
90    }
91
92    /// Returns the current connector state.
93    #[must_use]
94    pub fn state(&self) -> ConnectorState {
95        self.state
96    }
97
98    /// Attempts to reconnect and flush the disconnect buffer.
99    async fn try_reconnect(&mut self) -> Result<(), ConnectorError> {
100        let conn_mgr = self
101            .conn_mgr
102            .as_mut()
103            .ok_or_else(|| ConnectorError::InvalidState {
104                expected: "connection manager initialized".into(),
105                actual: "None".into(),
106            })?;
107
108        let url = conn_mgr.current_url().to_string();
109        info!(url = %url, "attempting WebSocket reconnection");
110
111        match tokio_tungstenite::connect_async(&url).await {
112            Ok((stream, _)) => {
113                conn_mgr.reset();
114                let (sink, _read) = stream.split();
115                self.ws_sink = Some(sink);
116                self.metrics.record_connect();
117                info!(url = %url, "WebSocket reconnected");
118
119                // Flush disconnect buffer.
120                self.flush_disconnect_buffer().await?;
121                Ok(())
122            }
123            Err(e) => {
124                self.metrics.record_disconnect();
125                Err(ConnectorError::ConnectionFailed(format!(
126                    "reconnection to {url} failed: {e}"
127                )))
128            }
129        }
130    }
131
132    /// Flushes buffered messages that accumulated during disconnection.
133    async fn flush_disconnect_buffer(&mut self) -> Result<(), ConnectorError> {
134        if self.disconnect_buffer.is_empty() {
135            return Ok(());
136        }
137
138        let sink = self
139            .ws_sink
140            .as_mut()
141            .ok_or_else(|| ConnectorError::InvalidState {
142                expected: "ws_sink initialized".into(),
143                actual: "None".into(),
144            })?;
145
146        let count = self.disconnect_buffer.len();
147        debug!(buffered_messages = count, "flushing disconnect buffer");
148
149        while let Some(msg) = self.disconnect_buffer.pop_front() {
150            self.buffered_bytes -= msg.len();
151            if let Err(e) = sink.send(tungstenite::Message::Text(msg.into())).await {
152                warn!(error = %e, "failed to flush buffered message");
153                return Err(ConnectorError::WriteError(format!(
154                    "buffer flush failed: {e}"
155                )));
156            }
157        }
158
159        Ok(())
160    }
161
162    /// Buffers a message for later delivery (when disconnected).
163    fn buffer_message(&mut self, msg: String) {
164        if self.max_buffer_bytes == 0 {
165            return; // buffering disabled
166        }
167
168        let msg_len = msg.len();
169
170        // Evict oldest if buffer would exceed limit.
171        while self.buffered_bytes + msg_len > self.max_buffer_bytes {
172            if let Some(old) = self.disconnect_buffer.pop_front() {
173                self.buffered_bytes -= old.len();
174            } else {
175                break;
176            }
177        }
178
179        if msg_len <= self.max_buffer_bytes {
180            self.buffered_bytes += msg_len;
181            self.disconnect_buffer.push_back(msg);
182        }
183    }
184}
185
186#[async_trait]
187impl SinkConnector for WebSocketSinkClient {
188    async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
189        self.state = ConnectorState::Initializing;
190
191        // If config has properties, re-parse (supports runtime config via SQL WITH).
192        if !config.properties().is_empty() {
193            self.config = WebSocketSinkConfig::from_config(config)?;
194        }
195
196        let (url, reconnect) = match &self.config.mode {
197            SinkMode::Client { url, reconnect, .. } => (url.clone(), reconnect.clone()),
198            SinkMode::Server { .. } => {
199                return Err(ConnectorError::ConfigurationError(
200                    "WebSocketSinkClient is for client mode; use WebSocketSinkServer for server mode".into(),
201                ));
202            }
203        };
204
205        info!(url = %url, "opening WebSocket sink client");
206
207        let (stream, _response) = tokio_tungstenite::connect_async(&url).await.map_err(|e| {
208            ConnectorError::ConnectionFailed(format!("failed to connect to {url}: {e}"))
209        })?;
210
211        let (sink, _read) = stream.split();
212        self.ws_sink = Some(sink);
213        self.conn_mgr = Some(ConnectionManager::new(vec![url.clone()], reconnect));
214        self.state = ConnectorState::Running;
215        self.metrics.record_connect();
216
217        info!(url = %url, "WebSocket sink client connected");
218        Ok(())
219    }
220
221    #[allow(clippy::cast_possible_truncation)]
222    async fn write_batch(&mut self, batch: &RecordBatch) -> Result<WriteResult, ConnectorError> {
223        if self.state != ConnectorState::Running {
224            return Err(ConnectorError::InvalidState {
225                expected: "Running".into(),
226                actual: self.state.to_string(),
227            });
228        }
229
230        let rows = self.serializer.serialize_rows(batch)?;
231        let mut bytes_written: u64 = 0;
232        let mut records_written: usize = 0;
233
234        for (i, row) in rows.iter().enumerate() {
235            if let Some(ref mut sink) = self.ws_sink {
236                match sink
237                    .send(tungstenite::Message::Text(row.clone().into()))
238                    .await
239                {
240                    Ok(()) => {
241                        bytes_written += row.len() as u64;
242                        records_written += 1;
243                        self.metrics.record_send(row.len() as u64);
244                    }
245                    Err(e) => {
246                        warn!(error = %e, "send failed, buffering and attempting reconnect");
247                        self.ws_sink = None;
248                        self.buffer_message(row.clone());
249
250                        // Try to reconnect.
251                        if self.try_reconnect().await.is_err() {
252                            // Buffer rows after the failed one.
253                            for remaining in &rows[i + 1..] {
254                                self.buffer_message(remaining.clone());
255                            }
256                            return Ok(WriteResult::new(records_written, bytes_written));
257                        }
258                    }
259                }
260            } else {
261                self.buffer_message(row.clone());
262            }
263        }
264
265        debug!(
266            records = records_written,
267            bytes = bytes_written,
268            "wrote batch to WebSocket"
269        );
270
271        Ok(WriteResult::new(records_written, bytes_written))
272    }
273
274    fn schema(&self) -> SchemaRef {
275        self.schema.clone()
276    }
277
278    fn capabilities(&self) -> SinkConnectorCapabilities {
279        SinkConnectorCapabilities::new(Duration::from_secs(10))
280    }
281
282    async fn begin_epoch(&mut self, epoch: u64) -> Result<(), ConnectorError> {
283        self.current_epoch = epoch;
284        Ok(())
285    }
286
287    async fn commit_epoch(&mut self, _epoch: u64) -> Result<(), ConnectorError> {
288        // Flush the WebSocket.
289        if let Some(ref mut sink) = self.ws_sink {
290            sink.flush()
291                .await
292                .map_err(|e| ConnectorError::WriteError(format!("flush failed: {e}")))?;
293        }
294        Ok(())
295    }
296
297    async fn close(&mut self) -> Result<(), ConnectorError> {
298        info!("closing WebSocket sink client");
299
300        if let Some(ref mut sink) = self.ws_sink {
301            let _ = sink.send(tungstenite::Message::Close(None)).await;
302        }
303
304        self.ws_sink = None;
305        self.state = ConnectorState::Closed;
306        info!("WebSocket sink client closed");
307        Ok(())
308    }
309}
310
311impl std::fmt::Debug for WebSocketSinkClient {
312    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313        f.debug_struct("WebSocketSinkClient")
314            .field("state", &self.state)
315            .field("connected", &self.ws_sink.is_some())
316            .field("buffered_messages", &self.disconnect_buffer.len())
317            .field("current_epoch", &self.current_epoch)
318            .finish_non_exhaustive()
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::super::source_config::ReconnectConfig;
325    use super::*;
326    use arrow_schema::{DataType, Field, Schema};
327    use std::sync::Arc;
328
329    fn test_schema() -> SchemaRef {
330        Arc::new(Schema::new(vec![
331            Field::new("id", DataType::Int64, false),
332            Field::new("value", DataType::Utf8, false),
333        ]))
334    }
335
336    fn test_config() -> WebSocketSinkConfig {
337        WebSocketSinkConfig {
338            mode: SinkMode::Client {
339                url: "ws://localhost:9090".into(),
340                reconnect: ReconnectConfig::default(),
341                buffer_on_disconnect: Some(1_048_576), // 1MB
342                batch_interval: None,
343                batch_max_size: None,
344            },
345            format: super::super::sink_config::SinkFormat::Json,
346            auth: None,
347        }
348    }
349
350    #[test]
351    fn test_new() {
352        let sink = WebSocketSinkClient::new(test_schema(), test_config(), None);
353        assert_eq!(sink.state(), ConnectorState::Created);
354        assert!(sink.ws_sink.is_none());
355    }
356
357    #[test]
358    fn test_schema_returned() {
359        let schema = test_schema();
360        let sink = WebSocketSinkClient::new(schema.clone(), test_config(), None);
361        assert_eq!(sink.schema(), schema);
362    }
363
364    #[test]
365    fn test_buffer_message() {
366        let mut sink = WebSocketSinkClient::new(test_schema(), test_config(), None);
367        sink.buffer_message("hello".into());
368        sink.buffer_message("world".into());
369        assert_eq!(sink.disconnect_buffer.len(), 2);
370        assert_eq!(sink.buffered_bytes, 10);
371    }
372
373    #[test]
374    fn test_buffer_eviction() {
375        let config = WebSocketSinkConfig {
376            mode: SinkMode::Client {
377                url: "ws://localhost:9090".into(),
378                reconnect: ReconnectConfig::default(),
379                buffer_on_disconnect: Some(10), // 10 bytes max
380                batch_interval: None,
381                batch_max_size: None,
382            },
383            format: super::super::sink_config::SinkFormat::Json,
384            auth: None,
385        };
386        let mut sink = WebSocketSinkClient::new(test_schema(), config, None);
387
388        sink.buffer_message("12345".into()); // 5 bytes
389        sink.buffer_message("67890".into()); // 5 bytes, total 10
390        sink.buffer_message("abcde".into()); // evicts "12345"
391
392        assert_eq!(sink.disconnect_buffer.len(), 2);
393        assert_eq!(sink.disconnect_buffer[0], "67890");
394        assert_eq!(sink.disconnect_buffer[1], "abcde");
395    }
396
397    #[test]
398    fn test_buffer_disabled() {
399        let config = WebSocketSinkConfig {
400            mode: SinkMode::Client {
401                url: "ws://localhost:9090".into(),
402                reconnect: ReconnectConfig::default(),
403                buffer_on_disconnect: None, // disabled
404                batch_interval: None,
405                batch_max_size: None,
406            },
407            format: super::super::sink_config::SinkFormat::Json,
408            auth: None,
409        };
410        let mut sink = WebSocketSinkClient::new(test_schema(), config, None);
411        sink.buffer_message("hello".into());
412        assert!(sink.disconnect_buffer.is_empty());
413    }
414
415    #[test]
416    fn test_capabilities() {
417        let sink = WebSocketSinkClient::new(test_schema(), test_config(), None);
418        let caps = sink.capabilities();
419        assert!(!caps.exactly_once);
420    }
421}