1use std::sync::Arc;
13
14use arrow_schema::SchemaRef;
15use async_trait::async_trait;
16use crossfire::{mpsc, AsyncRx, MAsyncTx, TryRecvError, TrySendError};
17use futures_util::{SinkExt, StreamExt};
18use tokio::sync::Notify;
19use tracing::{debug, info, warn};
20
21use crate::checkpoint::SourceCheckpoint;
22use crate::config::{ConnectorConfig, ConnectorState};
23use crate::connector::{SourceBatch, SourceConnector};
24use crate::error::ConnectorError;
25use crate::health::HealthStatus;
26use crate::metrics::ConnectorMetrics;
27
28use crate::schema::json::decoder::JsonDecoderConfig;
29
30use super::backpressure::WsBackpressure;
31use super::checkpoint::WebSocketSourceCheckpoint;
32use super::connection::ConnectionManager;
33use super::metrics::WebSocketSourceMetrics;
34use super::parser::MessageParser;
35use super::source_config::{ReconnectConfig, SourceMode, WebSocketSourceConfig};
36
37enum WsMessage {
39 Data(Vec<u8>),
41 Disconnected(String),
43}
44
45pub struct WebSocketSource {
53 config: WebSocketSourceConfig,
55 schema: SchemaRef,
57 parser: MessageParser,
59 state: ConnectorState,
61 metrics: WebSocketSourceMetrics,
63 checkpoint_state: WebSocketSourceCheckpoint,
65 rx: Option<AsyncRx<mpsc::Array<WsMessage>>>,
67 shutdown_tx: Option<tokio::sync::watch::Sender<bool>>,
69 reader_handle: Option<tokio::task::JoinHandle<()>>,
71 message_buffer: Vec<Vec<u8>>,
73 max_batch_size: usize,
75 data_ready: Arc<Notify>,
77}
78
79impl WebSocketSource {
80 #[must_use]
82 pub fn new(
83 schema: SchemaRef,
84 config: WebSocketSourceConfig,
85 registry: Option<&prometheus::Registry>,
86 ) -> Self {
87 let parser = MessageParser::new(
88 schema.clone(),
89 config.format.clone(),
90 JsonDecoderConfig::default(),
91 );
92
93 Self {
94 config,
95 schema,
96 parser,
97 state: ConnectorState::Created,
98 metrics: WebSocketSourceMetrics::new(registry),
99 checkpoint_state: WebSocketSourceCheckpoint::default(),
100 rx: None,
101 shutdown_tx: None,
102 reader_handle: None,
103 message_buffer: Vec::new(),
104 max_batch_size: 1000,
105 data_ready: Arc::new(Notify::new()),
106 }
107 }
108
109 #[must_use]
111 pub fn state(&self) -> ConnectorState {
112 self.state
113 }
114
115 #[allow(
118 clippy::too_many_arguments,
119 clippy::too_many_lines,
120 clippy::unused_self
121 )]
122 fn spawn_reader(
123 &self,
124 urls: Vec<String>,
125 subscribe_message: Option<String>,
126 reconnect: ReconnectConfig,
127 max_message_size: usize,
128 on_backpressure: WsBackpressure,
129 tx: MAsyncTx<mpsc::Array<WsMessage>>,
130 mut shutdown_rx: tokio::sync::watch::Receiver<bool>,
131 data_ready: Arc<Notify>,
132 ) -> tokio::task::JoinHandle<()> {
133 tokio::spawn(async move {
134 let mut conn_mgr = ConnectionManager::new(urls, reconnect);
135
136 'outer: loop {
137 if *shutdown_rx.borrow() {
139 break;
140 }
141
142 let url = conn_mgr.current_url().to_string();
143 info!(url = %url, "connecting to WebSocket server");
144
145 let mut ws_config = tungstenite::protocol::WebSocketConfig::default();
147 ws_config.max_message_size = Some(max_message_size);
148 ws_config.max_frame_size = Some(max_message_size);
149 let ws_stream = match tokio_tungstenite::connect_async_with_config(
150 &url,
151 Some(ws_config),
152 true, )
154 .await
155 {
156 Ok((stream, _response)) => {
157 conn_mgr.reset();
158 info!(url = %url, "WebSocket connection established");
159 stream
160 }
161 Err(e) => {
162 warn!(url = %url, error = %e, "WebSocket connection failed");
163 if let Some(delay) = conn_mgr.next_backoff() {
164 tokio::select! {
165 () = tokio::time::sleep(delay) => continue,
166 _ = shutdown_rx.changed() => break,
167 }
168 } else {
169 let _ = tx
170 .send(WsMessage::Disconnected(format!(
171 "connection failed, no more retries: {e}"
172 )))
173 .await;
174 break;
175 }
176 }
177 };
178
179 let (mut write, mut read) = ws_stream.split();
180
181 if let Some(ref msg) = subscribe_message {
183 if let Err(e) = write
184 .send(tungstenite::Message::Text(msg.clone().into()))
185 .await
186 {
187 warn!(error = %e, "failed to send subscription message");
188 if let Some(delay) = conn_mgr.next_backoff() {
189 tokio::select! {
190 () = tokio::time::sleep(delay) => continue,
191 _ = shutdown_rx.changed() => break,
192 }
193 }
194 continue;
195 }
196 debug!("subscription message sent");
197 }
198
199 loop {
201 tokio::select! {
202 msg = read.next() => {
203 match msg {
204 Some(Ok(tungstenite::Message::Text(text))) => {
205 let payload = text.as_bytes().to_vec();
206 if payload.len() > max_message_size {
207 warn!(size = payload.len(), max = max_message_size, "message exceeds max size, dropping");
208 continue;
209 }
210 if send_with_backpressure(&tx, WsMessage::Data(payload), &on_backpressure, &data_ready).await.is_err() {
211 break 'outer;
212 }
213 }
214 Some(Ok(tungstenite::Message::Binary(data))) => {
215 let payload = data.to_vec();
216 if payload.len() > max_message_size {
217 warn!(size = payload.len(), max = max_message_size, "message exceeds max size, dropping");
218 continue;
219 }
220 if send_with_backpressure(&tx, WsMessage::Data(payload), &on_backpressure, &data_ready).await.is_err() {
221 break 'outer;
222 }
223 }
224 Some(Ok(tungstenite::Message::Ping(data))) => {
225 let _ = write.send(tungstenite::Message::Pong(data)).await;
226 }
227 Some(Ok(tungstenite::Message::Close(_))) => {
228 info!(url = %url, "server sent Close frame");
229 break;
230 }
231 Some(Ok(_)) => {} Some(Err(e)) => {
233 warn!(url = %url, error = %e, "WebSocket read error");
234 break;
235 }
236 None => {
237 info!(url = %url, "WebSocket stream ended");
238 break;
239 }
240 }
241 }
242 _ = shutdown_rx.changed() => {
243 debug!("shutdown signal received in reader");
244 let _ = write.send(tungstenite::Message::Close(None)).await;
245 break 'outer;
246 }
247 }
248 }
249
250 let _ = tx
252 .send(WsMessage::Disconnected(format!("disconnected from {url}")))
253 .await;
254
255 if let Some(delay) = conn_mgr.next_backoff() {
256 tokio::select! {
257 () = tokio::time::sleep(delay) => {},
258 _ = shutdown_rx.changed() => break,
259 }
260 } else {
261 break;
262 }
263 }
264 })
265 }
266}
267
268async fn send_with_backpressure(
274 tx: &MAsyncTx<mpsc::Array<WsMessage>>,
275 msg: WsMessage,
276 strategy: &WsBackpressure,
277 data_ready: &Notify,
278) -> Result<(), ()> {
279 let result = match strategy {
280 WsBackpressure::Block => tx.send(msg).await.map_err(|_| ()),
281 WsBackpressure::DropNewest => match tx.try_send(msg) {
282 Ok(()) | Err(TrySendError::Full(_)) => Ok(()),
283 Err(TrySendError::Disconnected(_)) => Err(()),
284 },
285 WsBackpressure::DropOldest
287 | WsBackpressure::Buffer { .. }
288 | WsBackpressure::Sample { .. } => match tx.try_send(msg) {
289 Ok(()) | Err(TrySendError::Full(_)) => Ok(()),
290 Err(TrySendError::Disconnected(_)) => Err(()),
291 },
292 };
293 if result.is_ok() {
294 data_ready.notify_one();
295 }
296 result
297}
298
299#[async_trait]
300impl SourceConnector for WebSocketSource {
301 async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
302 self.state = ConnectorState::Initializing;
303
304 if !config.properties().is_empty() {
306 self.config = WebSocketSourceConfig::from_config(config)?;
307 }
308
309 if let Some(schema) = config.arrow_schema() {
311 info!(
312 fields = schema.fields().len(),
313 "using SQL-defined schema for deserialization"
314 );
315 self.schema = schema;
316 let decoder_config = JsonDecoderConfig::from_connector_config(config);
317 self.parser = MessageParser::new(
318 self.schema.clone(),
319 self.config.format.clone(),
320 decoder_config,
321 );
322 }
323
324 let mode = &self.config.mode;
325 let (urls, subscribe_message, reconnect, ping_interval, ping_timeout) = match mode {
326 SourceMode::Client {
327 urls,
328 subscribe_message,
329 reconnect,
330 ping_interval,
331 ping_timeout,
332 } => (
333 urls.clone(),
334 subscribe_message.clone(),
335 reconnect.clone(),
336 *ping_interval,
337 *ping_timeout,
338 ),
339 SourceMode::Server { .. } => {
340 return Err(ConnectorError::ConfigurationError(
341 "WebSocketSource is for client mode; use WebSocketSourceServer for server mode"
342 .into(),
343 ));
344 }
345 };
346
347 if urls.is_empty() {
348 return Err(ConnectorError::ConfigurationError(
349 "at least one WebSocket URL is required".into(),
350 ));
351 }
352
353 let _ = (ping_interval, ping_timeout);
357
358 info!(
359 urls = ?urls,
360 format = ?self.config.format,
361 backpressure = ?self.config.on_backpressure,
362 "opening WebSocket source connector (client mode)"
363 );
364
365 if matches!(
366 self.config.on_backpressure,
367 WsBackpressure::DropOldest
368 | WsBackpressure::Buffer { .. }
369 | WsBackpressure::Sample { .. }
370 ) {
371 warn!(
372 strategy = ?self.config.on_backpressure,
373 "backpressure strategy not implemented, falling back to DropNewest"
374 );
375 }
376
377 let channel_capacity = 10_000;
379 let (tx, rx) = mpsc::bounded_async::<WsMessage>(channel_capacity);
380
381 let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
383
384 let handle = self.spawn_reader(
386 urls.clone(),
387 subscribe_message,
388 reconnect.clone(),
389 self.config.max_message_size,
390 self.config.on_backpressure.clone(),
391 tx,
392 shutdown_rx,
393 Arc::clone(&self.data_ready),
394 );
395
396 self.rx = Some(rx);
397 self.shutdown_tx = Some(shutdown_tx);
398 self.reader_handle = Some(handle);
399 self.state = ConnectorState::Running;
400
401 info!("WebSocket source connector opened successfully");
402 Ok(())
403 }
404
405 #[allow(clippy::cast_possible_truncation)]
406 async fn poll_batch(
407 &mut self,
408 max_records: usize,
409 ) -> Result<Option<SourceBatch>, ConnectorError> {
410 if self.state != ConnectorState::Running {
411 return Err(ConnectorError::InvalidState {
412 expected: "Running".into(),
413 actual: self.state.to_string(),
414 });
415 }
416
417 let rx = self
418 .rx
419 .as_mut()
420 .ok_or_else(|| ConnectorError::InvalidState {
421 expected: "channel initialized".into(),
422 actual: "channel is None".into(),
423 })?;
424
425 let limit = max_records.min(self.max_batch_size);
426
427 while self.message_buffer.len() < limit {
430 match rx.try_recv() {
431 Ok(WsMessage::Data(payload)) => {
432 self.metrics.record_message(payload.len() as u64);
433 self.message_buffer.push(payload);
434 }
435 Ok(WsMessage::Disconnected(reason)) => {
436 self.metrics.record_reconnect();
437 warn!(reason = %reason, "WebSocket disconnected");
438 break;
439 }
440 Err(TryRecvError::Empty) => break,
441 Err(TryRecvError::Disconnected) => {
442 if self.message_buffer.is_empty() {
444 self.state = ConnectorState::Failed;
445 return Err(ConnectorError::ReadError(
446 "WebSocket reader task terminated".into(),
447 ));
448 }
449 break;
451 }
452 }
453 }
454
455 if self.message_buffer.is_empty() {
456 return Ok(None);
457 }
458
459 let refs: Vec<&[u8]> = self.message_buffer.iter().map(Vec::as_slice).collect();
461 let batch = self.parser.parse_batch(&refs).inspect_err(|_e| {
462 self.metrics.record_parse_error();
463 })?;
464
465 let num_rows = batch.num_rows();
466 self.message_buffer.clear();
467
468 self.checkpoint_state.watermark = std::time::SystemTime::now()
474 .duration_since(std::time::UNIX_EPOCH)
475 .unwrap_or_default()
476 .as_millis() as i64;
477
478 debug!(records = num_rows, "polled batch from WebSocket");
479 Ok(Some(SourceBatch::new(batch)))
480 }
481
482 fn schema(&self) -> SchemaRef {
483 self.schema.clone()
484 }
485
486 fn checkpoint(&self) -> SourceCheckpoint {
487 self.checkpoint_state.to_source_checkpoint(0)
491 }
492
493 async fn restore(&mut self, checkpoint: &SourceCheckpoint) -> Result<(), ConnectorError> {
494 info!(
495 epoch = checkpoint.epoch(),
496 "restoring WebSocket source from checkpoint (best-effort)"
497 );
498 self.checkpoint_state = WebSocketSourceCheckpoint::from_source_checkpoint(checkpoint);
499
500 warn!(
502 last_sequence = ?self.checkpoint_state.last_sequence,
503 last_event_time = ?self.checkpoint_state.last_event_time,
504 "WebSocket source restored; data gap expected (non-replayable transport)"
505 );
506
507 Ok(())
508 }
509
510 fn health_check(&self) -> HealthStatus {
511 match self.state {
512 ConnectorState::Running => HealthStatus::Healthy,
513 ConnectorState::Created | ConnectorState::Initializing => HealthStatus::Unknown,
514 ConnectorState::Paused => HealthStatus::Degraded("connector paused".into()),
515 ConnectorState::Recovering => HealthStatus::Degraded("recovering".into()),
516 ConnectorState::Closed => HealthStatus::Unhealthy("closed".into()),
517 ConnectorState::Failed => HealthStatus::Unhealthy("failed".into()),
518 }
519 }
520
521 fn metrics(&self) -> ConnectorMetrics {
522 self.metrics.to_connector_metrics()
523 }
524
525 fn data_ready_notify(&self) -> Option<Arc<Notify>> {
526 Some(Arc::clone(&self.data_ready))
527 }
528
529 fn supports_replay(&self) -> bool {
530 false
531 }
532
533 async fn close(&mut self) -> Result<(), ConnectorError> {
534 info!("closing WebSocket source connector");
535
536 if let Some(tx) = self.shutdown_tx.take() {
538 let _ = tx.send(true);
539 }
540
541 if let Some(handle) = self.reader_handle.take() {
543 let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
544 }
545
546 self.rx = None;
547 self.state = ConnectorState::Closed;
548 info!("WebSocket source connector closed");
549 Ok(())
550 }
551}
552
553impl std::fmt::Debug for WebSocketSource {
554 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
555 f.debug_struct("WebSocketSource")
556 .field("state", &self.state)
557 .field("mode", &"client")
558 .field("format", &self.config.format)
559 .finish_non_exhaustive()
560 }
561}
562
563#[cfg(test)]
564mod tests {
565 use super::super::source_config::MessageFormat;
566 use super::*;
567 use arrow_schema::{DataType, Field, Schema};
568 use std::sync::Arc;
569
570 fn test_schema() -> SchemaRef {
571 Arc::new(Schema::new(vec![
572 Field::new("id", DataType::Utf8, true),
573 Field::new("value", DataType::Utf8, true),
574 ]))
575 }
576
577 fn test_config() -> WebSocketSourceConfig {
578 WebSocketSourceConfig {
579 mode: SourceMode::Client {
580 urls: vec!["ws://localhost:9090".into()],
581 subscribe_message: None,
582 reconnect: ReconnectConfig::default(),
583 ping_interval: std::time::Duration::from_secs(30),
584 ping_timeout: std::time::Duration::from_secs(10),
585 },
586 format: MessageFormat::Json,
587 on_backpressure: WsBackpressure::Block,
588 event_time_field: None,
589 event_time_format: None,
590 max_message_size: 64 * 1024 * 1024,
591 auth: None,
592 }
593 }
594
595 #[test]
596 fn test_new_defaults() {
597 let source = WebSocketSource::new(test_schema(), test_config(), None);
598 assert_eq!(source.state(), ConnectorState::Created);
599 }
600
601 #[test]
602 fn test_schema_returned() {
603 let schema = test_schema();
604 let source = WebSocketSource::new(schema.clone(), test_config(), None);
605 assert_eq!(source.schema(), schema);
606 }
607
608 #[test]
609 fn test_checkpoint_empty() {
610 let source = WebSocketSource::new(test_schema(), test_config(), None);
611 let cp = source.checkpoint();
612 assert!(!cp.is_empty()); }
614
615 #[test]
616 fn test_health_check_created() {
617 let source = WebSocketSource::new(test_schema(), test_config(), None);
618 assert_eq!(source.health_check(), HealthStatus::Unknown);
619 }
620
621 #[test]
622 fn test_metrics_initial() {
623 let source = WebSocketSource::new(test_schema(), test_config(), None);
624 let m = source.metrics();
625 assert_eq!(m.records_total, 0);
626 assert_eq!(m.bytes_total, 0);
627 }
628}