1use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::Arc;
9
10use arrow_array::RecordBatch;
11use arrow_schema::SchemaRef;
12use async_trait::async_trait;
13use bytes::Bytes;
14use futures_util::{SinkExt, StreamExt};
15use tokio::net::TcpListener;
16use tracing::{debug, info, warn};
17
18use crate::config::{ConnectorConfig, ConnectorState};
19use crate::connector::{SinkConnector, SinkConnectorCapabilities, WriteResult};
20use crate::error::ConnectorError;
21use crate::health::HealthStatus;
22use crate::metrics::ConnectorMetrics;
23
24use super::fanout::FanoutManager;
25use super::protocol::{ClientMessage, ServerMessage};
26use super::serializer::BatchSerializer;
27use super::sink_config::{SinkMode, SlowClientPolicy, WebSocketSinkConfig};
28use super::sink_metrics::WebSocketSinkMetrics;
29
30pub struct WebSocketSinkServer {
35 config: WebSocketSinkConfig,
37 schema: SchemaRef,
39 serializer: BatchSerializer,
41 fanout: Arc<FanoutManager>,
43 state: ConnectorState,
45 metrics: Arc<WebSocketSinkMetrics>,
47 current_epoch: u64,
49 shutdown_tx: Option<tokio::sync::watch::Sender<bool>>,
51 acceptor_handle: Option<tokio::task::JoinHandle<()>>,
53 sequence: Arc<AtomicU64>,
55}
56
57impl WebSocketSinkServer {
58 #[must_use]
60 pub fn new(schema: SchemaRef, config: WebSocketSinkConfig) -> Self {
61 let serializer = BatchSerializer::new(config.format.clone());
62
63 let (buffer_capacity, policy, replay_size) = match &config.mode {
64 SinkMode::Server {
65 per_client_buffer,
66 slow_client_policy,
67 replay_buffer_size,
68 ..
69 } => {
70 let msg_capacity = (*per_client_buffer / 256).max(1);
72 (
73 msg_capacity,
74 slow_client_policy.clone(),
75 *replay_buffer_size,
76 )
77 }
78 SinkMode::Client { .. } => (1024, SlowClientPolicy::DropOldest, None),
79 };
80
81 let fanout = Arc::new(FanoutManager::new(policy, buffer_capacity, replay_size));
82
83 Self {
84 config,
85 schema,
86 serializer,
87 fanout,
88 state: ConnectorState::Created,
89 metrics: Arc::new(WebSocketSinkMetrics::new()),
90 current_epoch: 0,
91 shutdown_tx: None,
92 acceptor_handle: None,
93 sequence: Arc::new(AtomicU64::new(0)),
94 }
95 }
96
97 #[must_use]
99 pub fn state(&self) -> ConnectorState {
100 self.state
101 }
102
103 #[must_use]
105 pub fn connected_clients(&self) -> usize {
106 self.fanout.client_count()
107 }
108
109 #[must_use]
111 pub fn fanout(&self) -> &Arc<FanoutManager> {
112 &self.fanout
113 }
114}
115
116#[async_trait]
117#[allow(clippy::too_many_lines)]
118impl SinkConnector for WebSocketSinkServer {
119 async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
120 self.state = ConnectorState::Initializing;
121
122 if !config.properties().is_empty() {
124 self.config = WebSocketSinkConfig::from_config(config)?;
125 }
126
127 let (bind_address, max_connections, _path, ping_interval, ping_timeout) = match &self
128 .config
129 .mode
130 {
131 SinkMode::Server {
132 bind_address,
133 max_connections,
134 path,
135 ping_interval,
136 ping_timeout,
137 ..
138 } => (
139 bind_address.clone(),
140 *max_connections,
141 path.clone(),
142 *ping_interval,
143 *ping_timeout,
144 ),
145 SinkMode::Client { .. } => {
146 return Err(ConnectorError::ConfigurationError(
147 "WebSocketSinkServer is for server mode; use WebSocketSinkClient for client mode".into(),
148 ));
149 }
150 };
151
152 info!(
153 bind = %bind_address,
154 max_connections,
155 format = ?self.config.format,
156 "opening WebSocket sink server"
157 );
158
159 let listener = TcpListener::bind(&bind_address).await.map_err(|e| {
160 ConnectorError::ConnectionFailed(format!("failed to bind {bind_address}: {e}"))
161 })?;
162
163 let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
164 let fanout = Arc::clone(&self.fanout);
165 let metrics = Arc::clone(&self.metrics);
166
167 let handle = tokio::spawn(async move {
168 let mut shutdown_rx = shutdown_rx;
169
170 loop {
171 tokio::select! {
172 accept_result = listener.accept() => {
173 match accept_result {
174 Ok((stream, addr)) => {
175 if fanout.client_count() >= max_connections {
176 warn!(addr = %addr, "rejecting: max_connections exceeded");
177 drop(stream);
178 continue;
179 }
180
181 let _ = stream.set_nodelay(true);
182 let fanout = Arc::clone(&fanout);
183 let metrics = metrics.clone();
184 let mut client_shutdown = shutdown_rx.clone();
185 let client_ping_interval = ping_interval;
186 let client_ping_timeout = ping_timeout;
187
188 tokio::spawn(async move {
189 let mut ws_config = tungstenite::protocol::WebSocketConfig::default();
192 ws_config.max_message_size = Some(1024 * 1024);
193 ws_config.max_frame_size = Some(1024 * 1024);
194 let ws_stream = match tokio_tungstenite::accept_async_with_config(stream, Some(ws_config)).await {
195 Ok(ws) => ws,
196 Err(e) => {
197 warn!(addr = %addr, error = %e, "handshake failed");
198 return;
199 }
200 };
201
202 let (mut write, mut read) = ws_stream.split();
203
204 let sub_id = format!("sub_{}", addr.port());
206 let filter = None;
207
208 let (filter, last_seq) = match tokio::time::timeout(
210 std::time::Duration::from_secs(5),
211 read.next(),
212 )
213 .await
214 {
215 Ok(Some(Ok(tungstenite::Message::Text(text)))) => {
216 match serde_json::from_str::<ClientMessage>(text.as_ref()) {
217 Ok(ClientMessage::Subscribe {
218 filter,
219 last_sequence,
220 ..
221 }) => (filter, last_sequence),
222 _ => (None, None),
223 }
224 }
225 Ok(Some(Err(e))) => {
226 warn!(addr = %addr, error = %e, "client read error during subscribe, rejecting");
227 return;
228 }
229 _ => (filter, None),
230 };
231
232 let (client_id, rx) =
234 fanout.add_client(sub_id.clone(), filter, None);
235
236 metrics.record_connect();
237
238 let confirm = ServerMessage::Subscribed {
240 subscription_id: sub_id.clone(),
241 };
242 if let Ok(json) = serde_json::to_string(&confirm) {
243 let _ = write
244 .send(tungstenite::Message::Text(json.into()))
245 .await;
246 }
247
248 if let Some(seq) = last_seq {
250 let replay_msgs = fanout.replay_from(seq);
251 metrics.record_replay();
252 for (_seq, data) in replay_msgs {
253 if write
254 .send(tungstenite::Message::Text(
255 String::from_utf8_lossy(&data).into_owned().into(),
256 ))
257 .await
258 .is_err()
259 {
260 break;
261 }
262 }
263 }
264
265 let mut ping_ticker = tokio::time::interval(client_ping_interval);
267 ping_ticker.tick().await; let mut awaiting_pong = false;
269 let mut last_ping_sent = tokio::time::Instant::now();
270
271 loop {
272 tokio::select! {
273 Some(data) = rx.recv() => {
274 let text = String::from_utf8_lossy(&data).into_owned();
275 if write.send(tungstenite::Message::Text(text.into())).await.is_err() {
276 break;
277 }
278 metrics.record_send(data.len() as u64);
279 }
280 msg = read.next() => {
281 match msg {
282 Some(Ok(tungstenite::Message::Close(_))) | None => break,
283 Some(Ok(tungstenite::Message::Pong(_))) => {
284 awaiting_pong = false;
285 }
286 Some(Ok(tungstenite::Message::Text(text))) => {
287 if let Ok(ClientMessage::Unsubscribe { .. }) =
288 serde_json::from_str::<ClientMessage>(text.as_ref())
289 {
290 break;
291 }
292 }
293 Some(Err(e)) => {
294 warn!(addr = %addr, error = %e, "client read error, disconnecting");
295 break;
296 }
297 _ => {}
298 }
299 }
300 _ = ping_ticker.tick() => {
301 if awaiting_pong && last_ping_sent.elapsed() > client_ping_timeout {
302 debug!(addr = %addr, "ping timeout — disconnecting");
303 metrics.record_ping_timeout();
304 break;
305 }
306 if write.send(tungstenite::Message::Ping(bytes::Bytes::new())).await.is_err() {
307 break;
308 }
309 awaiting_pong = true;
310 last_ping_sent = tokio::time::Instant::now();
311 }
312 _ = client_shutdown.changed() => break,
313 }
314 }
315
316 fanout.remove_client(client_id);
317 metrics.record_disconnect();
318 debug!(addr = %addr, "sink client disconnected");
319 });
320 }
321 Err(e) => {
322 warn!(error = %e, "accept error");
323 }
324 }
325 }
326 _ = shutdown_rx.changed() => {
327 info!("sink server acceptor shutting down");
328 break;
329 }
330 }
331 }
332 });
333
334 self.shutdown_tx = Some(shutdown_tx);
335 self.acceptor_handle = Some(handle);
336 self.state = ConnectorState::Running;
337
338 info!(bind = %bind_address, "WebSocket sink server started");
339 Ok(())
340 }
341
342 #[allow(clippy::cast_possible_truncation)]
343 async fn write_batch(&mut self, batch: &RecordBatch) -> Result<WriteResult, ConnectorError> {
344 if self.state != ConnectorState::Running {
345 return Err(ConnectorError::InvalidState {
346 expected: "Running".into(),
347 actual: self.state.to_string(),
348 });
349 }
350
351 if self.fanout.client_count() == 0 {
352 return Ok(WriteResult::new(0, 0));
354 }
355
356 let json = self.serializer.serialize_to_json(batch)?;
358 let seq = self.sequence.fetch_add(1, Ordering::Relaxed) + 1;
359
360 let msg = ServerMessage::Data {
361 subscription_id: String::new(), data: json,
363 sequence: seq,
364 watermark: None,
365 };
366
367 let serialized = serde_json::to_vec(&msg)
368 .map_err(|e| ConnectorError::Serde(crate::error::SerdeError::Json(e.to_string())))?;
369
370 let bytes_len = serialized.len() as u64;
371 let data = Bytes::from(serialized);
372
373 let result = self.fanout.broadcast(data);
374
375 self.metrics.record_send(bytes_len);
376 if result.dropped > 0 {
377 for _ in 0..result.dropped {
378 self.metrics.record_drop();
379 }
380 }
381
382 debug!(
383 records = batch.num_rows(),
384 sent = result.sent,
385 dropped = result.dropped,
386 sequence = result.sequence,
387 "broadcast batch to WebSocket clients"
388 );
389
390 Ok(WriteResult::new(batch.num_rows(), bytes_len))
391 }
392
393 fn schema(&self) -> SchemaRef {
394 self.schema.clone()
395 }
396
397 fn health_check(&self) -> HealthStatus {
398 match self.state {
399 ConnectorState::Running => HealthStatus::Healthy,
400 ConnectorState::Created | ConnectorState::Initializing => HealthStatus::Unknown,
401 ConnectorState::Paused => HealthStatus::Degraded("connector paused".into()),
402 ConnectorState::Recovering => HealthStatus::Degraded("recovering".into()),
403 ConnectorState::Closed => HealthStatus::Unhealthy("closed".into()),
404 ConnectorState::Failed => HealthStatus::Unhealthy("failed".into()),
405 }
406 }
407
408 fn metrics(&self) -> ConnectorMetrics {
409 self.metrics.to_connector_metrics()
410 }
411
412 fn capabilities(&self) -> SinkConnectorCapabilities {
413 SinkConnectorCapabilities::default()
414 }
416
417 async fn begin_epoch(&mut self, epoch: u64) -> Result<(), ConnectorError> {
418 self.current_epoch = epoch;
419 Ok(())
420 }
421
422 async fn commit_epoch(&mut self, _epoch: u64) -> Result<(), ConnectorError> {
423 Ok(())
424 }
425
426 async fn close(&mut self) -> Result<(), ConnectorError> {
427 info!("closing WebSocket sink server");
428
429 if let Some(tx) = self.shutdown_tx.take() {
430 let _ = tx.send(true);
431 }
432
433 if let Some(handle) = self.acceptor_handle.take() {
434 let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
435 }
436
437 self.state = ConnectorState::Closed;
438 info!("WebSocket sink server closed");
439 Ok(())
440 }
441}
442
443impl std::fmt::Debug for WebSocketSinkServer {
444 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
445 f.debug_struct("WebSocketSinkServer")
446 .field("state", &self.state)
447 .field("connected_clients", &self.connected_clients())
448 .field("format", &self.config.format)
449 .field("current_epoch", &self.current_epoch)
450 .finish_non_exhaustive()
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use super::super::sink_config::SinkFormat;
457 use super::*;
458 use arrow_schema::{DataType, Field, Schema};
459
460 fn test_schema() -> SchemaRef {
461 Arc::new(Schema::new(vec![
462 Field::new("id", DataType::Int64, false),
463 Field::new("value", DataType::Utf8, false),
464 ]))
465 }
466
467 fn test_config() -> WebSocketSinkConfig {
468 WebSocketSinkConfig {
469 mode: SinkMode::Server {
470 bind_address: "127.0.0.1:0".into(),
471 path: None,
472 max_connections: 100,
473 per_client_buffer: 262_144,
474 slow_client_policy: SlowClientPolicy::DropOldest,
475 ping_interval: std::time::Duration::from_secs(30),
476 ping_timeout: std::time::Duration::from_secs(10),
477 enable_subscription_filter: false,
478 replay_buffer_size: None,
479 },
480 format: SinkFormat::Json,
481 auth: None,
482 }
483 }
484
485 #[test]
486 fn test_new() {
487 let sink = WebSocketSinkServer::new(test_schema(), test_config());
488 assert_eq!(sink.state(), ConnectorState::Created);
489 assert_eq!(sink.connected_clients(), 0);
490 }
491
492 #[test]
493 fn test_schema_returned() {
494 let schema = test_schema();
495 let sink = WebSocketSinkServer::new(schema.clone(), test_config());
496 assert_eq!(sink.schema(), schema);
497 }
498
499 #[test]
500 fn test_capabilities() {
501 let sink = WebSocketSinkServer::new(test_schema(), test_config());
502 let caps = sink.capabilities();
503 assert!(!caps.exactly_once);
504 assert!(!caps.upsert);
505 }
506
507 #[test]
508 fn test_health_created() {
509 let sink = WebSocketSinkServer::new(test_schema(), test_config());
510 assert_eq!(sink.health_check(), HealthStatus::Unknown);
511 }
512
513 #[test]
514 fn test_metrics_initial() {
515 let sink = WebSocketSinkServer::new(test_schema(), test_config());
516 let m = sink.metrics();
517 assert_eq!(m.records_total, 0);
518 }
519}