laminar_connectors/websocket/
sink_client.rs1use std::collections::VecDeque;
7use std::time::Duration;
8
9use arrow_array::RecordBatch;
10use arrow_schema::SchemaRef;
11use async_trait::async_trait;
12use futures_util::{SinkExt, StreamExt};
13use tracing::{debug, info, warn};
14
15use crate::config::{ConnectorConfig, ConnectorState};
16use crate::connector::{SinkConnector, SinkConnectorCapabilities, WriteResult};
17use crate::error::ConnectorError;
18
19use super::connection::ConnectionManager;
20use super::serializer::BatchSerializer;
21use super::sink_config::{SinkMode, WebSocketSinkConfig};
22use super::sink_metrics::WebSocketSinkMetrics;
23
24type WsSink = futures_util::stream::SplitSink<
26 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
27 tungstenite::Message,
28>;
29
30pub struct WebSocketSinkClient {
35 config: WebSocketSinkConfig,
37 schema: SchemaRef,
39 serializer: BatchSerializer,
41 conn_mgr: Option<ConnectionManager>,
43 ws_sink: Option<WsSink>,
45 state: ConnectorState,
47 metrics: WebSocketSinkMetrics,
49 current_epoch: u64,
51 disconnect_buffer: VecDeque<String>,
53 max_buffer_bytes: usize,
55 buffered_bytes: usize,
57}
58
59impl WebSocketSinkClient {
60 #[must_use]
62 pub fn new(
63 schema: SchemaRef,
64 config: WebSocketSinkConfig,
65 registry: Option<&prometheus::Registry>,
66 ) -> Self {
67 let serializer = BatchSerializer::new(config.format.clone());
68
69 let max_buffer_bytes = match &config.mode {
70 SinkMode::Client {
71 buffer_on_disconnect,
72 ..
73 } => buffer_on_disconnect.unwrap_or(0),
74 SinkMode::Server { .. } => 0,
75 };
76
77 Self {
78 config,
79 schema,
80 serializer,
81 conn_mgr: None,
82 ws_sink: None,
83 state: ConnectorState::Created,
84 metrics: WebSocketSinkMetrics::new(registry),
85 current_epoch: 0,
86 disconnect_buffer: VecDeque::new(),
87 max_buffer_bytes,
88 buffered_bytes: 0,
89 }
90 }
91
92 #[must_use]
94 pub fn state(&self) -> ConnectorState {
95 self.state
96 }
97
98 async fn try_reconnect(&mut self) -> Result<(), ConnectorError> {
100 let conn_mgr = self
101 .conn_mgr
102 .as_mut()
103 .ok_or_else(|| ConnectorError::InvalidState {
104 expected: "connection manager initialized".into(),
105 actual: "None".into(),
106 })?;
107
108 let url = conn_mgr.current_url().to_string();
109 info!(url = %url, "attempting WebSocket reconnection");
110
111 match tokio_tungstenite::connect_async(&url).await {
112 Ok((stream, _)) => {
113 conn_mgr.reset();
114 let (sink, _read) = stream.split();
115 self.ws_sink = Some(sink);
116 self.metrics.record_connect();
117 info!(url = %url, "WebSocket reconnected");
118
119 self.flush_disconnect_buffer().await?;
121 Ok(())
122 }
123 Err(e) => {
124 self.metrics.record_disconnect();
125 Err(ConnectorError::ConnectionFailed(format!(
126 "reconnection to {url} failed: {e}"
127 )))
128 }
129 }
130 }
131
132 async fn flush_disconnect_buffer(&mut self) -> Result<(), ConnectorError> {
134 if self.disconnect_buffer.is_empty() {
135 return Ok(());
136 }
137
138 let sink = self
139 .ws_sink
140 .as_mut()
141 .ok_or_else(|| ConnectorError::InvalidState {
142 expected: "ws_sink initialized".into(),
143 actual: "None".into(),
144 })?;
145
146 let count = self.disconnect_buffer.len();
147 debug!(buffered_messages = count, "flushing disconnect buffer");
148
149 while let Some(msg) = self.disconnect_buffer.pop_front() {
150 self.buffered_bytes -= msg.len();
151 if let Err(e) = sink.send(tungstenite::Message::Text(msg.into())).await {
152 warn!(error = %e, "failed to flush buffered message");
153 return Err(ConnectorError::WriteError(format!(
154 "buffer flush failed: {e}"
155 )));
156 }
157 }
158
159 Ok(())
160 }
161
162 fn buffer_message(&mut self, msg: String) {
164 if self.max_buffer_bytes == 0 {
165 return; }
167
168 let msg_len = msg.len();
169
170 while self.buffered_bytes + msg_len > self.max_buffer_bytes {
172 if let Some(old) = self.disconnect_buffer.pop_front() {
173 self.buffered_bytes -= old.len();
174 } else {
175 break;
176 }
177 }
178
179 if msg_len <= self.max_buffer_bytes {
180 self.buffered_bytes += msg_len;
181 self.disconnect_buffer.push_back(msg);
182 }
183 }
184}
185
186#[async_trait]
187impl SinkConnector for WebSocketSinkClient {
188 async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
189 self.state = ConnectorState::Initializing;
190
191 if !config.properties().is_empty() {
193 self.config = WebSocketSinkConfig::from_config(config)?;
194 }
195
196 let (url, reconnect) = match &self.config.mode {
197 SinkMode::Client { url, reconnect, .. } => (url.clone(), reconnect.clone()),
198 SinkMode::Server { .. } => {
199 return Err(ConnectorError::ConfigurationError(
200 "WebSocketSinkClient is for client mode; use WebSocketSinkServer for server mode".into(),
201 ));
202 }
203 };
204
205 info!(url = %url, "opening WebSocket sink client");
206
207 let (stream, _response) = tokio_tungstenite::connect_async(&url).await.map_err(|e| {
208 ConnectorError::ConnectionFailed(format!("failed to connect to {url}: {e}"))
209 })?;
210
211 let (sink, _read) = stream.split();
212 self.ws_sink = Some(sink);
213 self.conn_mgr = Some(ConnectionManager::new(vec![url.clone()], reconnect));
214 self.state = ConnectorState::Running;
215 self.metrics.record_connect();
216
217 info!(url = %url, "WebSocket sink client connected");
218 Ok(())
219 }
220
221 #[allow(clippy::cast_possible_truncation)]
222 async fn write_batch(&mut self, batch: &RecordBatch) -> Result<WriteResult, ConnectorError> {
223 if self.state != ConnectorState::Running {
224 return Err(ConnectorError::InvalidState {
225 expected: "Running".into(),
226 actual: self.state.to_string(),
227 });
228 }
229
230 let rows = self.serializer.serialize_rows(batch)?;
231 let mut bytes_written: u64 = 0;
232 let mut records_written: usize = 0;
233
234 for (i, row) in rows.iter().enumerate() {
235 if let Some(ref mut sink) = self.ws_sink {
236 match sink
237 .send(tungstenite::Message::Text(row.clone().into()))
238 .await
239 {
240 Ok(()) => {
241 bytes_written += row.len() as u64;
242 records_written += 1;
243 self.metrics.record_send(row.len() as u64);
244 }
245 Err(e) => {
246 warn!(error = %e, "send failed, buffering and attempting reconnect");
247 self.ws_sink = None;
248 self.buffer_message(row.clone());
249
250 if self.try_reconnect().await.is_err() {
252 for remaining in &rows[i + 1..] {
254 self.buffer_message(remaining.clone());
255 }
256 return Ok(WriteResult::new(records_written, bytes_written));
257 }
258 }
259 }
260 } else {
261 self.buffer_message(row.clone());
262 }
263 }
264
265 debug!(
266 records = records_written,
267 bytes = bytes_written,
268 "wrote batch to WebSocket"
269 );
270
271 Ok(WriteResult::new(records_written, bytes_written))
272 }
273
274 fn schema(&self) -> SchemaRef {
275 self.schema.clone()
276 }
277
278 fn capabilities(&self) -> SinkConnectorCapabilities {
279 SinkConnectorCapabilities::new(Duration::from_secs(10))
280 }
281
282 async fn begin_epoch(&mut self, epoch: u64) -> Result<(), ConnectorError> {
283 self.current_epoch = epoch;
284 Ok(())
285 }
286
287 async fn commit_epoch(&mut self, _epoch: u64) -> Result<(), ConnectorError> {
288 if let Some(ref mut sink) = self.ws_sink {
290 sink.flush()
291 .await
292 .map_err(|e| ConnectorError::WriteError(format!("flush failed: {e}")))?;
293 }
294 Ok(())
295 }
296
297 async fn close(&mut self) -> Result<(), ConnectorError> {
298 info!("closing WebSocket sink client");
299
300 if let Some(ref mut sink) = self.ws_sink {
301 let _ = sink.send(tungstenite::Message::Close(None)).await;
302 }
303
304 self.ws_sink = None;
305 self.state = ConnectorState::Closed;
306 info!("WebSocket sink client closed");
307 Ok(())
308 }
309}
310
311impl std::fmt::Debug for WebSocketSinkClient {
312 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313 f.debug_struct("WebSocketSinkClient")
314 .field("state", &self.state)
315 .field("connected", &self.ws_sink.is_some())
316 .field("buffered_messages", &self.disconnect_buffer.len())
317 .field("current_epoch", &self.current_epoch)
318 .finish_non_exhaustive()
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::super::source_config::ReconnectConfig;
325 use super::*;
326 use arrow_schema::{DataType, Field, Schema};
327 use std::sync::Arc;
328
329 fn test_schema() -> SchemaRef {
330 Arc::new(Schema::new(vec![
331 Field::new("id", DataType::Int64, false),
332 Field::new("value", DataType::Utf8, false),
333 ]))
334 }
335
336 fn test_config() -> WebSocketSinkConfig {
337 WebSocketSinkConfig {
338 mode: SinkMode::Client {
339 url: "ws://localhost:9090".into(),
340 reconnect: ReconnectConfig::default(),
341 buffer_on_disconnect: Some(1_048_576), batch_interval: None,
343 batch_max_size: None,
344 },
345 format: super::super::sink_config::SinkFormat::Json,
346 auth: None,
347 }
348 }
349
350 #[test]
351 fn test_new() {
352 let sink = WebSocketSinkClient::new(test_schema(), test_config(), None);
353 assert_eq!(sink.state(), ConnectorState::Created);
354 assert!(sink.ws_sink.is_none());
355 }
356
357 #[test]
358 fn test_schema_returned() {
359 let schema = test_schema();
360 let sink = WebSocketSinkClient::new(schema.clone(), test_config(), None);
361 assert_eq!(sink.schema(), schema);
362 }
363
364 #[test]
365 fn test_buffer_message() {
366 let mut sink = WebSocketSinkClient::new(test_schema(), test_config(), None);
367 sink.buffer_message("hello".into());
368 sink.buffer_message("world".into());
369 assert_eq!(sink.disconnect_buffer.len(), 2);
370 assert_eq!(sink.buffered_bytes, 10);
371 }
372
373 #[test]
374 fn test_buffer_eviction() {
375 let config = WebSocketSinkConfig {
376 mode: SinkMode::Client {
377 url: "ws://localhost:9090".into(),
378 reconnect: ReconnectConfig::default(),
379 buffer_on_disconnect: Some(10), batch_interval: None,
381 batch_max_size: None,
382 },
383 format: super::super::sink_config::SinkFormat::Json,
384 auth: None,
385 };
386 let mut sink = WebSocketSinkClient::new(test_schema(), config, None);
387
388 sink.buffer_message("12345".into()); sink.buffer_message("67890".into()); sink.buffer_message("abcde".into()); assert_eq!(sink.disconnect_buffer.len(), 2);
393 assert_eq!(sink.disconnect_buffer[0], "67890");
394 assert_eq!(sink.disconnect_buffer[1], "abcde");
395 }
396
397 #[test]
398 fn test_buffer_disabled() {
399 let config = WebSocketSinkConfig {
400 mode: SinkMode::Client {
401 url: "ws://localhost:9090".into(),
402 reconnect: ReconnectConfig::default(),
403 buffer_on_disconnect: None, batch_interval: None,
405 batch_max_size: None,
406 },
407 format: super::super::sink_config::SinkFormat::Json,
408 auth: None,
409 };
410 let mut sink = WebSocketSinkClient::new(test_schema(), config, None);
411 sink.buffer_message("hello".into());
412 assert!(sink.disconnect_buffer.is_empty());
413 }
414
415 #[test]
416 fn test_capabilities() {
417 let sink = WebSocketSinkClient::new(test_schema(), test_config(), None);
418 let caps = sink.capabilities();
419 assert!(!caps.exactly_once);
420 }
421}