Skip to main content

laminar_connectors/websocket/
fanout.rs

1//! Per-client fan-out manager for WebSocket sink server mode.
2//!
3//! Manages per-client state, bounded send buffers, and slow client
4//! eviction. Each connected client gets its own ring-buffer channel
5//! so that a slow client cannot block or affect other clients.
6//!
7//! The [`RingSender`]/[`RingReceiver`] pair implements true `DropOldest`
8//! semantics: when the buffer is full, the oldest message is evicted
9//! to make room for the new one.
10
11use std::collections::HashMap;
12use std::collections::VecDeque;
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::Arc;
15
16use bytes::Bytes;
17use parking_lot::{Mutex, RwLock};
18use tokio::sync::Notify;
19use tracing::{debug, warn};
20
21use super::sink_config::SlowClientPolicy;
22
23// ── Ring-buffer channel ─────────────────────────────────────────────
24
25struct RingInner<T> {
26    buffer: VecDeque<T>,
27    capacity: usize,
28    closed: bool,
29}
30
31/// Sender half of a bounded ring-buffer channel.
32///
33/// When the buffer is full:
34/// - `DropOldest`: evicts the oldest entry and pushes the new one.
35/// - `DropNewest`: discards the new entry.
36///
37/// Uses `parking_lot::Mutex` for the shared buffer (one writer from
38/// `broadcast()`, one reader from the per-client tokio task).
39pub struct RingSender<T> {
40    inner: Arc<Mutex<RingInner<T>>>,
41    notify: Arc<Notify>,
42}
43
44/// Receiver half of a bounded ring-buffer channel.
45pub struct RingReceiver<T> {
46    inner: Arc<Mutex<RingInner<T>>>,
47    notify: Arc<Notify>,
48}
49
50/// Result of a send attempt on a ring channel.
51pub enum RingSendResult {
52    /// Message was enqueued.
53    Sent,
54    /// Buffer was full; oldest message evicted to make room.
55    Evicted,
56    /// Buffer was full; incoming message dropped (`DropNewest` policy).
57    Dropped,
58    /// Receiver was dropped.
59    Closed,
60}
61
62/// Creates a ring-buffer channel pair with the given capacity.
63#[must_use]
64pub fn ring_channel<T>(capacity: usize) -> (RingSender<T>, RingReceiver<T>) {
65    let cap = capacity.max(1);
66    let inner = Arc::new(Mutex::new(RingInner {
67        buffer: VecDeque::with_capacity(cap),
68        capacity: cap,
69        closed: false,
70    }));
71    let notify = Arc::new(Notify::new());
72    (
73        RingSender {
74            inner: Arc::clone(&inner),
75            notify: Arc::clone(&notify),
76        },
77        RingReceiver { inner, notify },
78    )
79}
80
81impl<T> RingSender<T> {
82    /// Sends a value, applying the given policy when the buffer is full.
83    #[must_use]
84    pub fn send(&self, value: T, drop_oldest: bool) -> RingSendResult {
85        let mut guard = self.inner.lock();
86        if guard.closed {
87            return RingSendResult::Closed;
88        }
89        if guard.buffer.len() >= guard.capacity {
90            if drop_oldest {
91                guard.buffer.pop_front();
92                guard.buffer.push_back(value);
93                self.notify.notify_one();
94                return RingSendResult::Evicted;
95            }
96            return RingSendResult::Dropped;
97        }
98        guard.buffer.push_back(value);
99        self.notify.notify_one();
100        RingSendResult::Sent
101    }
102}
103
104impl<T> Drop for RingSender<T> {
105    fn drop(&mut self) {
106        self.inner.lock().closed = true;
107        self.notify.notify_one();
108    }
109}
110
111impl<T> RingReceiver<T> {
112    /// Receives the next value, waiting asynchronously if the buffer is empty.
113    /// Returns `None` if the sender is dropped and the buffer is empty.
114    pub async fn recv(&self) -> Option<T> {
115        loop {
116            {
117                let mut guard = self.inner.lock();
118                if let Some(item) = guard.buffer.pop_front() {
119                    return Some(item);
120                }
121                if guard.closed {
122                    return None;
123                }
124            }
125            self.notify.notified().await;
126        }
127    }
128}
129
130impl<T> Drop for RingReceiver<T> {
131    fn drop(&mut self) {
132        self.inner.lock().closed = true;
133    }
134}
135
136// ── Fan-out types ───────────────────────────────────────────────────
137
138/// Unique identifier for a connected WebSocket client.
139pub type ClientId = u64;
140
141/// Per-client state within the fan-out manager.
142#[derive(Debug)]
143pub struct ClientState {
144    /// Ring-buffer sender for this client.
145    pub tx: RingSender<Bytes>,
146    /// Client's subscription filter expression (if any).
147    pub filter: Option<String>,
148    /// Subscription ID assigned to this client.
149    pub subscription_id: String,
150    /// Desired output format (reserved for per-client format negotiation).
151    pub format: Option<super::sink_config::SinkFormat>,
152    /// Number of messages dropped for this client.
153    pub messages_dropped: AtomicU64,
154}
155
156impl std::fmt::Debug for RingSender<Bytes> {
157    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158        let guard = self.inner.lock();
159        f.debug_struct("RingSender")
160            .field("buffered", &guard.buffer.len())
161            .field("capacity", &guard.capacity)
162            .finish()
163    }
164}
165
166/// Circular replay buffer for client resume support.
167#[derive(Debug)]
168pub struct ReplayBuffer {
169    /// Ring buffer of `(sequence, serialized_message)`.
170    buffer: VecDeque<(u64, Bytes)>,
171    /// Maximum number of entries.
172    max_size: usize,
173}
174
175impl ReplayBuffer {
176    /// Creates a new replay buffer with the given capacity.
177    #[must_use]
178    pub fn new(max_size: usize) -> Self {
179        Self {
180            buffer: VecDeque::with_capacity(max_size.min(10_000)),
181            max_size,
182        }
183    }
184
185    /// Appends a message to the replay buffer, evicting the oldest if full.
186    pub fn push(&mut self, sequence: u64, data: Bytes) {
187        if self.buffer.len() >= self.max_size {
188            self.buffer.pop_front();
189        }
190        self.buffer.push_back((sequence, data));
191    }
192
193    /// Returns all messages with sequence number > `from_sequence`.
194    #[must_use]
195    pub fn replay_from(&self, from_sequence: u64) -> Vec<(u64, Bytes)> {
196        self.buffer
197            .iter()
198            .filter(|(seq, _)| *seq > from_sequence)
199            .cloned()
200            .collect()
201    }
202
203    /// Returns the lowest sequence number in the buffer, if any.
204    #[must_use]
205    pub fn oldest_sequence(&self) -> Option<u64> {
206        self.buffer.front().map(|(seq, _)| *seq)
207    }
208
209    /// Returns the number of entries in the buffer.
210    #[must_use]
211    pub fn len(&self) -> usize {
212        self.buffer.len()
213    }
214
215    /// Returns whether the buffer is empty.
216    #[must_use]
217    pub fn is_empty(&self) -> bool {
218        self.buffer.is_empty()
219    }
220}
221
222/// Fan-out manager that distributes messages to connected WebSocket clients.
223///
224/// Each client gets an independent bounded channel. The fan-out loop
225/// serializes once, then attempts to send to every client. If a client's
226/// channel is full, the configured [`SlowClientPolicy`] is applied.
227pub struct FanoutManager {
228    /// Connected clients keyed by ID.
229    clients: Arc<RwLock<HashMap<ClientId, ClientState>>>,
230    /// Slow client eviction policy.
231    policy: SlowClientPolicy,
232    /// Per-client send buffer capacity (in messages).
233    buffer_capacity: usize,
234    /// Next client ID.
235    next_id: AtomicU64,
236    /// Global sequence counter for messages.
237    sequence: AtomicU64,
238    /// Optional replay buffer.
239    replay_buffer: Option<parking_lot::Mutex<ReplayBuffer>>,
240}
241
242impl FanoutManager {
243    /// Creates a new fan-out manager.
244    ///
245    /// # Arguments
246    ///
247    /// * `policy` - Slow client eviction policy.
248    /// * `buffer_capacity` - Max queued messages per client.
249    /// * `replay_buffer_size` - If `Some`, enables a replay buffer of this size.
250    #[must_use]
251    pub fn new(
252        policy: SlowClientPolicy,
253        buffer_capacity: usize,
254        replay_buffer_size: Option<usize>,
255    ) -> Self {
256        let replay_buffer =
257            replay_buffer_size.map(|size| parking_lot::Mutex::new(ReplayBuffer::new(size)));
258        Self {
259            clients: Arc::new(RwLock::new(HashMap::new())),
260            policy,
261            buffer_capacity: buffer_capacity.max(1),
262            next_id: AtomicU64::new(1),
263            sequence: AtomicU64::new(0),
264            replay_buffer,
265        }
266    }
267
268    /// Registers a new client and returns its ID and receive channel.
269    pub fn add_client(
270        &self,
271        subscription_id: String,
272        filter: Option<String>,
273        format: Option<super::sink_config::SinkFormat>,
274    ) -> (ClientId, RingReceiver<Bytes>) {
275        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
276        let (tx, rx) = ring_channel(self.buffer_capacity);
277
278        let state = ClientState {
279            tx,
280            filter,
281            subscription_id,
282            format,
283            messages_dropped: AtomicU64::new(0),
284        };
285
286        self.clients.write().insert(id, state);
287        debug!(client_id = id, "client registered");
288        (id, rx)
289    }
290
291    /// Removes a client by ID, returning whether it existed.
292    pub fn remove_client(&self, id: ClientId) -> bool {
293        let removed = self.clients.write().remove(&id).is_some();
294        if removed {
295            debug!(client_id = id, "client removed");
296        }
297        removed
298    }
299
300    /// Returns the number of connected clients.
301    #[must_use]
302    pub fn client_count(&self) -> usize {
303        self.clients.read().len()
304    }
305
306    /// Returns the current global sequence number.
307    #[must_use]
308    pub fn current_sequence(&self) -> u64 {
309        self.sequence.load(Ordering::Relaxed)
310    }
311
312    /// Returns a clone of the clients map handle for external use.
313    #[must_use]
314    pub fn clients(&self) -> Arc<RwLock<HashMap<ClientId, ClientState>>> {
315        Arc::clone(&self.clients)
316    }
317
318    /// Broadcasts serialized data to all connected clients.
319    ///
320    /// Applies the slow client policy for clients that can't keep up.
321    /// Returns the number of clients that received the message successfully.
322    #[allow(clippy::needless_pass_by_value)]
323    pub fn broadcast(&self, data: Bytes) -> BroadcastResult {
324        let seq = self.sequence.fetch_add(1, Ordering::Relaxed) + 1;
325
326        // Store in replay buffer if enabled.
327        if let Some(ref replay) = self.replay_buffer {
328            replay.lock().push(seq, data.clone());
329        }
330
331        let clients = self.clients.read();
332        let mut sent = 0u64;
333        let mut dropped = 0u64;
334        let mut disconnected: Vec<ClientId> = Vec::new();
335
336        let drop_oldest = matches!(self.policy, SlowClientPolicy::DropOldest);
337
338        for (&id, state) in clients.iter() {
339            match state.tx.send(data.clone(), drop_oldest) {
340                RingSendResult::Sent => {
341                    sent += 1;
342                }
343                RingSendResult::Evicted => {
344                    state.messages_dropped.fetch_add(1, Ordering::Relaxed);
345                    sent += 1; // new message was enqueued after eviction
346                    dropped += 1; // but the old message was lost
347                }
348                RingSendResult::Dropped => match &self.policy {
349                    SlowClientPolicy::DropNewest | SlowClientPolicy::DropOldest => {
350                        state.messages_dropped.fetch_add(1, Ordering::Relaxed);
351                        dropped += 1;
352                    }
353                    SlowClientPolicy::Disconnect { .. } => {
354                        disconnected.push(id);
355                    }
356                    SlowClientPolicy::WarnThenDisconnect { .. } => {
357                        let total_drops =
358                            state.messages_dropped.fetch_add(1, Ordering::Relaxed) + 1;
359                        if total_drops > self.buffer_capacity as u64 {
360                            disconnected.push(id);
361                        } else {
362                            dropped += 1;
363                        }
364                    }
365                },
366                RingSendResult::Closed => {
367                    disconnected.push(id);
368                }
369            }
370        }
371        drop(clients);
372
373        // Remove disconnected clients outside the read lock.
374        if !disconnected.is_empty() {
375            let mut clients = self.clients.write();
376            for id in &disconnected {
377                clients.remove(id);
378                warn!(client_id = id, "slow/disconnected client removed");
379            }
380        }
381
382        BroadcastResult {
383            sequence: seq,
384            sent,
385            dropped,
386            disconnected: disconnected.len() as u64,
387        }
388    }
389
390    /// Returns replay messages for a client resuming from a given sequence.
391    #[must_use]
392    pub fn replay_from(&self, from_sequence: u64) -> Vec<(u64, Bytes)> {
393        match &self.replay_buffer {
394            Some(replay) => replay.lock().replay_from(from_sequence),
395            None => Vec::new(),
396        }
397    }
398}
399
400impl std::fmt::Debug for FanoutManager {
401    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
402        f.debug_struct("FanoutManager")
403            .field("clients", &self.client_count())
404            .field("sequence", &self.current_sequence())
405            .field("buffer_capacity", &self.buffer_capacity)
406            .finish_non_exhaustive()
407    }
408}
409
410/// Result of a broadcast operation.
411#[derive(Debug, Clone)]
412pub struct BroadcastResult {
413    /// Sequence number assigned to this message.
414    pub sequence: u64,
415    /// Number of clients that received the message.
416    pub sent: u64,
417    /// Number of clients where the message was dropped.
418    pub dropped: u64,
419    /// Number of clients that were disconnected.
420    pub disconnected: u64,
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426
427    #[test]
428    fn test_replay_buffer() {
429        let mut buf = ReplayBuffer::new(3);
430        buf.push(1, Bytes::from("a"));
431        buf.push(2, Bytes::from("b"));
432        buf.push(3, Bytes::from("c"));
433
434        assert_eq!(buf.len(), 3);
435        assert_eq!(buf.oldest_sequence(), Some(1));
436
437        let replay = buf.replay_from(1);
438        assert_eq!(replay.len(), 2);
439        assert_eq!(replay[0].0, 2);
440        assert_eq!(replay[1].0, 3);
441    }
442
443    #[test]
444    fn test_replay_buffer_eviction() {
445        let mut buf = ReplayBuffer::new(2);
446        buf.push(1, Bytes::from("a"));
447        buf.push(2, Bytes::from("b"));
448        buf.push(3, Bytes::from("c")); // evicts seq 1
449
450        assert_eq!(buf.len(), 2);
451        assert_eq!(buf.oldest_sequence(), Some(2));
452
453        let replay = buf.replay_from(0);
454        assert_eq!(replay.len(), 2);
455        assert_eq!(replay[0].0, 2);
456    }
457
458    #[test]
459    fn test_fanout_add_remove_client() {
460        let mgr = FanoutManager::new(SlowClientPolicy::DropOldest, 10, None);
461
462        let (id1, _rx1) = mgr.add_client("sub1".into(), None, None);
463        let (id2, _rx2) = mgr.add_client("sub2".into(), None, None);
464        assert_eq!(mgr.client_count(), 2);
465
466        assert!(mgr.remove_client(id1));
467        assert_eq!(mgr.client_count(), 1);
468
469        assert!(!mgr.remove_client(id1)); // already removed
470        assert!(mgr.remove_client(id2));
471        assert_eq!(mgr.client_count(), 0);
472    }
473
474    #[tokio::test]
475    async fn test_fanout_broadcast() {
476        let mgr = FanoutManager::new(SlowClientPolicy::DropOldest, 10, None);
477        let (_id1, mut rx1) = mgr.add_client("sub1".into(), None, None);
478        let (_id2, mut rx2) = mgr.add_client("sub2".into(), None, None);
479
480        let result = mgr.broadcast(Bytes::from("hello"));
481        assert_eq!(result.sent, 2);
482        assert_eq!(result.dropped, 0);
483        assert_eq!(result.sequence, 1);
484
485        let msg1 = rx1.recv().await.unwrap();
486        assert_eq!(msg1.as_ref(), b"hello");
487
488        let msg2 = rx2.recv().await.unwrap();
489        assert_eq!(msg2.as_ref(), b"hello");
490    }
491
492    #[test]
493    fn test_fanout_slow_client_drop() {
494        let mgr = FanoutManager::new(SlowClientPolicy::DropNewest, 2, None);
495        let (_id, _rx) = mgr.add_client("sub1".into(), None, None);
496
497        // Fill the buffer.
498        mgr.broadcast(Bytes::from("a"));
499        mgr.broadcast(Bytes::from("b"));
500
501        // This should be dropped (buffer full, DropNewest policy).
502        let result = mgr.broadcast(Bytes::from("c"));
503        assert_eq!(result.dropped, 1);
504    }
505
506    #[test]
507    fn test_fanout_disconnect_policy() {
508        let mgr = FanoutManager::new(SlowClientPolicy::Disconnect { threshold_pct: 80 }, 2, None);
509        let (_id, _rx) = mgr.add_client("sub1".into(), None, None);
510
511        mgr.broadcast(Bytes::from("a"));
512        mgr.broadcast(Bytes::from("b"));
513
514        // Buffer full → client disconnected.
515        let result = mgr.broadcast(Bytes::from("c"));
516        assert_eq!(result.disconnected, 1);
517        assert_eq!(mgr.client_count(), 0);
518    }
519
520    #[test]
521    fn test_fanout_with_replay() {
522        let mgr = FanoutManager::new(SlowClientPolicy::DropOldest, 10, Some(5));
523
524        mgr.broadcast(Bytes::from("a"));
525        mgr.broadcast(Bytes::from("b"));
526        mgr.broadcast(Bytes::from("c"));
527
528        let replay = mgr.replay_from(1);
529        assert_eq!(replay.len(), 2);
530        assert_eq!(replay[0].1.as_ref(), b"b");
531        assert_eq!(replay[1].1.as_ref(), b"c");
532    }
533
534    #[tokio::test]
535    async fn test_fanout_closed_client_removed() {
536        let mgr = FanoutManager::new(SlowClientPolicy::DropOldest, 10, None);
537        let (_id, rx) = mgr.add_client("sub1".into(), None, None);
538
539        // Drop the receiver to simulate client disconnect.
540        drop(rx);
541
542        let result = mgr.broadcast(Bytes::from("hello"));
543        assert_eq!(result.disconnected, 1);
544        assert_eq!(mgr.client_count(), 0);
545    }
546
547    #[tokio::test]
548    async fn test_fanout_drop_oldest_evicts() {
549        let mgr = FanoutManager::new(SlowClientPolicy::DropOldest, 2, None);
550        let (_id, rx) = mgr.add_client("sub1".into(), None, None);
551
552        // Fill the buffer.
553        mgr.broadcast(Bytes::from("a"));
554        mgr.broadcast(Bytes::from("b"));
555
556        // This should evict "a" and enqueue "c".
557        let result = mgr.broadcast(Bytes::from("c"));
558        assert_eq!(result.sent, 1); // new message was enqueued
559        assert_eq!(result.dropped, 1); // old message was evicted
560
561        // Receiver should get "b" then "c" (not "a").
562        let msg1 = rx.recv().await.unwrap();
563        assert_eq!(msg1.as_ref(), b"b");
564        let msg2 = rx.recv().await.unwrap();
565        assert_eq!(msg2.as_ref(), b"c");
566    }
567
568    #[tokio::test]
569    async fn test_ring_channel_basic() {
570        let (tx, rx) = ring_channel::<Bytes>(4);
571        let _ = tx.send(Bytes::from("hello"), false);
572        let msg = rx.recv().await.unwrap();
573        assert_eq!(msg.as_ref(), b"hello");
574    }
575
576    #[tokio::test]
577    async fn test_ring_channel_sender_dropped() {
578        let (tx, rx) = ring_channel::<Bytes>(4);
579        let _ = tx.send(Bytes::from("last"), false);
580        drop(tx);
581        let msg = rx.recv().await.unwrap();
582        assert_eq!(msg.as_ref(), b"last");
583        assert!(rx.recv().await.is_none());
584    }
585
586    #[test]
587    fn test_fanout_sequence_increments() {
588        let mgr = FanoutManager::new(SlowClientPolicy::DropOldest, 10, None);
589
590        let r1 = mgr.broadcast(Bytes::from("a"));
591        let r2 = mgr.broadcast(Bytes::from("b"));
592        let r3 = mgr.broadcast(Bytes::from("c"));
593
594        assert_eq!(r1.sequence, 1);
595        assert_eq!(r2.sequence, 2);
596        assert_eq!(r3.sequence, 3);
597    }
598}