laminar_connectors/websocket/
sink.rs1use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::Arc;
9use std::time::Duration;
10
11use arrow_array::RecordBatch;
12use arrow_schema::SchemaRef;
13use async_trait::async_trait;
14use bytes::Bytes;
15use futures_util::{SinkExt, StreamExt};
16use tokio::net::TcpListener;
17use tracing::{debug, info, warn};
18
19use crate::config::{ConnectorConfig, ConnectorState};
20use crate::connector::{SinkConnector, SinkConnectorCapabilities, WriteResult};
21use crate::error::ConnectorError;
22
23use super::fanout::FanoutManager;
24use super::protocol::{ClientMessage, ServerMessage};
25use super::serializer::BatchSerializer;
26use super::sink_config::{SinkMode, SlowClientPolicy, WebSocketSinkConfig};
27use super::sink_metrics::WebSocketSinkMetrics;
28
29pub struct WebSocketSinkServer {
34 config: WebSocketSinkConfig,
36 schema: SchemaRef,
38 serializer: BatchSerializer,
40 fanout: Arc<FanoutManager>,
42 state: ConnectorState,
44 metrics: Arc<WebSocketSinkMetrics>,
46 current_epoch: u64,
48 shutdown_tx: Option<tokio::sync::watch::Sender<bool>>,
50 acceptor_handle: Option<tokio::task::JoinHandle<()>>,
52 sequence: Arc<AtomicU64>,
54}
55
56impl WebSocketSinkServer {
57 #[must_use]
59 pub fn new(
60 schema: SchemaRef,
61 config: WebSocketSinkConfig,
62 registry: Option<&prometheus::Registry>,
63 ) -> Self {
64 let serializer = BatchSerializer::new(config.format.clone());
65
66 let (buffer_capacity, policy, replay_size) = match &config.mode {
67 SinkMode::Server {
68 per_client_buffer,
69 slow_client_policy,
70 replay_buffer_size,
71 ..
72 } => {
73 let msg_capacity = (*per_client_buffer / 256).max(1);
75 (
76 msg_capacity,
77 slow_client_policy.clone(),
78 *replay_buffer_size,
79 )
80 }
81 SinkMode::Client { .. } => (1024, SlowClientPolicy::DropOldest, None),
82 };
83
84 let fanout = Arc::new(FanoutManager::new(policy, buffer_capacity, replay_size));
85
86 Self {
87 config,
88 schema,
89 serializer,
90 fanout,
91 state: ConnectorState::Created,
92 metrics: Arc::new(WebSocketSinkMetrics::new(registry)),
93 current_epoch: 0,
94 shutdown_tx: None,
95 acceptor_handle: None,
96 sequence: Arc::new(AtomicU64::new(0)),
97 }
98 }
99
100 #[must_use]
102 pub fn state(&self) -> ConnectorState {
103 self.state
104 }
105
106 #[must_use]
108 pub fn connected_clients(&self) -> usize {
109 self.fanout.client_count()
110 }
111
112 #[must_use]
114 pub fn fanout(&self) -> &Arc<FanoutManager> {
115 &self.fanout
116 }
117}
118
119#[async_trait]
120#[allow(clippy::too_many_lines)]
121impl SinkConnector for WebSocketSinkServer {
122 async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
123 self.state = ConnectorState::Initializing;
124
125 if !config.properties().is_empty() {
127 self.config = WebSocketSinkConfig::from_config(config)?;
128 }
129
130 let (bind_address, max_connections, _path, ping_interval, ping_timeout) = match &self
131 .config
132 .mode
133 {
134 SinkMode::Server {
135 bind_address,
136 max_connections,
137 path,
138 ping_interval,
139 ping_timeout,
140 ..
141 } => (
142 bind_address.clone(),
143 *max_connections,
144 path.clone(),
145 *ping_interval,
146 *ping_timeout,
147 ),
148 SinkMode::Client { .. } => {
149 return Err(ConnectorError::ConfigurationError(
150 "WebSocketSinkServer is for server mode; use WebSocketSinkClient for client mode".into(),
151 ));
152 }
153 };
154
155 info!(
156 bind = %bind_address,
157 max_connections,
158 format = ?self.config.format,
159 "opening WebSocket sink server"
160 );
161
162 let listener = TcpListener::bind(&bind_address).await.map_err(|e| {
163 ConnectorError::ConnectionFailed(format!("failed to bind {bind_address}: {e}"))
164 })?;
165
166 let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
167 let fanout = Arc::clone(&self.fanout);
168 let metrics = Arc::clone(&self.metrics);
169
170 let handle = tokio::spawn(async move {
171 let mut shutdown_rx = shutdown_rx;
172
173 loop {
174 tokio::select! {
175 accept_result = listener.accept() => {
176 match accept_result {
177 Ok((stream, addr)) => {
178 if fanout.client_count() >= max_connections {
179 warn!(addr = %addr, "rejecting: max_connections exceeded");
180 drop(stream);
181 continue;
182 }
183
184 let _ = stream.set_nodelay(true);
185 let fanout = Arc::clone(&fanout);
186 let metrics = metrics.clone();
187 let mut client_shutdown = shutdown_rx.clone();
188 let client_ping_interval = ping_interval;
189 let client_ping_timeout = ping_timeout;
190
191 tokio::spawn(async move {
192 let mut ws_config = tungstenite::protocol::WebSocketConfig::default();
195 ws_config.max_message_size = Some(1024 * 1024);
196 ws_config.max_frame_size = Some(1024 * 1024);
197 let ws_stream = match tokio_tungstenite::accept_async_with_config(stream, Some(ws_config)).await {
198 Ok(ws) => ws,
199 Err(e) => {
200 warn!(addr = %addr, error = %e, "handshake failed");
201 return;
202 }
203 };
204
205 let (mut write, mut read) = ws_stream.split();
206
207 let sub_id = format!("sub_{}", addr.port());
209 let filter = None;
210
211 let (filter, last_seq) = match tokio::time::timeout(
213 std::time::Duration::from_secs(5),
214 read.next(),
215 )
216 .await
217 {
218 Ok(Some(Ok(tungstenite::Message::Text(text)))) => {
219 match serde_json::from_str::<ClientMessage>(text.as_ref()) {
220 Ok(ClientMessage::Subscribe {
221 filter,
222 last_sequence,
223 ..
224 }) => (filter, last_sequence),
225 _ => (None, None),
226 }
227 }
228 Ok(Some(Err(e))) => {
229 warn!(addr = %addr, error = %e, "client read error during subscribe, rejecting");
230 return;
231 }
232 _ => (filter, None),
233 };
234
235 let (client_id, rx) =
237 fanout.add_client(sub_id.clone(), filter, None);
238
239 metrics.record_connect();
240
241 let confirm = ServerMessage::Subscribed {
243 subscription_id: sub_id.clone(),
244 };
245 if let Ok(json) = serde_json::to_string(&confirm) {
246 let _ = write
247 .send(tungstenite::Message::Text(json.into()))
248 .await;
249 }
250
251 if let Some(seq) = last_seq {
253 let replay_msgs = fanout.replay_from(seq);
254 metrics.record_replay();
255 for (_seq, data) in replay_msgs {
256 if write
257 .send(tungstenite::Message::Text(
258 String::from_utf8_lossy(&data).into_owned().into(),
259 ))
260 .await
261 .is_err()
262 {
263 break;
264 }
265 }
266 }
267
268 let mut ping_ticker = tokio::time::interval(client_ping_interval);
270 ping_ticker.tick().await; let mut awaiting_pong = false;
272 let mut last_ping_sent = tokio::time::Instant::now();
273
274 loop {
275 tokio::select! {
276 Some(data) = rx.recv() => {
277 let text = String::from_utf8_lossy(&data).into_owned();
278 if write.send(tungstenite::Message::Text(text.into())).await.is_err() {
279 break;
280 }
281 metrics.record_send(data.len() as u64);
282 }
283 msg = read.next() => {
284 match msg {
285 Some(Ok(tungstenite::Message::Close(_))) | None => break,
286 Some(Ok(tungstenite::Message::Pong(_))) => {
287 awaiting_pong = false;
288 }
289 Some(Ok(tungstenite::Message::Text(text))) => {
290 if let Ok(ClientMessage::Unsubscribe { .. }) =
291 serde_json::from_str::<ClientMessage>(text.as_ref())
292 {
293 break;
294 }
295 }
296 Some(Err(e)) => {
297 warn!(addr = %addr, error = %e, "client read error, disconnecting");
298 break;
299 }
300 _ => {}
301 }
302 }
303 _ = ping_ticker.tick() => {
304 if awaiting_pong && last_ping_sent.elapsed() > client_ping_timeout {
305 debug!(addr = %addr, "ping timeout — disconnecting");
306 metrics.record_ping_timeout();
307 break;
308 }
309 if write.send(tungstenite::Message::Ping(bytes::Bytes::new())).await.is_err() {
310 break;
311 }
312 awaiting_pong = true;
313 last_ping_sent = tokio::time::Instant::now();
314 }
315 _ = client_shutdown.changed() => break,
316 }
317 }
318
319 fanout.remove_client(client_id);
320 metrics.record_disconnect();
321 debug!(addr = %addr, "sink client disconnected");
322 });
323 }
324 Err(e) => {
325 warn!(error = %e, "accept error");
326 }
327 }
328 }
329 _ = shutdown_rx.changed() => {
330 info!("sink server acceptor shutting down");
331 break;
332 }
333 }
334 }
335 });
336
337 self.shutdown_tx = Some(shutdown_tx);
338 self.acceptor_handle = Some(handle);
339 self.state = ConnectorState::Running;
340
341 info!(bind = %bind_address, "WebSocket sink server started");
342 Ok(())
343 }
344
345 #[allow(clippy::cast_possible_truncation)]
346 async fn write_batch(&mut self, batch: &RecordBatch) -> Result<WriteResult, ConnectorError> {
347 if self.state != ConnectorState::Running {
348 return Err(ConnectorError::InvalidState {
349 expected: "Running".into(),
350 actual: self.state.to_string(),
351 });
352 }
353
354 if self.fanout.client_count() == 0 {
355 return Ok(WriteResult::new(0, 0));
357 }
358
359 let json = self.serializer.serialize_to_json(batch)?;
361 let seq = self.sequence.fetch_add(1, Ordering::Relaxed) + 1;
362
363 let msg = ServerMessage::Data {
364 subscription_id: String::new(), data: json,
366 sequence: seq,
367 watermark: None,
368 };
369
370 let serialized = serde_json::to_vec(&msg)
371 .map_err(|e| ConnectorError::Serde(crate::error::SerdeError::Json(e.to_string())))?;
372
373 let bytes_len = serialized.len() as u64;
374 let data = Bytes::from(serialized);
375
376 let result = self.fanout.broadcast(data);
377
378 self.metrics.record_send(bytes_len);
379 if result.dropped > 0 {
380 for _ in 0..result.dropped {
381 self.metrics.record_drop();
382 }
383 }
384
385 debug!(
386 records = batch.num_rows(),
387 sent = result.sent,
388 dropped = result.dropped,
389 sequence = result.sequence,
390 "broadcast batch to WebSocket clients"
391 );
392
393 Ok(WriteResult::new(batch.num_rows(), bytes_len))
394 }
395
396 fn schema(&self) -> SchemaRef {
397 self.schema.clone()
398 }
399
400 fn capabilities(&self) -> SinkConnectorCapabilities {
401 SinkConnectorCapabilities::new(Duration::from_secs(10))
403 }
404
405 async fn begin_epoch(&mut self, epoch: u64) -> Result<(), ConnectorError> {
406 self.current_epoch = epoch;
407 Ok(())
408 }
409
410 async fn commit_epoch(&mut self, _epoch: u64) -> Result<(), ConnectorError> {
411 Ok(())
412 }
413
414 async fn close(&mut self) -> Result<(), ConnectorError> {
415 info!("closing WebSocket sink server");
416
417 if let Some(tx) = self.shutdown_tx.take() {
418 let _ = tx.send(true);
419 }
420
421 if let Some(handle) = self.acceptor_handle.take() {
422 let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
423 }
424
425 self.state = ConnectorState::Closed;
426 info!("WebSocket sink server closed");
427 Ok(())
428 }
429}
430
431impl std::fmt::Debug for WebSocketSinkServer {
432 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
433 f.debug_struct("WebSocketSinkServer")
434 .field("state", &self.state)
435 .field("connected_clients", &self.connected_clients())
436 .field("format", &self.config.format)
437 .field("current_epoch", &self.current_epoch)
438 .finish_non_exhaustive()
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use super::super::sink_config::SinkFormat;
445 use super::*;
446 use arrow_schema::{DataType, Field, Schema};
447
448 fn test_schema() -> SchemaRef {
449 Arc::new(Schema::new(vec![
450 Field::new("id", DataType::Int64, false),
451 Field::new("value", DataType::Utf8, false),
452 ]))
453 }
454
455 fn test_config() -> WebSocketSinkConfig {
456 WebSocketSinkConfig {
457 mode: SinkMode::Server {
458 bind_address: "127.0.0.1:0".into(),
459 path: None,
460 max_connections: 100,
461 per_client_buffer: 262_144,
462 slow_client_policy: SlowClientPolicy::DropOldest,
463 ping_interval: std::time::Duration::from_secs(30),
464 ping_timeout: std::time::Duration::from_secs(10),
465 enable_subscription_filter: false,
466 replay_buffer_size: None,
467 },
468 format: SinkFormat::Json,
469 auth: None,
470 }
471 }
472
473 #[test]
474 fn test_new() {
475 let sink = WebSocketSinkServer::new(test_schema(), test_config(), None);
476 assert_eq!(sink.state(), ConnectorState::Created);
477 assert_eq!(sink.connected_clients(), 0);
478 }
479
480 #[test]
481 fn test_schema_returned() {
482 let schema = test_schema();
483 let sink = WebSocketSinkServer::new(schema.clone(), test_config(), None);
484 assert_eq!(sink.schema(), schema);
485 }
486
487 #[test]
488 fn test_capabilities() {
489 let sink = WebSocketSinkServer::new(test_schema(), test_config(), None);
490 let caps = sink.capabilities();
491 assert!(!caps.exactly_once);
492 assert!(!caps.upsert);
493 }
494}