1use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::Arc;
9use std::time::Duration;
10
11use arrow_array::RecordBatch;
12use arrow_schema::SchemaRef;
13use async_trait::async_trait;
14use bytes::Bytes;
15use futures_util::{SinkExt, StreamExt};
16use tokio::net::TcpListener;
17use tracing::{debug, info, warn};
18
19use crate::config::{ConnectorConfig, ConnectorState};
20use crate::connector::{SinkConnector, SinkConnectorCapabilities, WriteResult};
21use crate::error::ConnectorError;
22use crate::health::HealthStatus;
23use crate::metrics::ConnectorMetrics;
24
25use super::fanout::FanoutManager;
26use super::protocol::{ClientMessage, ServerMessage};
27use super::serializer::BatchSerializer;
28use super::sink_config::{SinkMode, SlowClientPolicy, WebSocketSinkConfig};
29use super::sink_metrics::WebSocketSinkMetrics;
30
31pub struct WebSocketSinkServer {
36 config: WebSocketSinkConfig,
38 schema: SchemaRef,
40 serializer: BatchSerializer,
42 fanout: Arc<FanoutManager>,
44 state: ConnectorState,
46 metrics: Arc<WebSocketSinkMetrics>,
48 current_epoch: u64,
50 shutdown_tx: Option<tokio::sync::watch::Sender<bool>>,
52 acceptor_handle: Option<tokio::task::JoinHandle<()>>,
54 sequence: Arc<AtomicU64>,
56}
57
58impl WebSocketSinkServer {
59 #[must_use]
61 pub fn new(
62 schema: SchemaRef,
63 config: WebSocketSinkConfig,
64 registry: Option<&prometheus::Registry>,
65 ) -> Self {
66 let serializer = BatchSerializer::new(config.format.clone());
67
68 let (buffer_capacity, policy, replay_size) = match &config.mode {
69 SinkMode::Server {
70 per_client_buffer,
71 slow_client_policy,
72 replay_buffer_size,
73 ..
74 } => {
75 let msg_capacity = (*per_client_buffer / 256).max(1);
77 (
78 msg_capacity,
79 slow_client_policy.clone(),
80 *replay_buffer_size,
81 )
82 }
83 SinkMode::Client { .. } => (1024, SlowClientPolicy::DropOldest, None),
84 };
85
86 let fanout = Arc::new(FanoutManager::new(policy, buffer_capacity, replay_size));
87
88 Self {
89 config,
90 schema,
91 serializer,
92 fanout,
93 state: ConnectorState::Created,
94 metrics: Arc::new(WebSocketSinkMetrics::new(registry)),
95 current_epoch: 0,
96 shutdown_tx: None,
97 acceptor_handle: None,
98 sequence: Arc::new(AtomicU64::new(0)),
99 }
100 }
101
102 #[must_use]
104 pub fn state(&self) -> ConnectorState {
105 self.state
106 }
107
108 #[must_use]
110 pub fn connected_clients(&self) -> usize {
111 self.fanout.client_count()
112 }
113
114 #[must_use]
116 pub fn fanout(&self) -> &Arc<FanoutManager> {
117 &self.fanout
118 }
119}
120
121#[async_trait]
122#[allow(clippy::too_many_lines)]
123impl SinkConnector for WebSocketSinkServer {
124 async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
125 self.state = ConnectorState::Initializing;
126
127 if !config.properties().is_empty() {
129 self.config = WebSocketSinkConfig::from_config(config)?;
130 }
131
132 let (bind_address, max_connections, _path, ping_interval, ping_timeout) = match &self
133 .config
134 .mode
135 {
136 SinkMode::Server {
137 bind_address,
138 max_connections,
139 path,
140 ping_interval,
141 ping_timeout,
142 ..
143 } => (
144 bind_address.clone(),
145 *max_connections,
146 path.clone(),
147 *ping_interval,
148 *ping_timeout,
149 ),
150 SinkMode::Client { .. } => {
151 return Err(ConnectorError::ConfigurationError(
152 "WebSocketSinkServer is for server mode; use WebSocketSinkClient for client mode".into(),
153 ));
154 }
155 };
156
157 info!(
158 bind = %bind_address,
159 max_connections,
160 format = ?self.config.format,
161 "opening WebSocket sink server"
162 );
163
164 let listener = TcpListener::bind(&bind_address).await.map_err(|e| {
165 ConnectorError::ConnectionFailed(format!("failed to bind {bind_address}: {e}"))
166 })?;
167
168 let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
169 let fanout = Arc::clone(&self.fanout);
170 let metrics = Arc::clone(&self.metrics);
171
172 let handle = tokio::spawn(async move {
173 let mut shutdown_rx = shutdown_rx;
174
175 loop {
176 tokio::select! {
177 accept_result = listener.accept() => {
178 match accept_result {
179 Ok((stream, addr)) => {
180 if fanout.client_count() >= max_connections {
181 warn!(addr = %addr, "rejecting: max_connections exceeded");
182 drop(stream);
183 continue;
184 }
185
186 let _ = stream.set_nodelay(true);
187 let fanout = Arc::clone(&fanout);
188 let metrics = metrics.clone();
189 let mut client_shutdown = shutdown_rx.clone();
190 let client_ping_interval = ping_interval;
191 let client_ping_timeout = ping_timeout;
192
193 tokio::spawn(async move {
194 let mut ws_config = tungstenite::protocol::WebSocketConfig::default();
197 ws_config.max_message_size = Some(1024 * 1024);
198 ws_config.max_frame_size = Some(1024 * 1024);
199 let ws_stream = match tokio_tungstenite::accept_async_with_config(stream, Some(ws_config)).await {
200 Ok(ws) => ws,
201 Err(e) => {
202 warn!(addr = %addr, error = %e, "handshake failed");
203 return;
204 }
205 };
206
207 let (mut write, mut read) = ws_stream.split();
208
209 let sub_id = format!("sub_{}", addr.port());
211 let filter = None;
212
213 let (filter, last_seq) = match tokio::time::timeout(
215 std::time::Duration::from_secs(5),
216 read.next(),
217 )
218 .await
219 {
220 Ok(Some(Ok(tungstenite::Message::Text(text)))) => {
221 match serde_json::from_str::<ClientMessage>(text.as_ref()) {
222 Ok(ClientMessage::Subscribe {
223 filter,
224 last_sequence,
225 ..
226 }) => (filter, last_sequence),
227 _ => (None, None),
228 }
229 }
230 Ok(Some(Err(e))) => {
231 warn!(addr = %addr, error = %e, "client read error during subscribe, rejecting");
232 return;
233 }
234 _ => (filter, None),
235 };
236
237 let (client_id, rx) =
239 fanout.add_client(sub_id.clone(), filter, None);
240
241 metrics.record_connect();
242
243 let confirm = ServerMessage::Subscribed {
245 subscription_id: sub_id.clone(),
246 };
247 if let Ok(json) = serde_json::to_string(&confirm) {
248 let _ = write
249 .send(tungstenite::Message::Text(json.into()))
250 .await;
251 }
252
253 if let Some(seq) = last_seq {
255 let replay_msgs = fanout.replay_from(seq);
256 metrics.record_replay();
257 for (_seq, data) in replay_msgs {
258 if write
259 .send(tungstenite::Message::Text(
260 String::from_utf8_lossy(&data).into_owned().into(),
261 ))
262 .await
263 .is_err()
264 {
265 break;
266 }
267 }
268 }
269
270 let mut ping_ticker = tokio::time::interval(client_ping_interval);
272 ping_ticker.tick().await; let mut awaiting_pong = false;
274 let mut last_ping_sent = tokio::time::Instant::now();
275
276 loop {
277 tokio::select! {
278 Some(data) = rx.recv() => {
279 let text = String::from_utf8_lossy(&data).into_owned();
280 if write.send(tungstenite::Message::Text(text.into())).await.is_err() {
281 break;
282 }
283 metrics.record_send(data.len() as u64);
284 }
285 msg = read.next() => {
286 match msg {
287 Some(Ok(tungstenite::Message::Close(_))) | None => break,
288 Some(Ok(tungstenite::Message::Pong(_))) => {
289 awaiting_pong = false;
290 }
291 Some(Ok(tungstenite::Message::Text(text))) => {
292 if let Ok(ClientMessage::Unsubscribe { .. }) =
293 serde_json::from_str::<ClientMessage>(text.as_ref())
294 {
295 break;
296 }
297 }
298 Some(Err(e)) => {
299 warn!(addr = %addr, error = %e, "client read error, disconnecting");
300 break;
301 }
302 _ => {}
303 }
304 }
305 _ = ping_ticker.tick() => {
306 if awaiting_pong && last_ping_sent.elapsed() > client_ping_timeout {
307 debug!(addr = %addr, "ping timeout — disconnecting");
308 metrics.record_ping_timeout();
309 break;
310 }
311 if write.send(tungstenite::Message::Ping(bytes::Bytes::new())).await.is_err() {
312 break;
313 }
314 awaiting_pong = true;
315 last_ping_sent = tokio::time::Instant::now();
316 }
317 _ = client_shutdown.changed() => break,
318 }
319 }
320
321 fanout.remove_client(client_id);
322 metrics.record_disconnect();
323 debug!(addr = %addr, "sink client disconnected");
324 });
325 }
326 Err(e) => {
327 warn!(error = %e, "accept error");
328 }
329 }
330 }
331 _ = shutdown_rx.changed() => {
332 info!("sink server acceptor shutting down");
333 break;
334 }
335 }
336 }
337 });
338
339 self.shutdown_tx = Some(shutdown_tx);
340 self.acceptor_handle = Some(handle);
341 self.state = ConnectorState::Running;
342
343 info!(bind = %bind_address, "WebSocket sink server started");
344 Ok(())
345 }
346
347 #[allow(clippy::cast_possible_truncation)]
348 async fn write_batch(&mut self, batch: &RecordBatch) -> Result<WriteResult, ConnectorError> {
349 if self.state != ConnectorState::Running {
350 return Err(ConnectorError::InvalidState {
351 expected: "Running".into(),
352 actual: self.state.to_string(),
353 });
354 }
355
356 if self.fanout.client_count() == 0 {
357 return Ok(WriteResult::new(0, 0));
359 }
360
361 let json = self.serializer.serialize_to_json(batch)?;
363 let seq = self.sequence.fetch_add(1, Ordering::Relaxed) + 1;
364
365 let msg = ServerMessage::Data {
366 subscription_id: String::new(), data: json,
368 sequence: seq,
369 watermark: None,
370 };
371
372 let serialized = serde_json::to_vec(&msg)
373 .map_err(|e| ConnectorError::Serde(crate::error::SerdeError::Json(e.to_string())))?;
374
375 let bytes_len = serialized.len() as u64;
376 let data = Bytes::from(serialized);
377
378 let result = self.fanout.broadcast(data);
379
380 self.metrics.record_send(bytes_len);
381 if result.dropped > 0 {
382 for _ in 0..result.dropped {
383 self.metrics.record_drop();
384 }
385 }
386
387 debug!(
388 records = batch.num_rows(),
389 sent = result.sent,
390 dropped = result.dropped,
391 sequence = result.sequence,
392 "broadcast batch to WebSocket clients"
393 );
394
395 Ok(WriteResult::new(batch.num_rows(), bytes_len))
396 }
397
398 fn schema(&self) -> SchemaRef {
399 self.schema.clone()
400 }
401
402 fn health_check(&self) -> HealthStatus {
403 match self.state {
404 ConnectorState::Running => HealthStatus::Healthy,
405 ConnectorState::Created | ConnectorState::Initializing => HealthStatus::Unknown,
406 ConnectorState::Paused => HealthStatus::Degraded("connector paused".into()),
407 ConnectorState::Recovering => HealthStatus::Degraded("recovering".into()),
408 ConnectorState::Closed => HealthStatus::Unhealthy("closed".into()),
409 ConnectorState::Failed => HealthStatus::Unhealthy("failed".into()),
410 }
411 }
412
413 fn metrics(&self) -> ConnectorMetrics {
414 self.metrics.to_connector_metrics()
415 }
416
417 fn capabilities(&self) -> SinkConnectorCapabilities {
418 SinkConnectorCapabilities::new(Duration::from_secs(10))
420 }
421
422 async fn begin_epoch(&mut self, epoch: u64) -> Result<(), ConnectorError> {
423 self.current_epoch = epoch;
424 Ok(())
425 }
426
427 async fn commit_epoch(&mut self, _epoch: u64) -> Result<(), ConnectorError> {
428 Ok(())
429 }
430
431 async fn close(&mut self) -> Result<(), ConnectorError> {
432 info!("closing WebSocket sink server");
433
434 if let Some(tx) = self.shutdown_tx.take() {
435 let _ = tx.send(true);
436 }
437
438 if let Some(handle) = self.acceptor_handle.take() {
439 let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
440 }
441
442 self.state = ConnectorState::Closed;
443 info!("WebSocket sink server closed");
444 Ok(())
445 }
446}
447
448impl std::fmt::Debug for WebSocketSinkServer {
449 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
450 f.debug_struct("WebSocketSinkServer")
451 .field("state", &self.state)
452 .field("connected_clients", &self.connected_clients())
453 .field("format", &self.config.format)
454 .field("current_epoch", &self.current_epoch)
455 .finish_non_exhaustive()
456 }
457}
458
459#[cfg(test)]
460mod tests {
461 use super::super::sink_config::SinkFormat;
462 use super::*;
463 use arrow_schema::{DataType, Field, Schema};
464
465 fn test_schema() -> SchemaRef {
466 Arc::new(Schema::new(vec![
467 Field::new("id", DataType::Int64, false),
468 Field::new("value", DataType::Utf8, false),
469 ]))
470 }
471
472 fn test_config() -> WebSocketSinkConfig {
473 WebSocketSinkConfig {
474 mode: SinkMode::Server {
475 bind_address: "127.0.0.1:0".into(),
476 path: None,
477 max_connections: 100,
478 per_client_buffer: 262_144,
479 slow_client_policy: SlowClientPolicy::DropOldest,
480 ping_interval: std::time::Duration::from_secs(30),
481 ping_timeout: std::time::Duration::from_secs(10),
482 enable_subscription_filter: false,
483 replay_buffer_size: None,
484 },
485 format: SinkFormat::Json,
486 auth: None,
487 }
488 }
489
490 #[test]
491 fn test_new() {
492 let sink = WebSocketSinkServer::new(test_schema(), test_config(), None);
493 assert_eq!(sink.state(), ConnectorState::Created);
494 assert_eq!(sink.connected_clients(), 0);
495 }
496
497 #[test]
498 fn test_schema_returned() {
499 let schema = test_schema();
500 let sink = WebSocketSinkServer::new(schema.clone(), test_config(), None);
501 assert_eq!(sink.schema(), schema);
502 }
503
504 #[test]
505 fn test_capabilities() {
506 let sink = WebSocketSinkServer::new(test_schema(), test_config(), None);
507 let caps = sink.capabilities();
508 assert!(!caps.exactly_once);
509 assert!(!caps.upsert);
510 }
511
512 #[test]
513 fn test_health_created() {
514 let sink = WebSocketSinkServer::new(test_schema(), test_config(), None);
515 assert_eq!(sink.health_check(), HealthStatus::Unknown);
516 }
517
518 #[test]
519 fn test_metrics_initial() {
520 let sink = WebSocketSinkServer::new(test_schema(), test_config(), None);
521 let m = sink.metrics();
522 assert_eq!(m.records_total, 0);
523 }
524}