1use std::sync::Arc;
13
14use arrow_schema::SchemaRef;
15use async_trait::async_trait;
16use futures_util::{SinkExt, StreamExt};
17use tokio::sync::{mpsc, Notify};
18use tracing::{debug, info, warn};
19
20use crate::checkpoint::SourceCheckpoint;
21use crate::config::{ConnectorConfig, ConnectorState};
22use crate::connector::{SourceBatch, SourceConnector};
23use crate::error::ConnectorError;
24use crate::health::HealthStatus;
25use crate::metrics::ConnectorMetrics;
26
27use crate::schema::json::decoder::JsonDecoderConfig;
28
29use super::backpressure::BackpressureStrategy;
30use super::checkpoint::WebSocketSourceCheckpoint;
31use super::connection::ConnectionManager;
32use super::metrics::WebSocketSourceMetrics;
33use super::parser::MessageParser;
34use super::source_config::{ReconnectConfig, SourceMode, WebSocketSourceConfig};
35
36enum WsMessage {
38 Data(Vec<u8>),
40 Disconnected(String),
42}
43
44pub struct WebSocketSource {
52 config: WebSocketSourceConfig,
54 schema: SchemaRef,
56 parser: MessageParser,
58 state: ConnectorState,
60 metrics: WebSocketSourceMetrics,
62 checkpoint_state: WebSocketSourceCheckpoint,
64 rx: Option<mpsc::Receiver<WsMessage>>,
66 shutdown_tx: Option<tokio::sync::watch::Sender<bool>>,
68 reader_handle: Option<tokio::task::JoinHandle<()>>,
70 message_buffer: Vec<Vec<u8>>,
72 max_batch_size: usize,
74 data_ready: Arc<Notify>,
76}
77
78impl WebSocketSource {
79 #[must_use]
86 pub fn new(schema: SchemaRef, config: WebSocketSourceConfig) -> Self {
87 let parser = MessageParser::new(
88 schema.clone(),
89 config.format.clone(),
90 JsonDecoderConfig::default(),
91 );
92
93 Self {
94 config,
95 schema,
96 parser,
97 state: ConnectorState::Created,
98 metrics: WebSocketSourceMetrics::new(),
99 checkpoint_state: WebSocketSourceCheckpoint::default(),
100 rx: None,
101 shutdown_tx: None,
102 reader_handle: None,
103 message_buffer: Vec::new(),
104 max_batch_size: 1000,
105 data_ready: Arc::new(Notify::new()),
106 }
107 }
108
109 #[must_use]
111 pub fn state(&self) -> ConnectorState {
112 self.state
113 }
114
115 #[allow(
118 clippy::too_many_arguments,
119 clippy::too_many_lines,
120 clippy::unused_self
121 )]
122 fn spawn_reader(
123 &self,
124 urls: Vec<String>,
125 subscribe_message: Option<String>,
126 reconnect: ReconnectConfig,
127 max_message_size: usize,
128 on_backpressure: BackpressureStrategy,
129 tx: mpsc::Sender<WsMessage>,
130 mut shutdown_rx: tokio::sync::watch::Receiver<bool>,
131 data_ready: Arc<Notify>,
132 ) -> tokio::task::JoinHandle<()> {
133 tokio::spawn(async move {
134 let mut conn_mgr = ConnectionManager::new(urls, reconnect);
135
136 'outer: loop {
137 if *shutdown_rx.borrow() {
139 break;
140 }
141
142 let url = conn_mgr.current_url().to_string();
143 info!(url = %url, "connecting to WebSocket server");
144
145 let mut ws_config = tungstenite::protocol::WebSocketConfig::default();
147 ws_config.max_message_size = Some(max_message_size);
148 ws_config.max_frame_size = Some(max_message_size);
149 let ws_stream = match tokio_tungstenite::connect_async_with_config(
150 &url,
151 Some(ws_config),
152 true, )
154 .await
155 {
156 Ok((stream, _response)) => {
157 conn_mgr.reset();
158 info!(url = %url, "WebSocket connection established");
159 stream
160 }
161 Err(e) => {
162 warn!(url = %url, error = %e, "WebSocket connection failed");
163 if let Some(delay) = conn_mgr.next_backoff() {
164 tokio::select! {
165 () = tokio::time::sleep(delay) => continue,
166 _ = shutdown_rx.changed() => break,
167 }
168 } else {
169 let _ = tx
170 .send(WsMessage::Disconnected(format!(
171 "connection failed, no more retries: {e}"
172 )))
173 .await;
174 break;
175 }
176 }
177 };
178
179 let (mut write, mut read) = ws_stream.split();
180
181 if let Some(ref msg) = subscribe_message {
183 if let Err(e) = write
184 .send(tungstenite::Message::Text(msg.clone().into()))
185 .await
186 {
187 warn!(error = %e, "failed to send subscription message");
188 if let Some(delay) = conn_mgr.next_backoff() {
189 tokio::select! {
190 () = tokio::time::sleep(delay) => continue,
191 _ = shutdown_rx.changed() => break,
192 }
193 }
194 continue;
195 }
196 debug!("subscription message sent");
197 }
198
199 loop {
201 tokio::select! {
202 msg = read.next() => {
203 match msg {
204 Some(Ok(tungstenite::Message::Text(text))) => {
205 let payload = text.as_bytes().to_vec();
206 if payload.len() > max_message_size {
207 warn!(size = payload.len(), max = max_message_size, "message exceeds max size, dropping");
208 continue;
209 }
210 if send_with_backpressure(&tx, WsMessage::Data(payload), &on_backpressure, &data_ready).await.is_err() {
211 break 'outer;
212 }
213 }
214 Some(Ok(tungstenite::Message::Binary(data))) => {
215 let payload = data.to_vec();
216 if payload.len() > max_message_size {
217 warn!(size = payload.len(), max = max_message_size, "message exceeds max size, dropping");
218 continue;
219 }
220 if send_with_backpressure(&tx, WsMessage::Data(payload), &on_backpressure, &data_ready).await.is_err() {
221 break 'outer;
222 }
223 }
224 Some(Ok(tungstenite::Message::Ping(data))) => {
225 let _ = write.send(tungstenite::Message::Pong(data)).await;
226 }
227 Some(Ok(tungstenite::Message::Close(_))) => {
228 info!(url = %url, "server sent Close frame");
229 break;
230 }
231 Some(Ok(_)) => {} Some(Err(e)) => {
233 warn!(url = %url, error = %e, "WebSocket read error");
234 break;
235 }
236 None => {
237 info!(url = %url, "WebSocket stream ended");
238 break;
239 }
240 }
241 }
242 _ = shutdown_rx.changed() => {
243 debug!("shutdown signal received in reader");
244 let _ = write.send(tungstenite::Message::Close(None)).await;
245 break 'outer;
246 }
247 }
248 }
249
250 let _ = tx
252 .send(WsMessage::Disconnected(format!("disconnected from {url}")))
253 .await;
254
255 if let Some(delay) = conn_mgr.next_backoff() {
256 tokio::select! {
257 () = tokio::time::sleep(delay) => {},
258 _ = shutdown_rx.changed() => break,
259 }
260 } else {
261 break;
262 }
263 }
264 })
265 }
266}
267
268async fn send_with_backpressure(
274 tx: &mpsc::Sender<WsMessage>,
275 msg: WsMessage,
276 strategy: &BackpressureStrategy,
277 data_ready: &Notify,
278) -> Result<(), ()> {
279 let result = match strategy {
280 BackpressureStrategy::Block => tx.send(msg).await.map_err(|_| ()),
281 BackpressureStrategy::DropNewest => match tx.try_send(msg) {
282 Ok(()) | Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => Ok(()),
283 Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => Err(()),
284 },
285 BackpressureStrategy::DropOldest
286 | BackpressureStrategy::Buffer { .. }
287 | BackpressureStrategy::Sample { .. } => {
288 tracing::debug!("backpressure strategy not fully implemented, using DropNewest");
293 match tx.try_send(msg) {
294 Ok(()) | Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => Ok(()),
295 Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => Err(()),
296 }
297 }
298 };
299 if result.is_ok() {
300 data_ready.notify_one();
301 }
302 result
303}
304
305#[async_trait]
306impl SourceConnector for WebSocketSource {
307 async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
308 self.state = ConnectorState::Initializing;
309
310 if !config.properties().is_empty() {
312 self.config = WebSocketSourceConfig::from_config(config)?;
313 }
314
315 if let Some(schema) = config.arrow_schema() {
317 info!(
318 fields = schema.fields().len(),
319 "using SQL-defined schema for deserialization"
320 );
321 self.schema = schema;
322 let decoder_config = JsonDecoderConfig::from_connector_config(config);
323 self.parser = MessageParser::new(
324 self.schema.clone(),
325 self.config.format.clone(),
326 decoder_config,
327 );
328 }
329
330 let mode = &self.config.mode;
331 let (urls, subscribe_message, reconnect, ping_interval, ping_timeout) = match mode {
332 SourceMode::Client {
333 urls,
334 subscribe_message,
335 reconnect,
336 ping_interval,
337 ping_timeout,
338 } => (
339 urls.clone(),
340 subscribe_message.clone(),
341 reconnect.clone(),
342 *ping_interval,
343 *ping_timeout,
344 ),
345 SourceMode::Server { .. } => {
346 return Err(ConnectorError::ConfigurationError(
347 "WebSocketSource is for client mode; use WebSocketSourceServer for server mode"
348 .into(),
349 ));
350 }
351 };
352
353 if urls.is_empty() {
354 return Err(ConnectorError::ConfigurationError(
355 "at least one WebSocket URL is required".into(),
356 ));
357 }
358
359 let _ = (ping_interval, ping_timeout);
363
364 info!(
365 urls = ?urls,
366 format = ?self.config.format,
367 backpressure = ?self.config.on_backpressure,
368 "opening WebSocket source connector (client mode)"
369 );
370
371 let channel_capacity = 10_000;
373 let (tx, rx) = mpsc::channel(channel_capacity);
374
375 let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
377
378 let handle = self.spawn_reader(
380 urls.clone(),
381 subscribe_message,
382 reconnect.clone(),
383 self.config.max_message_size,
384 self.config.on_backpressure.clone(),
385 tx,
386 shutdown_rx,
387 Arc::clone(&self.data_ready),
388 );
389
390 self.rx = Some(rx);
391 self.shutdown_tx = Some(shutdown_tx);
392 self.reader_handle = Some(handle);
393 self.state = ConnectorState::Running;
394
395 info!("WebSocket source connector opened successfully");
396 Ok(())
397 }
398
399 #[allow(clippy::cast_possible_truncation)]
400 async fn poll_batch(
401 &mut self,
402 max_records: usize,
403 ) -> Result<Option<SourceBatch>, ConnectorError> {
404 if self.state != ConnectorState::Running {
405 return Err(ConnectorError::InvalidState {
406 expected: "Running".into(),
407 actual: self.state.to_string(),
408 });
409 }
410
411 let rx = self
412 .rx
413 .as_mut()
414 .ok_or_else(|| ConnectorError::InvalidState {
415 expected: "channel initialized".into(),
416 actual: "channel is None".into(),
417 })?;
418
419 let limit = max_records.min(self.max_batch_size);
420
421 while self.message_buffer.len() < limit {
424 match rx.try_recv() {
425 Ok(WsMessage::Data(payload)) => {
426 self.metrics.record_message(payload.len() as u64);
427 self.message_buffer.push(payload);
428 }
429 Ok(WsMessage::Disconnected(reason)) => {
430 self.metrics.record_reconnect();
431 warn!(reason = %reason, "WebSocket disconnected");
432 break;
433 }
434 Err(mpsc::error::TryRecvError::Empty) => break,
435 Err(mpsc::error::TryRecvError::Disconnected) => {
436 if self.message_buffer.is_empty() {
438 self.state = ConnectorState::Failed;
439 return Err(ConnectorError::ReadError(
440 "WebSocket reader task terminated".into(),
441 ));
442 }
443 break;
445 }
446 }
447 }
448
449 if self.message_buffer.is_empty() {
450 return Ok(None);
451 }
452
453 let refs: Vec<&[u8]> = self.message_buffer.iter().map(Vec::as_slice).collect();
455 let batch = self.parser.parse_batch(&refs).inspect_err(|_e| {
456 self.metrics.record_parse_error();
457 })?;
458
459 let num_rows = batch.num_rows();
460 self.message_buffer.clear();
461
462 self.checkpoint_state.watermark = std::time::SystemTime::now()
468 .duration_since(std::time::UNIX_EPOCH)
469 .unwrap_or_default()
470 .as_millis() as i64;
471
472 debug!(records = num_rows, "polled batch from WebSocket");
473 Ok(Some(SourceBatch::new(batch)))
474 }
475
476 fn schema(&self) -> SchemaRef {
477 self.schema.clone()
478 }
479
480 fn checkpoint(&self) -> SourceCheckpoint {
481 self.checkpoint_state.to_source_checkpoint(0)
485 }
486
487 async fn restore(&mut self, checkpoint: &SourceCheckpoint) -> Result<(), ConnectorError> {
488 info!(
489 epoch = checkpoint.epoch(),
490 "restoring WebSocket source from checkpoint (best-effort)"
491 );
492 self.checkpoint_state = WebSocketSourceCheckpoint::from_source_checkpoint(checkpoint);
493
494 warn!(
496 last_sequence = ?self.checkpoint_state.last_sequence,
497 last_event_time = ?self.checkpoint_state.last_event_time,
498 "WebSocket source restored; data gap expected (non-replayable transport)"
499 );
500
501 Ok(())
502 }
503
504 fn health_check(&self) -> HealthStatus {
505 match self.state {
506 ConnectorState::Running => HealthStatus::Healthy,
507 ConnectorState::Created | ConnectorState::Initializing => HealthStatus::Unknown,
508 ConnectorState::Paused => HealthStatus::Degraded("connector paused".into()),
509 ConnectorState::Recovering => HealthStatus::Degraded("recovering".into()),
510 ConnectorState::Closed => HealthStatus::Unhealthy("closed".into()),
511 ConnectorState::Failed => HealthStatus::Unhealthy("failed".into()),
512 }
513 }
514
515 fn metrics(&self) -> ConnectorMetrics {
516 self.metrics.to_connector_metrics()
517 }
518
519 fn data_ready_notify(&self) -> Option<Arc<Notify>> {
520 Some(Arc::clone(&self.data_ready))
521 }
522
523 fn supports_replay(&self) -> bool {
524 false
525 }
526
527 async fn close(&mut self) -> Result<(), ConnectorError> {
528 info!("closing WebSocket source connector");
529
530 if let Some(tx) = self.shutdown_tx.take() {
532 let _ = tx.send(true);
533 }
534
535 if let Some(handle) = self.reader_handle.take() {
537 let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
538 }
539
540 self.rx = None;
541 self.state = ConnectorState::Closed;
542 info!("WebSocket source connector closed");
543 Ok(())
544 }
545}
546
547impl std::fmt::Debug for WebSocketSource {
548 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
549 f.debug_struct("WebSocketSource")
550 .field("state", &self.state)
551 .field("mode", &"client")
552 .field("format", &self.config.format)
553 .finish_non_exhaustive()
554 }
555}
556
557#[cfg(test)]
558mod tests {
559 use super::super::source_config::MessageFormat;
560 use super::*;
561 use arrow_schema::{DataType, Field, Schema};
562 use std::sync::Arc;
563
564 fn test_schema() -> SchemaRef {
565 Arc::new(Schema::new(vec![
566 Field::new("id", DataType::Utf8, true),
567 Field::new("value", DataType::Utf8, true),
568 ]))
569 }
570
571 fn test_config() -> WebSocketSourceConfig {
572 WebSocketSourceConfig {
573 mode: SourceMode::Client {
574 urls: vec!["ws://localhost:9090".into()],
575 subscribe_message: None,
576 reconnect: ReconnectConfig::default(),
577 ping_interval: std::time::Duration::from_secs(30),
578 ping_timeout: std::time::Duration::from_secs(10),
579 },
580 format: MessageFormat::Json,
581 on_backpressure: BackpressureStrategy::Block,
582 event_time_field: None,
583 event_time_format: None,
584 max_message_size: 64 * 1024 * 1024,
585 auth: None,
586 }
587 }
588
589 #[test]
590 fn test_new_defaults() {
591 let source = WebSocketSource::new(test_schema(), test_config());
592 assert_eq!(source.state(), ConnectorState::Created);
593 }
594
595 #[test]
596 fn test_schema_returned() {
597 let schema = test_schema();
598 let source = WebSocketSource::new(schema.clone(), test_config());
599 assert_eq!(source.schema(), schema);
600 }
601
602 #[test]
603 fn test_checkpoint_empty() {
604 let source = WebSocketSource::new(test_schema(), test_config());
605 let cp = source.checkpoint();
606 assert!(!cp.is_empty()); }
608
609 #[test]
610 fn test_health_check_created() {
611 let source = WebSocketSource::new(test_schema(), test_config());
612 assert_eq!(source.health_check(), HealthStatus::Unknown);
613 }
614
615 #[test]
616 fn test_metrics_initial() {
617 let source = WebSocketSource::new(test_schema(), test_config());
618 let m = source.metrics();
619 assert_eq!(m.records_total, 0);
620 assert_eq!(m.bytes_total, 0);
621 }
622}