laminar_connectors/websocket/
source_server.rs1use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::Arc;
9
10use arrow_schema::SchemaRef;
11use async_trait::async_trait;
12use crossfire::{mpsc, AsyncRx, TryRecvError};
13use futures_util::StreamExt;
14use tokio::net::TcpListener;
15use tokio::sync::Notify;
16use tracing::{debug, info, warn};
17
18use crate::checkpoint::SourceCheckpoint;
19use crate::config::{ConnectorConfig, ConnectorState};
20use crate::connector::{SourceBatch, SourceConnector};
21use crate::error::ConnectorError;
22use crate::health::HealthStatus;
23use crate::metrics::ConnectorMetrics;
24
25use crate::schema::json::decoder::JsonDecoderConfig;
26
27use super::checkpoint::WebSocketSourceCheckpoint;
28use super::metrics::WebSocketSourceMetrics;
29use super::parser::MessageParser;
30use super::source_config::{SourceMode, WebSocketSourceConfig};
31
32pub struct WebSocketSourceServer {
38 config: WebSocketSourceConfig,
40 schema: SchemaRef,
42 parser: MessageParser,
44 state: ConnectorState,
46 metrics: WebSocketSourceMetrics,
48 checkpoint_state: WebSocketSourceCheckpoint,
50 rx: Option<AsyncRx<mpsc::Array<Vec<u8>>>>,
52 shutdown_tx: Option<tokio::sync::watch::Sender<bool>>,
54 acceptor_handle: Option<tokio::task::JoinHandle<()>>,
56 message_buffer: Vec<Vec<u8>>,
58 connected_clients: Arc<AtomicU64>,
60 max_batch_size: usize,
62 data_ready: Arc<Notify>,
64}
65
66impl WebSocketSourceServer {
67 #[must_use]
69 pub fn new(
70 schema: SchemaRef,
71 config: WebSocketSourceConfig,
72 registry: Option<&prometheus::Registry>,
73 ) -> Self {
74 let parser = MessageParser::new(
75 schema.clone(),
76 config.format.clone(),
77 JsonDecoderConfig::default(),
78 );
79
80 Self {
81 config,
82 schema,
83 parser,
84 state: ConnectorState::Created,
85 metrics: WebSocketSourceMetrics::new(registry),
86 checkpoint_state: WebSocketSourceCheckpoint::default(),
87 rx: None,
88 shutdown_tx: None,
89 acceptor_handle: None,
90 message_buffer: Vec::new(),
91 connected_clients: Arc::new(AtomicU64::new(0)),
92 max_batch_size: 1000,
93 data_ready: Arc::new(Notify::new()),
94 }
95 }
96
97 #[must_use]
99 pub fn state(&self) -> ConnectorState {
100 self.state
101 }
102
103 #[must_use]
105 pub fn connected_clients(&self) -> u64 {
106 self.connected_clients.load(Ordering::Relaxed)
107 }
108}
109
110#[async_trait]
111#[allow(clippy::too_many_lines)]
112impl SourceConnector for WebSocketSourceServer {
113 async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
114 self.state = ConnectorState::Initializing;
115
116 if !config.properties().is_empty() {
118 self.config = WebSocketSourceConfig::from_config(config)?;
119 }
120
121 let (bind_address, max_connections, _path) = match &self.config.mode {
122 SourceMode::Server {
123 bind_address,
124 max_connections,
125 path,
126 } => (bind_address.clone(), *max_connections, path.clone()),
127 SourceMode::Client { .. } => {
128 return Err(ConnectorError::ConfigurationError(
129 "WebSocketSourceServer is for server mode; use WebSocketSource for client mode"
130 .into(),
131 ));
132 }
133 };
134
135 info!(
136 bind = %bind_address,
137 max_connections,
138 format = ?self.config.format,
139 "opening WebSocket source connector (server mode)"
140 );
141
142 let listener = TcpListener::bind(&bind_address).await.map_err(|e| {
143 ConnectorError::ConnectionFailed(format!("failed to bind {bind_address}: {e}"))
144 })?;
145
146 let channel_capacity = 10_000;
147 let (tx, rx) = mpsc::bounded_async::<Vec<u8>>(channel_capacity);
148 let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
149
150 let connected = Arc::clone(&self.connected_clients);
151 let max_msg_size = self.config.max_message_size;
152 let data_ready = Arc::clone(&self.data_ready);
153
154 let handle = tokio::spawn(async move {
155 let mut shutdown_rx = shutdown_rx;
156
157 loop {
158 tokio::select! {
159 accept_result = listener.accept() => {
160 match accept_result {
161 Ok((stream, addr)) => {
162 let current = connected.load(Ordering::Relaxed);
163 if current >= max_connections as u64 {
164 warn!(
165 current_connections = current,
166 max = max_connections,
167 addr = %addr,
168 "rejecting connection: max_connections exceeded"
169 );
170 drop(stream);
171 continue;
172 }
173
174 let _ = stream.set_nodelay(true);
176
177 let tx = tx.clone();
178 let connected = Arc::clone(&connected);
179 let mut client_shutdown = shutdown_rx.clone();
180 let data_ready = Arc::clone(&data_ready);
181
182 connected.fetch_add(1, Ordering::Relaxed);
183 debug!(addr = %addr, "accepted WebSocket client");
184
185 tokio::spawn(async move {
186 let mut ws_config = tungstenite::protocol::WebSocketConfig::default();
187 ws_config.max_message_size = Some(max_msg_size);
188 ws_config.max_frame_size = Some(max_msg_size);
189 let ws_stream = match tokio_tungstenite::accept_async_with_config(stream, Some(ws_config)).await {
190 Ok(ws) => ws,
191 Err(e) => {
192 warn!(addr = %addr, error = %e, "WebSocket handshake failed");
193 connected.fetch_sub(1, Ordering::Relaxed);
194 return;
195 }
196 };
197
198 let (_write, mut read) = ws_stream.split();
199
200 loop {
201 tokio::select! {
202 msg = read.next() => {
203 match msg {
204 Some(Ok(tungstenite::Message::Text(text))) => {
205 let payload = text.as_bytes().to_vec();
206 if payload.len() > max_msg_size {
207 warn!(
208 size = payload.len(),
209 max = max_msg_size,
210 addr = %addr,
211 "dropping oversized text message"
212 );
213 } else if tx.send(payload).await.is_err() {
214 break;
215 } else {
216 data_ready.notify_one();
217 }
218 }
219 Some(Ok(tungstenite::Message::Binary(data))) => {
220 let payload = data.to_vec();
221 if payload.len() > max_msg_size {
222 warn!(
223 size = payload.len(),
224 max = max_msg_size,
225 addr = %addr,
226 "dropping oversized binary message"
227 );
228 } else if tx.send(payload).await.is_err() {
229 break;
230 } else {
231 data_ready.notify_one();
232 }
233 }
234 Some(Ok(tungstenite::Message::Close(_))) | None => break,
235 Some(Ok(_)) => {} Some(Err(e)) => {
237 debug!(addr = %addr, error = %e, "client read error");
238 break;
239 }
240 }
241 }
242 _ = client_shutdown.changed() => break,
243 }
244 }
245
246 connected.fetch_sub(1, Ordering::Relaxed);
247 debug!(addr = %addr, "client disconnected");
248 });
249 }
250 Err(e) => {
251 warn!(error = %e, "accept error");
252 }
253 }
254 }
255 _ = shutdown_rx.changed() => {
256 info!("acceptor shutting down");
257 break;
258 }
259 }
260 }
261 });
262
263 self.rx = Some(rx);
264 self.shutdown_tx = Some(shutdown_tx);
265 self.acceptor_handle = Some(handle);
266 self.state = ConnectorState::Running;
267
268 info!(bind = %bind_address, "WebSocket source server started");
269 Ok(())
270 }
271
272 #[allow(clippy::cast_possible_truncation)]
273 async fn poll_batch(
274 &mut self,
275 max_records: usize,
276 ) -> Result<Option<SourceBatch>, ConnectorError> {
277 if self.state != ConnectorState::Running {
278 return Err(ConnectorError::InvalidState {
279 expected: "Running".into(),
280 actual: self.state.to_string(),
281 });
282 }
283
284 let rx = self
285 .rx
286 .as_mut()
287 .ok_or_else(|| ConnectorError::InvalidState {
288 expected: "channel initialized".into(),
289 actual: "channel is None".into(),
290 })?;
291
292 let limit = max_records.min(self.max_batch_size);
293
294 while self.message_buffer.len() < limit {
297 match rx.try_recv() {
298 Ok(payload) => {
299 self.metrics.record_message(payload.len() as u64);
300 self.message_buffer.push(payload);
301 }
302 Err(TryRecvError::Empty) => break,
303 Err(TryRecvError::Disconnected) => {
304 if self.message_buffer.is_empty() {
305 self.state = ConnectorState::Failed;
306 return Err(ConnectorError::ReadError(
307 "WebSocket source server acceptor terminated".into(),
308 ));
309 }
310 break;
315 }
316 }
317 }
318
319 self.metrics
320 .set_connected_clients(self.connected_clients.load(Ordering::Relaxed));
321
322 if self.message_buffer.is_empty() {
323 return Ok(None);
324 }
325
326 let refs: Vec<&[u8]> = self.message_buffer.iter().map(Vec::as_slice).collect();
327 let batch = self.parser.parse_batch(&refs).inspect_err(|_e| {
328 self.metrics.record_parse_error();
329 })?;
330
331 let num_rows = batch.num_rows();
332 self.message_buffer.clear();
333
334 if let Some(ref field) = self.config.event_time_field {
336 if let Some(max_ts) = super::parser::extract_max_event_time(&batch, field)? {
337 self.checkpoint_state.watermark = max_ts;
338 }
339 } else {
340 self.checkpoint_state.watermark = std::time::SystemTime::now()
341 .duration_since(std::time::UNIX_EPOCH)
342 .unwrap_or_default()
343 .as_millis() as i64;
344 }
345
346 debug!(
347 records = num_rows,
348 clients = self.connected_clients(),
349 "polled batch from WebSocket server source"
350 );
351 Ok(Some(SourceBatch::new(batch)))
352 }
353
354 fn schema(&self) -> SchemaRef {
355 self.schema.clone()
356 }
357
358 fn checkpoint(&self) -> SourceCheckpoint {
359 self.checkpoint_state.to_source_checkpoint(0)
360 }
361
362 async fn restore(&mut self, checkpoint: &SourceCheckpoint) -> Result<(), ConnectorError> {
363 info!(
364 epoch = checkpoint.epoch(),
365 "restoring WebSocket source server from checkpoint (best-effort)"
366 );
367 self.checkpoint_state = WebSocketSourceCheckpoint::from_source_checkpoint(checkpoint);
368 warn!("WebSocket source server restored; data gap expected (non-replayable)");
369 Ok(())
370 }
371
372 fn health_check(&self) -> HealthStatus {
373 match self.state {
374 ConnectorState::Running => HealthStatus::Healthy,
375 ConnectorState::Created | ConnectorState::Initializing => HealthStatus::Unknown,
376 ConnectorState::Paused => HealthStatus::Degraded("connector paused".into()),
377 ConnectorState::Recovering => HealthStatus::Degraded("recovering".into()),
378 ConnectorState::Closed => HealthStatus::Unhealthy("closed".into()),
379 ConnectorState::Failed => HealthStatus::Unhealthy("failed".into()),
380 }
381 }
382
383 fn metrics(&self) -> ConnectorMetrics {
384 self.metrics.to_connector_metrics()
385 }
386
387 fn data_ready_notify(&self) -> Option<Arc<Notify>> {
388 Some(Arc::clone(&self.data_ready))
389 }
390
391 fn supports_replay(&self) -> bool {
392 false
393 }
394
395 async fn close(&mut self) -> Result<(), ConnectorError> {
396 info!("closing WebSocket source server");
397
398 if let Some(tx) = self.shutdown_tx.take() {
399 let _ = tx.send(true);
400 }
401
402 if let Some(handle) = self.acceptor_handle.take() {
403 let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
404 }
405
406 self.rx = None;
407 self.state = ConnectorState::Closed;
408 info!("WebSocket source server closed");
409 Ok(())
410 }
411}
412
413impl std::fmt::Debug for WebSocketSourceServer {
414 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
415 f.debug_struct("WebSocketSourceServer")
416 .field("state", &self.state)
417 .field("connected_clients", &self.connected_clients())
418 .finish_non_exhaustive()
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::super::source_config::MessageFormat;
425 use super::*;
426 use arrow_schema::{DataType, Field, Schema};
427
428 fn test_schema() -> SchemaRef {
429 Arc::new(Schema::new(vec![
430 Field::new("id", DataType::Utf8, true),
431 Field::new("value", DataType::Utf8, true),
432 ]))
433 }
434
435 fn test_config() -> WebSocketSourceConfig {
436 WebSocketSourceConfig {
437 mode: SourceMode::Server {
438 bind_address: "127.0.0.1:0".into(),
439 max_connections: 100,
440 path: None,
441 },
442 format: MessageFormat::Json,
443 on_backpressure: super::super::backpressure::WsBackpressure::Block,
444 event_time_field: None,
445 event_time_format: None,
446 max_message_size: 64 * 1024 * 1024,
447 auth: None,
448 }
449 }
450
451 #[test]
452 fn test_new() {
453 let server = WebSocketSourceServer::new(test_schema(), test_config(), None);
454 assert_eq!(server.state(), ConnectorState::Created);
455 assert_eq!(server.connected_clients(), 0);
456 }
457
458 #[test]
459 fn test_schema_returned() {
460 let schema = test_schema();
461 let server = WebSocketSourceServer::new(schema.clone(), test_config(), None);
462 assert_eq!(server.schema(), schema);
463 }
464
465 #[test]
466 fn test_health_created() {
467 let server = WebSocketSourceServer::new(test_schema(), test_config(), None);
468 assert_eq!(server.health_check(), HealthStatus::Unknown);
469 }
470}