laminar_connectors/websocket/
source_server.rs1use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::Arc;
9
10use arrow_schema::SchemaRef;
11use async_trait::async_trait;
12use crossfire::{mpsc, AsyncRx, TryRecvError};
13use futures_util::StreamExt;
14use tokio::net::TcpListener;
15use tokio::sync::Notify;
16use tracing::{debug, info, warn};
17
18use crate::checkpoint::SourceCheckpoint;
19use crate::config::{ConnectorConfig, ConnectorState};
20use crate::connector::{SourceBatch, SourceConnector};
21use crate::error::ConnectorError;
22
23use crate::schema::json::decoder::JsonDecoderConfig;
24
25use super::checkpoint::WebSocketSourceCheckpoint;
26use super::metrics::WebSocketSourceMetrics;
27use super::parser::MessageParser;
28use super::source_config::{SourceMode, WebSocketSourceConfig};
29
30pub struct WebSocketSourceServer {
36 config: WebSocketSourceConfig,
38 schema: SchemaRef,
40 parser: MessageParser,
42 state: ConnectorState,
44 metrics: WebSocketSourceMetrics,
46 checkpoint_state: WebSocketSourceCheckpoint,
48 rx: Option<AsyncRx<mpsc::Array<Vec<u8>>>>,
50 shutdown_tx: Option<tokio::sync::watch::Sender<bool>>,
52 acceptor_handle: Option<tokio::task::JoinHandle<()>>,
54 message_buffer: Vec<Vec<u8>>,
56 connected_clients: Arc<AtomicU64>,
58 max_batch_size: usize,
60 data_ready: Arc<Notify>,
62}
63
64impl WebSocketSourceServer {
65 #[must_use]
67 pub fn new(
68 schema: SchemaRef,
69 config: WebSocketSourceConfig,
70 registry: Option<&prometheus::Registry>,
71 ) -> Self {
72 let parser = MessageParser::new(
73 schema.clone(),
74 config.format.clone(),
75 JsonDecoderConfig::default(),
76 );
77
78 Self {
79 config,
80 schema,
81 parser,
82 state: ConnectorState::Created,
83 metrics: WebSocketSourceMetrics::new(registry),
84 checkpoint_state: WebSocketSourceCheckpoint::default(),
85 rx: None,
86 shutdown_tx: None,
87 acceptor_handle: None,
88 message_buffer: Vec::new(),
89 connected_clients: Arc::new(AtomicU64::new(0)),
90 max_batch_size: 1000,
91 data_ready: Arc::new(Notify::new()),
92 }
93 }
94
95 #[must_use]
97 pub fn state(&self) -> ConnectorState {
98 self.state
99 }
100
101 #[must_use]
103 pub fn connected_clients(&self) -> u64 {
104 self.connected_clients.load(Ordering::Relaxed)
105 }
106}
107
108#[async_trait]
109#[allow(clippy::too_many_lines)]
110impl SourceConnector for WebSocketSourceServer {
111 async fn open(&mut self, config: &ConnectorConfig) -> Result<(), ConnectorError> {
112 self.state = ConnectorState::Initializing;
113
114 if !config.properties().is_empty() {
116 self.config = WebSocketSourceConfig::from_config(config)?;
117 }
118
119 let (bind_address, max_connections, _path) = match &self.config.mode {
120 SourceMode::Server {
121 bind_address,
122 max_connections,
123 path,
124 } => (bind_address.clone(), *max_connections, path.clone()),
125 SourceMode::Client { .. } => {
126 return Err(ConnectorError::ConfigurationError(
127 "WebSocketSourceServer is for server mode; use WebSocketSource for client mode"
128 .into(),
129 ));
130 }
131 };
132
133 info!(
134 bind = %bind_address,
135 max_connections,
136 format = ?self.config.format,
137 "opening WebSocket source connector (server mode)"
138 );
139
140 let listener = TcpListener::bind(&bind_address).await.map_err(|e| {
141 ConnectorError::ConnectionFailed(format!("failed to bind {bind_address}: {e}"))
142 })?;
143
144 let channel_capacity = 10_000;
145 let (tx, rx) = mpsc::bounded_async::<Vec<u8>>(channel_capacity);
146 let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
147
148 let connected = Arc::clone(&self.connected_clients);
149 let max_msg_size = self.config.max_message_size;
150 let data_ready = Arc::clone(&self.data_ready);
151
152 let handle = tokio::spawn(async move {
153 let mut shutdown_rx = shutdown_rx;
154
155 loop {
156 tokio::select! {
157 accept_result = listener.accept() => {
158 match accept_result {
159 Ok((stream, addr)) => {
160 let current = connected.load(Ordering::Relaxed);
161 if current >= max_connections as u64 {
162 warn!(
163 current_connections = current,
164 max = max_connections,
165 addr = %addr,
166 "rejecting connection: max_connections exceeded"
167 );
168 drop(stream);
169 continue;
170 }
171
172 let _ = stream.set_nodelay(true);
174
175 let tx = tx.clone();
176 let connected = Arc::clone(&connected);
177 let mut client_shutdown = shutdown_rx.clone();
178 let data_ready = Arc::clone(&data_ready);
179
180 connected.fetch_add(1, Ordering::Relaxed);
181 debug!(addr = %addr, "accepted WebSocket client");
182
183 tokio::spawn(async move {
184 let mut ws_config = tungstenite::protocol::WebSocketConfig::default();
185 ws_config.max_message_size = Some(max_msg_size);
186 ws_config.max_frame_size = Some(max_msg_size);
187 let ws_stream = match tokio_tungstenite::accept_async_with_config(stream, Some(ws_config)).await {
188 Ok(ws) => ws,
189 Err(e) => {
190 warn!(addr = %addr, error = %e, "WebSocket handshake failed");
191 connected.fetch_sub(1, Ordering::Relaxed);
192 return;
193 }
194 };
195
196 let (_write, mut read) = ws_stream.split();
197
198 loop {
199 tokio::select! {
200 msg = read.next() => {
201 match msg {
202 Some(Ok(tungstenite::Message::Text(text))) => {
203 let payload = text.as_bytes().to_vec();
204 if payload.len() > max_msg_size {
205 warn!(
206 size = payload.len(),
207 max = max_msg_size,
208 addr = %addr,
209 "dropping oversized text message"
210 );
211 } else if tx.send(payload).await.is_err() {
212 break;
213 } else {
214 data_ready.notify_one();
215 }
216 }
217 Some(Ok(tungstenite::Message::Binary(data))) => {
218 let payload = data.to_vec();
219 if payload.len() > max_msg_size {
220 warn!(
221 size = payload.len(),
222 max = max_msg_size,
223 addr = %addr,
224 "dropping oversized binary message"
225 );
226 } else if tx.send(payload).await.is_err() {
227 break;
228 } else {
229 data_ready.notify_one();
230 }
231 }
232 Some(Ok(tungstenite::Message::Close(_))) | None => break,
233 Some(Ok(_)) => {} Some(Err(e)) => {
235 debug!(addr = %addr, error = %e, "client read error");
236 break;
237 }
238 }
239 }
240 _ = client_shutdown.changed() => break,
241 }
242 }
243
244 connected.fetch_sub(1, Ordering::Relaxed);
245 debug!(addr = %addr, "client disconnected");
246 });
247 }
248 Err(e) => {
249 warn!(error = %e, "accept error");
250 }
251 }
252 }
253 _ = shutdown_rx.changed() => {
254 info!("acceptor shutting down");
255 break;
256 }
257 }
258 }
259 });
260
261 self.rx = Some(rx);
262 self.shutdown_tx = Some(shutdown_tx);
263 self.acceptor_handle = Some(handle);
264 self.state = ConnectorState::Running;
265
266 info!(bind = %bind_address, "WebSocket source server started");
267 Ok(())
268 }
269
270 #[allow(clippy::cast_possible_truncation)]
271 async fn poll_batch(
272 &mut self,
273 max_records: usize,
274 ) -> Result<Option<SourceBatch>, ConnectorError> {
275 if self.state != ConnectorState::Running {
276 return Err(ConnectorError::InvalidState {
277 expected: "Running".into(),
278 actual: self.state.to_string(),
279 });
280 }
281
282 let rx = self
283 .rx
284 .as_mut()
285 .ok_or_else(|| ConnectorError::InvalidState {
286 expected: "channel initialized".into(),
287 actual: "channel is None".into(),
288 })?;
289
290 let limit = max_records.min(self.max_batch_size);
291
292 while self.message_buffer.len() < limit {
295 match rx.try_recv() {
296 Ok(payload) => {
297 self.metrics.record_message(payload.len() as u64);
298 self.message_buffer.push(payload);
299 }
300 Err(TryRecvError::Empty) => break,
301 Err(TryRecvError::Disconnected) => {
302 if self.message_buffer.is_empty() {
303 self.state = ConnectorState::Failed;
304 return Err(ConnectorError::ReadError(
305 "WebSocket source server acceptor terminated".into(),
306 ));
307 }
308 break;
313 }
314 }
315 }
316
317 self.metrics
318 .set_connected_clients(self.connected_clients.load(Ordering::Relaxed));
319
320 if self.message_buffer.is_empty() {
321 return Ok(None);
322 }
323
324 let refs: Vec<&[u8]> = self.message_buffer.iter().map(Vec::as_slice).collect();
325 let batch = self.parser.parse_batch(&refs).inspect_err(|_e| {
326 self.metrics.record_parse_error();
327 })?;
328
329 let num_rows = batch.num_rows();
330 self.message_buffer.clear();
331
332 if let Some(ref field) = self.config.event_time_field {
334 if let Some(max_ts) = super::parser::extract_max_event_time(&batch, field)? {
335 self.checkpoint_state.watermark = max_ts;
336 }
337 } else {
338 self.checkpoint_state.watermark = std::time::SystemTime::now()
339 .duration_since(std::time::UNIX_EPOCH)
340 .unwrap_or_default()
341 .as_millis() as i64;
342 }
343
344 debug!(
345 records = num_rows,
346 clients = self.connected_clients(),
347 "polled batch from WebSocket server source"
348 );
349 Ok(Some(SourceBatch::new(batch)))
350 }
351
352 fn schema(&self) -> SchemaRef {
353 self.schema.clone()
354 }
355
356 fn checkpoint(&self) -> SourceCheckpoint {
357 self.checkpoint_state.to_source_checkpoint(0)
358 }
359
360 async fn restore(&mut self, checkpoint: &SourceCheckpoint) -> Result<(), ConnectorError> {
361 info!(
362 epoch = checkpoint.epoch(),
363 "restoring WebSocket source server from checkpoint (best-effort)"
364 );
365 self.checkpoint_state = WebSocketSourceCheckpoint::from_source_checkpoint(checkpoint);
366 warn!("WebSocket source server restored; data gap expected (non-replayable)");
367 Ok(())
368 }
369
370 fn data_ready_notify(&self) -> Option<Arc<Notify>> {
371 Some(Arc::clone(&self.data_ready))
372 }
373
374 fn supports_replay(&self) -> bool {
375 false
376 }
377
378 async fn close(&mut self) -> Result<(), ConnectorError> {
379 info!("closing WebSocket source server");
380
381 if let Some(tx) = self.shutdown_tx.take() {
382 let _ = tx.send(true);
383 }
384
385 if let Some(handle) = self.acceptor_handle.take() {
386 let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
387 }
388
389 self.rx = None;
390 self.state = ConnectorState::Closed;
391 info!("WebSocket source server closed");
392 Ok(())
393 }
394}
395
396impl std::fmt::Debug for WebSocketSourceServer {
397 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
398 f.debug_struct("WebSocketSourceServer")
399 .field("state", &self.state)
400 .field("connected_clients", &self.connected_clients())
401 .finish_non_exhaustive()
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::super::source_config::MessageFormat;
408 use super::*;
409 use arrow_schema::{DataType, Field, Schema};
410
411 fn test_schema() -> SchemaRef {
412 Arc::new(Schema::new(vec![
413 Field::new("id", DataType::Utf8, true),
414 Field::new("value", DataType::Utf8, true),
415 ]))
416 }
417
418 fn test_config() -> WebSocketSourceConfig {
419 WebSocketSourceConfig {
420 mode: SourceMode::Server {
421 bind_address: "127.0.0.1:0".into(),
422 max_connections: 100,
423 path: None,
424 },
425 format: MessageFormat::Json,
426 on_backpressure: super::super::backpressure::WsBackpressure::Block,
427 event_time_field: None,
428 event_time_format: None,
429 max_message_size: 64 * 1024 * 1024,
430 auth: None,
431 }
432 }
433
434 #[test]
435 fn test_new() {
436 let server = WebSocketSourceServer::new(test_schema(), test_config(), None);
437 assert_eq!(server.state(), ConnectorState::Created);
438 assert_eq!(server.connected_clients(), 0);
439 }
440
441 #[test]
442 fn test_schema_returned() {
443 let schema = test_schema();
444 let server = WebSocketSourceServer::new(schema.clone(), test_config(), None);
445 assert_eq!(server.schema(), schema);
446 }
447}