laminar_connectors/websocket/
source_server.rs1use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::Arc;
9
10use arrow_schema::SchemaRef;
11use async_trait::async_trait;
12use futures_util::StreamExt;
13use tokio::net::TcpListener;
14use tokio::sync::{mpsc, Notify};
15use tracing::{debug, info, warn};
16
17use crate::checkpoint::SourceCheckpoint;
18use crate::config::{ConnectorConfig, ConnectorState};
19use crate::connector::{SourceBatch, SourceConnector};
20use crate::error::ConnectorError;
21use crate::health::HealthStatus;
22use crate::metrics::ConnectorMetrics;
23
24use crate::schema::json::decoder::JsonDecoderConfig;
25
26use super::checkpoint::WebSocketSourceCheckpoint;
27use super::metrics::WebSocketSourceMetrics;
28use super::parser::MessageParser;
29use super::source_config::{SourceMode, WebSocketSourceConfig};
30
31pub struct WebSocketSourceServer {
37 config: WebSocketSourceConfig,
39 schema: SchemaRef,
41 parser: MessageParser,
43 state: ConnectorState,
45 metrics: WebSocketSourceMetrics,
47 checkpoint_state: WebSocketSourceCheckpoint,
49 rx: Option<mpsc::Receiver<Vec<u8>>>,
51 shutdown_tx: Option<tokio::sync::watch::Sender<bool>>,
53 acceptor_handle: Option<tokio::task::JoinHandle<()>>,
55 message_buffer: Vec<Vec<u8>>,
57 connected_clients: Arc<AtomicU64>,
59 max_batch_size: usize,
61 data_ready: Arc<Notify>,
63}
64
65impl WebSocketSourceServer {
66 #[must_use]
68 pub fn new(schema: SchemaRef, config: WebSocketSourceConfig) -> Self {
69 let parser = MessageParser::new(
70 schema.clone(),
71 config.format.clone(),
72 JsonDecoderConfig::default(),
73 );
74
75 Self {
76 config,
77 schema,
78 parser,
79 state: ConnectorState::Created,
80 metrics: WebSocketSourceMetrics::new(),
81 checkpoint_state: WebSocketSourceCheckpoint::default(),
82 rx: None,
83 shutdown_tx: None,
84 acceptor_handle: None,
85 message_buffer: Vec::new(),
86 connected_clients: Arc::new(AtomicU64::new(0)),
87 max_batch_size: 1000,
88 data_ready: Arc::new(Notify::new()),
89 }
90 }
91
92 #[must_use]
94 pub fn state(&self) -> ConnectorState {
95 self.state
96 }
97
98 #[must_use]
100 pub fn connected_clients(&self) -> u64 {
101 self.connected_clients.load(Ordering::Relaxed)
102 }
103}
104
105#[async_trait]
106#[allow(clippy::too_many_lines)]
107impl SourceConnector for WebSocketSourceServer {
108 async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
109 self.state = ConnectorState::Initializing;
110
111 if !config.properties().is_empty() {
113 self.config = WebSocketSourceConfig::from_config(config)?;
114 }
115
116 let (bind_address, max_connections, _path) = match &self.config.mode {
117 SourceMode::Server {
118 bind_address,
119 max_connections,
120 path,
121 } => (bind_address.clone(), *max_connections, path.clone()),
122 SourceMode::Client { .. } => {
123 return Err(ConnectorError::ConfigurationError(
124 "WebSocketSourceServer is for server mode; use WebSocketSource for client mode"
125 .into(),
126 ));
127 }
128 };
129
130 info!(
131 bind = %bind_address,
132 max_connections,
133 format = ?self.config.format,
134 "opening WebSocket source connector (server mode)"
135 );
136
137 let listener = TcpListener::bind(&bind_address).await.map_err(|e| {
138 ConnectorError::ConnectionFailed(format!("failed to bind {bind_address}: {e}"))
139 })?;
140
141 let channel_capacity = 10_000;
142 let (tx, rx) = mpsc::channel(channel_capacity);
143 let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
144
145 let connected = Arc::clone(&self.connected_clients);
146 let max_msg_size = self.config.max_message_size;
147 let data_ready = Arc::clone(&self.data_ready);
148
149 let handle = tokio::spawn(async move {
150 let mut shutdown_rx = shutdown_rx;
151
152 loop {
153 tokio::select! {
154 accept_result = listener.accept() => {
155 match accept_result {
156 Ok((stream, addr)) => {
157 let current = connected.load(Ordering::Relaxed);
158 if current >= max_connections as u64 {
159 warn!(
160 current_connections = current,
161 max = max_connections,
162 addr = %addr,
163 "rejecting connection: max_connections exceeded"
164 );
165 drop(stream);
166 continue;
167 }
168
169 let _ = stream.set_nodelay(true);
171
172 let tx = tx.clone();
173 let connected = Arc::clone(&connected);
174 let mut client_shutdown = shutdown_rx.clone();
175 let data_ready = Arc::clone(&data_ready);
176
177 connected.fetch_add(1, Ordering::Relaxed);
178 debug!(addr = %addr, "accepted WebSocket client");
179
180 tokio::spawn(async move {
181 let mut ws_config = tungstenite::protocol::WebSocketConfig::default();
182 ws_config.max_message_size = Some(max_msg_size);
183 ws_config.max_frame_size = Some(max_msg_size);
184 let ws_stream = match tokio_tungstenite::accept_async_with_config(stream, Some(ws_config)).await {
185 Ok(ws) => ws,
186 Err(e) => {
187 warn!(addr = %addr, error = %e, "WebSocket handshake failed");
188 connected.fetch_sub(1, Ordering::Relaxed);
189 return;
190 }
191 };
192
193 let (_write, mut read) = ws_stream.split();
194
195 loop {
196 tokio::select! {
197 msg = read.next() => {
198 match msg {
199 Some(Ok(tungstenite::Message::Text(text))) => {
200 let payload = text.as_bytes().to_vec();
201 if payload.len() > max_msg_size {
202 warn!(
203 size = payload.len(),
204 max = max_msg_size,
205 addr = %addr,
206 "dropping oversized text message"
207 );
208 } else if tx.send(payload).await.is_err() {
209 break;
210 } else {
211 data_ready.notify_one();
212 }
213 }
214 Some(Ok(tungstenite::Message::Binary(data))) => {
215 let payload = data.to_vec();
216 if payload.len() > max_msg_size {
217 warn!(
218 size = payload.len(),
219 max = max_msg_size,
220 addr = %addr,
221 "dropping oversized binary message"
222 );
223 } else if tx.send(payload).await.is_err() {
224 break;
225 } else {
226 data_ready.notify_one();
227 }
228 }
229 Some(Ok(tungstenite::Message::Close(_))) | None => break,
230 Some(Ok(_)) => {} Some(Err(e)) => {
232 debug!(addr = %addr, error = %e, "client read error");
233 break;
234 }
235 }
236 }
237 _ = client_shutdown.changed() => break,
238 }
239 }
240
241 connected.fetch_sub(1, Ordering::Relaxed);
242 debug!(addr = %addr, "client disconnected");
243 });
244 }
245 Err(e) => {
246 warn!(error = %e, "accept error");
247 }
248 }
249 }
250 _ = shutdown_rx.changed() => {
251 info!("acceptor shutting down");
252 break;
253 }
254 }
255 }
256 });
257
258 self.rx = Some(rx);
259 self.shutdown_tx = Some(shutdown_tx);
260 self.acceptor_handle = Some(handle);
261 self.state = ConnectorState::Running;
262
263 info!(bind = %bind_address, "WebSocket source server started");
264 Ok(())
265 }
266
267 #[allow(clippy::cast_possible_truncation)]
268 async fn poll_batch(
269 &mut self,
270 max_records: usize,
271 ) -> Result<Option<SourceBatch>, ConnectorError> {
272 if self.state != ConnectorState::Running {
273 return Err(ConnectorError::InvalidState {
274 expected: "Running".into(),
275 actual: self.state.to_string(),
276 });
277 }
278
279 let rx = self
280 .rx
281 .as_mut()
282 .ok_or_else(|| ConnectorError::InvalidState {
283 expected: "channel initialized".into(),
284 actual: "channel is None".into(),
285 })?;
286
287 let limit = max_records.min(self.max_batch_size);
288
289 while self.message_buffer.len() < limit {
292 match rx.try_recv() {
293 Ok(payload) => {
294 self.metrics.record_message(payload.len() as u64);
295 self.message_buffer.push(payload);
296 }
297 Err(mpsc::error::TryRecvError::Empty) => break,
298 Err(mpsc::error::TryRecvError::Disconnected) => {
299 if self.message_buffer.is_empty() {
300 self.state = ConnectorState::Failed;
301 return Err(ConnectorError::ReadError(
302 "WebSocket source server acceptor terminated".into(),
303 ));
304 }
305 break;
310 }
311 }
312 }
313
314 self.metrics
315 .set_connected_clients(self.connected_clients.load(Ordering::Relaxed));
316
317 if self.message_buffer.is_empty() {
318 return Ok(None);
319 }
320
321 let refs: Vec<&[u8]> = self.message_buffer.iter().map(Vec::as_slice).collect();
322 let batch = self.parser.parse_batch(&refs).inspect_err(|_e| {
323 self.metrics.record_parse_error();
324 })?;
325
326 let num_rows = batch.num_rows();
327 self.message_buffer.clear();
328
329 if let Some(ref field) = self.config.event_time_field {
331 if let Some(max_ts) = super::parser::extract_max_event_time(&batch, field) {
332 self.checkpoint_state.watermark = max_ts;
333 }
334 } else {
335 self.checkpoint_state.watermark = std::time::SystemTime::now()
336 .duration_since(std::time::UNIX_EPOCH)
337 .unwrap_or_default()
338 .as_millis() as i64;
339 }
340
341 debug!(
342 records = num_rows,
343 clients = self.connected_clients(),
344 "polled batch from WebSocket server source"
345 );
346 Ok(Some(SourceBatch::new(batch)))
347 }
348
349 fn schema(&self) -> SchemaRef {
350 self.schema.clone()
351 }
352
353 fn checkpoint(&self) -> SourceCheckpoint {
354 self.checkpoint_state.to_source_checkpoint(0)
355 }
356
357 async fn restore(&mut self, checkpoint: &SourceCheckpoint) -> Result<(), ConnectorError> {
358 info!(
359 epoch = checkpoint.epoch(),
360 "restoring WebSocket source server from checkpoint (best-effort)"
361 );
362 self.checkpoint_state = WebSocketSourceCheckpoint::from_source_checkpoint(checkpoint);
363 warn!("WebSocket source server restored; data gap expected (non-replayable)");
364 Ok(())
365 }
366
367 fn health_check(&self) -> HealthStatus {
368 match self.state {
369 ConnectorState::Running => HealthStatus::Healthy,
370 ConnectorState::Created | ConnectorState::Initializing => HealthStatus::Unknown,
371 ConnectorState::Paused => HealthStatus::Degraded("connector paused".into()),
372 ConnectorState::Recovering => HealthStatus::Degraded("recovering".into()),
373 ConnectorState::Closed => HealthStatus::Unhealthy("closed".into()),
374 ConnectorState::Failed => HealthStatus::Unhealthy("failed".into()),
375 }
376 }
377
378 fn metrics(&self) -> ConnectorMetrics {
379 self.metrics.to_connector_metrics()
380 }
381
382 fn data_ready_notify(&self) -> Option<Arc<Notify>> {
383 Some(Arc::clone(&self.data_ready))
384 }
385
386 fn supports_replay(&self) -> bool {
387 false
388 }
389
390 async fn close(&mut self) -> Result<(), ConnectorError> {
391 info!("closing WebSocket source server");
392
393 if let Some(tx) = self.shutdown_tx.take() {
394 let _ = tx.send(true);
395 }
396
397 if let Some(handle) = self.acceptor_handle.take() {
398 let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
399 }
400
401 self.rx = None;
402 self.state = ConnectorState::Closed;
403 info!("WebSocket source server closed");
404 Ok(())
405 }
406}
407
408impl std::fmt::Debug for WebSocketSourceServer {
409 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
410 f.debug_struct("WebSocketSourceServer")
411 .field("state", &self.state)
412 .field("connected_clients", &self.connected_clients())
413 .finish_non_exhaustive()
414 }
415}
416
417#[cfg(test)]
418mod tests {
419 use super::super::source_config::MessageFormat;
420 use super::*;
421 use arrow_schema::{DataType, Field, Schema};
422
423 fn test_schema() -> SchemaRef {
424 Arc::new(Schema::new(vec![
425 Field::new("id", DataType::Utf8, true),
426 Field::new("value", DataType::Utf8, true),
427 ]))
428 }
429
430 fn test_config() -> WebSocketSourceConfig {
431 WebSocketSourceConfig {
432 mode: SourceMode::Server {
433 bind_address: "127.0.0.1:0".into(),
434 max_connections: 100,
435 path: None,
436 },
437 format: MessageFormat::Json,
438 on_backpressure: super::super::backpressure::BackpressureStrategy::Block,
439 event_time_field: None,
440 event_time_format: None,
441 max_message_size: 64 * 1024 * 1024,
442 auth: None,
443 }
444 }
445
446 #[test]
447 fn test_new() {
448 let server = WebSocketSourceServer::new(test_schema(), test_config());
449 assert_eq!(server.state(), ConnectorState::Created);
450 assert_eq!(server.connected_clients(), 0);
451 }
452
453 #[test]
454 fn test_schema_returned() {
455 let schema = test_schema();
456 let server = WebSocketSourceServer::new(schema.clone(), test_config());
457 assert_eq!(server.schema(), schema);
458 }
459
460 #[test]
461 fn test_health_created() {
462 let server = WebSocketSourceServer::new(test_schema(), test_config());
463 assert_eq!(server.health_check(), HealthStatus::Unknown);
464 }
465}