Skip to main content

laminar_connectors/websocket/
sink_client.rs

1//! WebSocket sink connector — client mode.
2//!
3//! [`WebSocketSinkClient`] pushes streaming query output to an external
4//! WebSocket server by connecting as a client.
5
6use std::collections::VecDeque;
7
8use arrow_array::RecordBatch;
9use arrow_schema::SchemaRef;
10use async_trait::async_trait;
11use futures_util::{SinkExt, StreamExt};
12use tracing::{debug, info, warn};
13
14use crate::config::{ConnectorConfig, ConnectorState};
15use crate::connector::{SinkConnector, SinkConnectorCapabilities, WriteResult};
16use crate::error::ConnectorError;
17use crate::health::HealthStatus;
18use crate::metrics::ConnectorMetrics;
19
20use super::connection::ConnectionManager;
21use super::serializer::BatchSerializer;
22use super::sink_config::{SinkMode, WebSocketSinkConfig};
23use super::sink_metrics::WebSocketSinkMetrics;
24
25/// Type alias for the split WebSocket sink half.
26type WsSink = futures_util::stream::SplitSink<
27    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
28    tungstenite::Message,
29>;
30
31/// WebSocket sink connector in client mode.
32///
33/// Connects to an external WebSocket server and pushes serialized
34/// `RecordBatch` data as text or binary messages.
35pub struct WebSocketSinkClient {
36    /// Configuration.
37    config: WebSocketSinkConfig,
38    /// Input Arrow schema.
39    schema: SchemaRef,
40    /// Serializer for `RecordBatch` → messages.
41    serializer: BatchSerializer,
42    /// Connection manager for reconnection.
43    conn_mgr: Option<ConnectionManager>,
44    /// WebSocket sink (write half).
45    ws_sink: Option<WsSink>,
46    /// Connector state.
47    state: ConnectorState,
48    /// Metrics.
49    metrics: WebSocketSinkMetrics,
50    /// Current epoch.
51    current_epoch: u64,
52    /// Buffer for messages while disconnected.
53    disconnect_buffer: VecDeque<String>,
54    /// Max buffer size in bytes when disconnected.
55    max_buffer_bytes: usize,
56    /// Current buffered bytes.
57    buffered_bytes: usize,
58}
59
60impl WebSocketSinkClient {
61    /// Creates a new WebSocket sink client connector.
62    #[must_use]
63    pub fn new(schema: SchemaRef, config: WebSocketSinkConfig) -> Self {
64        let serializer = BatchSerializer::new(config.format.clone());
65
66        let max_buffer_bytes = match &config.mode {
67            SinkMode::Client {
68                buffer_on_disconnect,
69                ..
70            } => buffer_on_disconnect.unwrap_or(0),
71            SinkMode::Server { .. } => 0,
72        };
73
74        Self {
75            config,
76            schema,
77            serializer,
78            conn_mgr: None,
79            ws_sink: None,
80            state: ConnectorState::Created,
81            metrics: WebSocketSinkMetrics::new(),
82            current_epoch: 0,
83            disconnect_buffer: VecDeque::new(),
84            max_buffer_bytes,
85            buffered_bytes: 0,
86        }
87    }
88
89    /// Returns the current connector state.
90    #[must_use]
91    pub fn state(&self) -> ConnectorState {
92        self.state
93    }
94
95    /// Attempts to reconnect and flush the disconnect buffer.
96    async fn try_reconnect(&mut self) -> Result<(), ConnectorError> {
97        let conn_mgr = self
98            .conn_mgr
99            .as_mut()
100            .ok_or_else(|| ConnectorError::InvalidState {
101                expected: "connection manager initialized".into(),
102                actual: "None".into(),
103            })?;
104
105        let url = conn_mgr.current_url().to_string();
106        info!(url = %url, "attempting WebSocket reconnection");
107
108        match tokio_tungstenite::connect_async(&url).await {
109            Ok((stream, _)) => {
110                conn_mgr.reset();
111                let (sink, _read) = stream.split();
112                self.ws_sink = Some(sink);
113                self.metrics.record_connect();
114                info!(url = %url, "WebSocket reconnected");
115
116                // Flush disconnect buffer.
117                self.flush_disconnect_buffer().await?;
118                Ok(())
119            }
120            Err(e) => {
121                self.metrics.record_disconnect();
122                Err(ConnectorError::ConnectionFailed(format!(
123                    "reconnection to {url} failed: {e}"
124                )))
125            }
126        }
127    }
128
129    /// Flushes buffered messages that accumulated during disconnection.
130    async fn flush_disconnect_buffer(&mut self) -> Result<(), ConnectorError> {
131        if self.disconnect_buffer.is_empty() {
132            return Ok(());
133        }
134
135        let sink = self
136            .ws_sink
137            .as_mut()
138            .ok_or_else(|| ConnectorError::InvalidState {
139                expected: "ws_sink initialized".into(),
140                actual: "None".into(),
141            })?;
142
143        let count = self.disconnect_buffer.len();
144        debug!(buffered_messages = count, "flushing disconnect buffer");
145
146        while let Some(msg) = self.disconnect_buffer.pop_front() {
147            self.buffered_bytes -= msg.len();
148            if let Err(e) = sink.send(tungstenite::Message::Text(msg.into())).await {
149                warn!(error = %e, "failed to flush buffered message");
150                return Err(ConnectorError::WriteError(format!(
151                    "buffer flush failed: {e}"
152                )));
153            }
154        }
155
156        Ok(())
157    }
158
159    /// Buffers a message for later delivery (when disconnected).
160    fn buffer_message(&mut self, msg: String) {
161        if self.max_buffer_bytes == 0 {
162            return; // buffering disabled
163        }
164
165        let msg_len = msg.len();
166
167        // Evict oldest if buffer would exceed limit.
168        while self.buffered_bytes + msg_len > self.max_buffer_bytes {
169            if let Some(old) = self.disconnect_buffer.pop_front() {
170                self.buffered_bytes -= old.len();
171            } else {
172                break;
173            }
174        }
175
176        if msg_len <= self.max_buffer_bytes {
177            self.buffered_bytes += msg_len;
178            self.disconnect_buffer.push_back(msg);
179        }
180    }
181}
182
183#[async_trait]
184impl SinkConnector for WebSocketSinkClient {
185    async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
186        self.state = ConnectorState::Initializing;
187
188        // If config has properties, re-parse (supports runtime config via SQL WITH).
189        if !config.properties().is_empty() {
190            self.config = WebSocketSinkConfig::from_config(config)?;
191        }
192
193        let (url, reconnect) = match &self.config.mode {
194            SinkMode::Client { url, reconnect, .. } => (url.clone(), reconnect.clone()),
195            SinkMode::Server { .. } => {
196                return Err(ConnectorError::ConfigurationError(
197                    "WebSocketSinkClient is for client mode; use WebSocketSinkServer for server mode".into(),
198                ));
199            }
200        };
201
202        info!(url = %url, "opening WebSocket sink client");
203
204        let (stream, _response) = tokio_tungstenite::connect_async(&url).await.map_err(|e| {
205            ConnectorError::ConnectionFailed(format!("failed to connect to {url}: {e}"))
206        })?;
207
208        let (sink, _read) = stream.split();
209        self.ws_sink = Some(sink);
210        self.conn_mgr = Some(ConnectionManager::new(vec![url.clone()], reconnect));
211        self.state = ConnectorState::Running;
212        self.metrics.record_connect();
213
214        info!(url = %url, "WebSocket sink client connected");
215        Ok(())
216    }
217
218    #[allow(clippy::cast_possible_truncation)]
219    async fn write_batch(&mut self, batch: &RecordBatch) -> Result<WriteResult, ConnectorError> {
220        if self.state != ConnectorState::Running {
221            return Err(ConnectorError::InvalidState {
222                expected: "Running".into(),
223                actual: self.state.to_string(),
224            });
225        }
226
227        let rows = self.serializer.serialize_rows(batch)?;
228        let mut bytes_written: u64 = 0;
229        let mut records_written: usize = 0;
230
231        for (i, row) in rows.iter().enumerate() {
232            if let Some(ref mut sink) = self.ws_sink {
233                match sink
234                    .send(tungstenite::Message::Text(row.clone().into()))
235                    .await
236                {
237                    Ok(()) => {
238                        bytes_written += row.len() as u64;
239                        records_written += 1;
240                        self.metrics.record_send(row.len() as u64);
241                    }
242                    Err(e) => {
243                        warn!(error = %e, "send failed, buffering and attempting reconnect");
244                        self.ws_sink = None;
245                        self.buffer_message(row.clone());
246
247                        // Try to reconnect.
248                        if self.try_reconnect().await.is_err() {
249                            // Buffer rows after the failed one.
250                            for remaining in &rows[i + 1..] {
251                                self.buffer_message(remaining.clone());
252                            }
253                            return Ok(WriteResult::new(records_written, bytes_written));
254                        }
255                    }
256                }
257            } else {
258                self.buffer_message(row.clone());
259            }
260        }
261
262        debug!(
263            records = records_written,
264            bytes = bytes_written,
265            "wrote batch to WebSocket"
266        );
267
268        Ok(WriteResult::new(records_written, bytes_written))
269    }
270
271    fn schema(&self) -> SchemaRef {
272        self.schema.clone()
273    }
274
275    fn health_check(&self) -> HealthStatus {
276        match self.state {
277            ConnectorState::Running => {
278                if self.ws_sink.is_some() {
279                    HealthStatus::Healthy
280                } else {
281                    HealthStatus::Degraded("disconnected, buffering".into())
282                }
283            }
284            ConnectorState::Created | ConnectorState::Initializing => HealthStatus::Unknown,
285            ConnectorState::Paused => HealthStatus::Degraded("paused".into()),
286            ConnectorState::Recovering => HealthStatus::Degraded("recovering".into()),
287            ConnectorState::Closed => HealthStatus::Unhealthy("closed".into()),
288            ConnectorState::Failed => HealthStatus::Unhealthy("failed".into()),
289        }
290    }
291
292    fn metrics(&self) -> ConnectorMetrics {
293        self.metrics.to_connector_metrics()
294    }
295
296    fn capabilities(&self) -> SinkConnectorCapabilities {
297        SinkConnectorCapabilities::default()
298    }
299
300    async fn begin_epoch(&mut self, epoch: u64) -> Result<(), ConnectorError> {
301        self.current_epoch = epoch;
302        Ok(())
303    }
304
305    async fn commit_epoch(&mut self, _epoch: u64) -> Result<(), ConnectorError> {
306        // Flush the WebSocket.
307        if let Some(ref mut sink) = self.ws_sink {
308            sink.flush()
309                .await
310                .map_err(|e| ConnectorError::WriteError(format!("flush failed: {e}")))?;
311        }
312        Ok(())
313    }
314
315    async fn close(&mut self) -> Result<(), ConnectorError> {
316        info!("closing WebSocket sink client");
317
318        if let Some(ref mut sink) = self.ws_sink {
319            let _ = sink.send(tungstenite::Message::Close(None)).await;
320        }
321
322        self.ws_sink = None;
323        self.state = ConnectorState::Closed;
324        info!("WebSocket sink client closed");
325        Ok(())
326    }
327}
328
329impl std::fmt::Debug for WebSocketSinkClient {
330    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
331        f.debug_struct("WebSocketSinkClient")
332            .field("state", &self.state)
333            .field("connected", &self.ws_sink.is_some())
334            .field("buffered_messages", &self.disconnect_buffer.len())
335            .field("current_epoch", &self.current_epoch)
336            .finish_non_exhaustive()
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::super::source_config::ReconnectConfig;
343    use super::*;
344    use arrow_schema::{DataType, Field, Schema};
345    use std::sync::Arc;
346
347    fn test_schema() -> SchemaRef {
348        Arc::new(Schema::new(vec![
349            Field::new("id", DataType::Int64, false),
350            Field::new("value", DataType::Utf8, false),
351        ]))
352    }
353
354    fn test_config() -> WebSocketSinkConfig {
355        WebSocketSinkConfig {
356            mode: SinkMode::Client {
357                url: "ws://localhost:9090".into(),
358                reconnect: ReconnectConfig::default(),
359                buffer_on_disconnect: Some(1_048_576), // 1MB
360                batch_interval: None,
361                batch_max_size: None,
362            },
363            format: super::super::sink_config::SinkFormat::Json,
364            auth: None,
365        }
366    }
367
368    #[test]
369    fn test_new() {
370        let sink = WebSocketSinkClient::new(test_schema(), test_config());
371        assert_eq!(sink.state(), ConnectorState::Created);
372        assert!(sink.ws_sink.is_none());
373    }
374
375    #[test]
376    fn test_schema_returned() {
377        let schema = test_schema();
378        let sink = WebSocketSinkClient::new(schema.clone(), test_config());
379        assert_eq!(sink.schema(), schema);
380    }
381
382    #[test]
383    fn test_buffer_message() {
384        let mut sink = WebSocketSinkClient::new(test_schema(), test_config());
385        sink.buffer_message("hello".into());
386        sink.buffer_message("world".into());
387        assert_eq!(sink.disconnect_buffer.len(), 2);
388        assert_eq!(sink.buffered_bytes, 10);
389    }
390
391    #[test]
392    fn test_buffer_eviction() {
393        let config = WebSocketSinkConfig {
394            mode: SinkMode::Client {
395                url: "ws://localhost:9090".into(),
396                reconnect: ReconnectConfig::default(),
397                buffer_on_disconnect: Some(10), // 10 bytes max
398                batch_interval: None,
399                batch_max_size: None,
400            },
401            format: super::super::sink_config::SinkFormat::Json,
402            auth: None,
403        };
404        let mut sink = WebSocketSinkClient::new(test_schema(), config);
405
406        sink.buffer_message("12345".into()); // 5 bytes
407        sink.buffer_message("67890".into()); // 5 bytes, total 10
408        sink.buffer_message("abcde".into()); // evicts "12345"
409
410        assert_eq!(sink.disconnect_buffer.len(), 2);
411        assert_eq!(sink.disconnect_buffer[0], "67890");
412        assert_eq!(sink.disconnect_buffer[1], "abcde");
413    }
414
415    #[test]
416    fn test_buffer_disabled() {
417        let config = WebSocketSinkConfig {
418            mode: SinkMode::Client {
419                url: "ws://localhost:9090".into(),
420                reconnect: ReconnectConfig::default(),
421                buffer_on_disconnect: None, // disabled
422                batch_interval: None,
423                batch_max_size: None,
424            },
425            format: super::super::sink_config::SinkFormat::Json,
426            auth: None,
427        };
428        let mut sink = WebSocketSinkClient::new(test_schema(), config);
429        sink.buffer_message("hello".into());
430        assert!(sink.disconnect_buffer.is_empty());
431    }
432
433    #[test]
434    fn test_capabilities() {
435        let sink = WebSocketSinkClient::new(test_schema(), test_config());
436        let caps = sink.capabilities();
437        assert!(!caps.exactly_once);
438    }
439
440    #[test]
441    fn test_health_created() {
442        let sink = WebSocketSinkClient::new(test_schema(), test_config());
443        assert_eq!(sink.health_check(), HealthStatus::Unknown);
444    }
445}