1use std::sync::Arc;
13
14use arrow_schema::SchemaRef;
15use async_trait::async_trait;
16use crossfire::{mpsc, AsyncRx, MAsyncTx, TryRecvError, TrySendError};
17use futures_util::{SinkExt, StreamExt};
18use tokio::sync::Notify;
19use tracing::{debug, info, warn};
20
21use crate::checkpoint::SourceCheckpoint;
22use crate::config::{ConnectorConfig, ConnectorState};
23use crate::connector::{SourceBatch, SourceConnector};
24use crate::error::ConnectorError;
25
26use crate::schema::json::decoder::JsonDecoderConfig;
27
28use super::backpressure::WsBackpressure;
29use super::checkpoint::WebSocketSourceCheckpoint;
30use super::connection::ConnectionManager;
31use super::metrics::WebSocketSourceMetrics;
32use super::parser::MessageParser;
33use super::source_config::{ReconnectConfig, SourceMode, WebSocketSourceConfig};
34
35enum WsMessage {
37 Data(Vec<u8>),
39 Disconnected(String),
41}
42
43pub struct WebSocketSource {
51 config: WebSocketSourceConfig,
53 schema: SchemaRef,
55 parser: MessageParser,
57 state: ConnectorState,
59 metrics: WebSocketSourceMetrics,
61 checkpoint_state: WebSocketSourceCheckpoint,
63 rx: Option<AsyncRx<mpsc::Array<WsMessage>>>,
65 shutdown_tx: Option<tokio::sync::watch::Sender<bool>>,
67 reader_handle: Option<tokio::task::JoinHandle<()>>,
69 message_buffer: Vec<Vec<u8>>,
71 max_batch_size: usize,
73 data_ready: Arc<Notify>,
75}
76
77impl WebSocketSource {
78 #[must_use]
80 pub fn new(
81 schema: SchemaRef,
82 config: WebSocketSourceConfig,
83 registry: Option<&prometheus::Registry>,
84 ) -> Self {
85 let parser = MessageParser::new(
86 schema.clone(),
87 config.format.clone(),
88 JsonDecoderConfig::default(),
89 );
90
91 Self {
92 config,
93 schema,
94 parser,
95 state: ConnectorState::Created,
96 metrics: WebSocketSourceMetrics::new(registry),
97 checkpoint_state: WebSocketSourceCheckpoint::default(),
98 rx: None,
99 shutdown_tx: None,
100 reader_handle: None,
101 message_buffer: Vec::new(),
102 max_batch_size: 1000,
103 data_ready: Arc::new(Notify::new()),
104 }
105 }
106
107 #[must_use]
109 pub fn state(&self) -> ConnectorState {
110 self.state
111 }
112
113 #[allow(
116 clippy::too_many_arguments,
117 clippy::too_many_lines,
118 clippy::unused_self
119 )]
120 fn spawn_reader(
121 &self,
122 urls: Vec<String>,
123 subscribe_message: Option<String>,
124 reconnect: ReconnectConfig,
125 max_message_size: usize,
126 on_backpressure: WsBackpressure,
127 tx: MAsyncTx<mpsc::Array<WsMessage>>,
128 mut shutdown_rx: tokio::sync::watch::Receiver<bool>,
129 data_ready: Arc<Notify>,
130 ) -> tokio::task::JoinHandle<()> {
131 tokio::spawn(async move {
132 let mut conn_mgr = ConnectionManager::new(urls, reconnect);
133
134 'outer: loop {
135 if *shutdown_rx.borrow() {
137 break;
138 }
139
140 let url = conn_mgr.current_url().to_string();
141 info!(url = %url, "connecting to WebSocket server");
142
143 let mut ws_config = tungstenite::protocol::WebSocketConfig::default();
145 ws_config.max_message_size = Some(max_message_size);
146 ws_config.max_frame_size = Some(max_message_size);
147 let ws_stream = match tokio_tungstenite::connect_async_with_config(
148 &url,
149 Some(ws_config),
150 true, )
152 .await
153 {
154 Ok((stream, _response)) => {
155 conn_mgr.reset();
156 info!(url = %url, "WebSocket connection established");
157 stream
158 }
159 Err(e) => {
160 warn!(url = %url, error = %e, "WebSocket connection failed");
161 if let Some(delay) = conn_mgr.next_backoff() {
162 tokio::select! {
163 () = tokio::time::sleep(delay) => continue,
164 _ = shutdown_rx.changed() => break,
165 }
166 } else {
167 let _ = tx
168 .send(WsMessage::Disconnected(format!(
169 "connection failed, no more retries: {e}"
170 )))
171 .await;
172 break;
173 }
174 }
175 };
176
177 let (mut write, mut read) = ws_stream.split();
178
179 if let Some(ref msg) = subscribe_message {
181 if let Err(e) = write
182 .send(tungstenite::Message::Text(msg.clone().into()))
183 .await
184 {
185 warn!(error = %e, "failed to send subscription message");
186 if let Some(delay) = conn_mgr.next_backoff() {
187 tokio::select! {
188 () = tokio::time::sleep(delay) => continue,
189 _ = shutdown_rx.changed() => break,
190 }
191 }
192 continue;
193 }
194 debug!("subscription message sent");
195 }
196
197 loop {
199 tokio::select! {
200 msg = read.next() => {
201 match msg {
202 Some(Ok(tungstenite::Message::Text(text))) => {
203 let payload = text.as_bytes().to_vec();
204 if payload.len() > max_message_size {
205 warn!(size = payload.len(), max = max_message_size, "message exceeds max size, dropping");
206 continue;
207 }
208 if send_with_backpressure(&tx, WsMessage::Data(payload), &on_backpressure, &data_ready).await.is_err() {
209 break 'outer;
210 }
211 }
212 Some(Ok(tungstenite::Message::Binary(data))) => {
213 let payload = data.to_vec();
214 if payload.len() > max_message_size {
215 warn!(size = payload.len(), max = max_message_size, "message exceeds max size, dropping");
216 continue;
217 }
218 if send_with_backpressure(&tx, WsMessage::Data(payload), &on_backpressure, &data_ready).await.is_err() {
219 break 'outer;
220 }
221 }
222 Some(Ok(tungstenite::Message::Ping(data))) => {
223 let _ = write.send(tungstenite::Message::Pong(data)).await;
224 }
225 Some(Ok(tungstenite::Message::Close(_))) => {
226 info!(url = %url, "server sent Close frame");
227 break;
228 }
229 Some(Ok(_)) => {} Some(Err(e)) => {
231 warn!(url = %url, error = %e, "WebSocket read error");
232 break;
233 }
234 None => {
235 info!(url = %url, "WebSocket stream ended");
236 break;
237 }
238 }
239 }
240 _ = shutdown_rx.changed() => {
241 debug!("shutdown signal received in reader");
242 let _ = write.send(tungstenite::Message::Close(None)).await;
243 break 'outer;
244 }
245 }
246 }
247
248 let _ = tx
250 .send(WsMessage::Disconnected(format!("disconnected from {url}")))
251 .await;
252
253 if let Some(delay) = conn_mgr.next_backoff() {
254 tokio::select! {
255 () = tokio::time::sleep(delay) => {},
256 _ = shutdown_rx.changed() => break,
257 }
258 } else {
259 break;
260 }
261 }
262 })
263 }
264}
265
266async fn send_with_backpressure(
272 tx: &MAsyncTx<mpsc::Array<WsMessage>>,
273 msg: WsMessage,
274 strategy: &WsBackpressure,
275 data_ready: &Notify,
276) -> Result<(), ()> {
277 let result = match strategy {
278 WsBackpressure::Block => tx.send(msg).await.map_err(|_| ()),
279 WsBackpressure::DropNewest => match tx.try_send(msg) {
280 Ok(()) | Err(TrySendError::Full(_)) => Ok(()),
281 Err(TrySendError::Disconnected(_)) => Err(()),
282 },
283 WsBackpressure::DropOldest
285 | WsBackpressure::Buffer { .. }
286 | WsBackpressure::Sample { .. } => match tx.try_send(msg) {
287 Ok(()) | Err(TrySendError::Full(_)) => Ok(()),
288 Err(TrySendError::Disconnected(_)) => Err(()),
289 },
290 };
291 if result.is_ok() {
292 data_ready.notify_one();
293 }
294 result
295}
296
297#[async_trait]
298impl SourceConnector for WebSocketSource {
299 async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
300 self.state = ConnectorState::Initializing;
301
302 if !config.properties().is_empty() {
304 self.config = WebSocketSourceConfig::from_config(config)?;
305 }
306
307 if let Some(schema) = config.arrow_schema() {
309 info!(
310 fields = schema.fields().len(),
311 "using SQL-defined schema for deserialization"
312 );
313 self.schema = schema;
314 let decoder_config = JsonDecoderConfig::from_connector_config(config);
315 self.parser = MessageParser::new(
316 self.schema.clone(),
317 self.config.format.clone(),
318 decoder_config,
319 );
320 }
321
322 let mode = &self.config.mode;
323 let (urls, subscribe_message, reconnect, ping_interval, ping_timeout) = match mode {
324 SourceMode::Client {
325 urls,
326 subscribe_message,
327 reconnect,
328 ping_interval,
329 ping_timeout,
330 } => (
331 urls.clone(),
332 subscribe_message.clone(),
333 reconnect.clone(),
334 *ping_interval,
335 *ping_timeout,
336 ),
337 SourceMode::Server { .. } => {
338 return Err(ConnectorError::ConfigurationError(
339 "WebSocketSource is for client mode; use WebSocketSourceServer for server mode"
340 .into(),
341 ));
342 }
343 };
344
345 if urls.is_empty() {
346 return Err(ConnectorError::ConfigurationError(
347 "at least one WebSocket URL is required".into(),
348 ));
349 }
350
351 let _ = (ping_interval, ping_timeout);
355
356 info!(
357 urls = ?urls,
358 format = ?self.config.format,
359 backpressure = ?self.config.on_backpressure,
360 "opening WebSocket source connector (client mode)"
361 );
362
363 if matches!(
364 self.config.on_backpressure,
365 WsBackpressure::DropOldest
366 | WsBackpressure::Buffer { .. }
367 | WsBackpressure::Sample { .. }
368 ) {
369 warn!(
370 strategy = ?self.config.on_backpressure,
371 "backpressure strategy not implemented, falling back to DropNewest"
372 );
373 }
374
375 let channel_capacity = 10_000;
377 let (tx, rx) = mpsc::bounded_async::<WsMessage>(channel_capacity);
378
379 let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
381
382 let handle = self.spawn_reader(
384 urls.clone(),
385 subscribe_message,
386 reconnect.clone(),
387 self.config.max_message_size,
388 self.config.on_backpressure.clone(),
389 tx,
390 shutdown_rx,
391 Arc::clone(&self.data_ready),
392 );
393
394 self.rx = Some(rx);
395 self.shutdown_tx = Some(shutdown_tx);
396 self.reader_handle = Some(handle);
397 self.state = ConnectorState::Running;
398
399 info!("WebSocket source connector opened successfully");
400 Ok(())
401 }
402
403 #[allow(clippy::cast_possible_truncation)]
404 async fn poll_batch(
405 &mut self,
406 max_records: usize,
407 ) -> Result<Option<SourceBatch>, ConnectorError> {
408 if self.state != ConnectorState::Running {
409 return Err(ConnectorError::InvalidState {
410 expected: "Running".into(),
411 actual: self.state.to_string(),
412 });
413 }
414
415 let rx = self
416 .rx
417 .as_mut()
418 .ok_or_else(|| ConnectorError::InvalidState {
419 expected: "channel initialized".into(),
420 actual: "channel is None".into(),
421 })?;
422
423 let limit = max_records.min(self.max_batch_size);
424
425 while self.message_buffer.len() < limit {
428 match rx.try_recv() {
429 Ok(WsMessage::Data(payload)) => {
430 self.metrics.record_message(payload.len() as u64);
431 self.message_buffer.push(payload);
432 }
433 Ok(WsMessage::Disconnected(reason)) => {
434 self.metrics.record_reconnect();
435 warn!(reason = %reason, "WebSocket disconnected");
436 break;
437 }
438 Err(TryRecvError::Empty) => break,
439 Err(TryRecvError::Disconnected) => {
440 if self.message_buffer.is_empty() {
442 self.state = ConnectorState::Failed;
443 return Err(ConnectorError::ReadError(
444 "WebSocket reader task terminated".into(),
445 ));
446 }
447 break;
449 }
450 }
451 }
452
453 if self.message_buffer.is_empty() {
454 return Ok(None);
455 }
456
457 let refs: Vec<&[u8]> = self.message_buffer.iter().map(Vec::as_slice).collect();
459 let batch = self.parser.parse_batch(&refs).inspect_err(|_e| {
460 self.metrics.record_parse_error();
461 })?;
462
463 let num_rows = batch.num_rows();
464 self.message_buffer.clear();
465
466 self.checkpoint_state.watermark = std::time::SystemTime::now()
472 .duration_since(std::time::UNIX_EPOCH)
473 .unwrap_or_default()
474 .as_millis() as i64;
475
476 debug!(records = num_rows, "polled batch from WebSocket");
477 Ok(Some(SourceBatch::new(batch)))
478 }
479
480 fn schema(&self) -> SchemaRef {
481 self.schema.clone()
482 }
483
484 fn checkpoint(&self) -> SourceCheckpoint {
485 self.checkpoint_state.to_source_checkpoint(0)
489 }
490
491 async fn restore(&mut self, checkpoint: &SourceCheckpoint) -> Result<(), ConnectorError> {
492 info!(
493 epoch = checkpoint.epoch(),
494 "restoring WebSocket source from checkpoint (best-effort)"
495 );
496 self.checkpoint_state = WebSocketSourceCheckpoint::from_source_checkpoint(checkpoint);
497
498 warn!(
500 last_sequence = ?self.checkpoint_state.last_sequence,
501 last_event_time = ?self.checkpoint_state.last_event_time,
502 "WebSocket source restored; data gap expected (non-replayable transport)"
503 );
504
505 Ok(())
506 }
507
508 fn data_ready_notify(&self) -> Option<Arc<Notify>> {
509 Some(Arc::clone(&self.data_ready))
510 }
511
512 fn supports_replay(&self) -> bool {
513 false
514 }
515
516 async fn close(&mut self) -> Result<(), ConnectorError> {
517 info!("closing WebSocket source connector");
518
519 if let Some(tx) = self.shutdown_tx.take() {
521 let _ = tx.send(true);
522 }
523
524 if let Some(handle) = self.reader_handle.take() {
526 let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
527 }
528
529 self.rx = None;
530 self.state = ConnectorState::Closed;
531 info!("WebSocket source connector closed");
532 Ok(())
533 }
534}
535
536impl std::fmt::Debug for WebSocketSource {
537 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
538 f.debug_struct("WebSocketSource")
539 .field("state", &self.state)
540 .field("mode", &"client")
541 .field("format", &self.config.format)
542 .finish_non_exhaustive()
543 }
544}
545
546#[cfg(test)]
547mod tests {
548 use super::super::source_config::MessageFormat;
549 use super::*;
550 use arrow_schema::{DataType, Field, Schema};
551 use std::sync::Arc;
552
553 fn test_schema() -> SchemaRef {
554 Arc::new(Schema::new(vec![
555 Field::new("id", DataType::Utf8, true),
556 Field::new("value", DataType::Utf8, true),
557 ]))
558 }
559
560 fn test_config() -> WebSocketSourceConfig {
561 WebSocketSourceConfig {
562 mode: SourceMode::Client {
563 urls: vec!["ws://localhost:9090".into()],
564 subscribe_message: None,
565 reconnect: ReconnectConfig::default(),
566 ping_interval: std::time::Duration::from_secs(30),
567 ping_timeout: std::time::Duration::from_secs(10),
568 },
569 format: MessageFormat::Json,
570 on_backpressure: WsBackpressure::Block,
571 event_time_field: None,
572 event_time_format: None,
573 max_message_size: 64 * 1024 * 1024,
574 auth: None,
575 }
576 }
577
578 #[test]
579 fn test_new_defaults() {
580 let source = WebSocketSource::new(test_schema(), test_config(), None);
581 assert_eq!(source.state(), ConnectorState::Created);
582 }
583
584 #[test]
585 fn test_schema_returned() {
586 let schema = test_schema();
587 let source = WebSocketSource::new(schema.clone(), test_config(), None);
588 assert_eq!(source.schema(), schema);
589 }
590
591 #[test]
592 fn test_checkpoint_empty() {
593 let source = WebSocketSource::new(test_schema(), test_config(), None);
594 let cp = source.checkpoint();
595 assert!(!cp.is_empty()); }
597}