1use std::collections::HashMap;
12use std::collections::VecDeque;
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::Arc;
15
16use bytes::Bytes;
17use parking_lot::{Mutex, RwLock};
18use tokio::sync::Notify;
19use tracing::{debug, warn};
20
21use super::sink_config::SlowClientPolicy;
22
23struct RingInner<T> {
26 buffer: VecDeque<T>,
27 capacity: usize,
28 closed: bool,
29}
30
31pub struct RingSender<T> {
40 inner: Arc<Mutex<RingInner<T>>>,
41 notify: Arc<Notify>,
42}
43
44pub struct RingReceiver<T> {
46 inner: Arc<Mutex<RingInner<T>>>,
47 notify: Arc<Notify>,
48}
49
50pub enum RingSendResult {
52 Sent,
54 Evicted,
56 Dropped,
58 Closed,
60}
61
62#[must_use]
64pub fn ring_channel<T>(capacity: usize) -> (RingSender<T>, RingReceiver<T>) {
65 let cap = capacity.max(1);
66 let inner = Arc::new(Mutex::new(RingInner {
67 buffer: VecDeque::with_capacity(cap),
68 capacity: cap,
69 closed: false,
70 }));
71 let notify = Arc::new(Notify::new());
72 (
73 RingSender {
74 inner: Arc::clone(&inner),
75 notify: Arc::clone(¬ify),
76 },
77 RingReceiver { inner, notify },
78 )
79}
80
81impl<T> RingSender<T> {
82 #[must_use]
84 pub fn send(&self, value: T, drop_oldest: bool) -> RingSendResult {
85 let mut guard = self.inner.lock();
86 if guard.closed {
87 return RingSendResult::Closed;
88 }
89 if guard.buffer.len() >= guard.capacity {
90 if drop_oldest {
91 guard.buffer.pop_front();
92 guard.buffer.push_back(value);
93 self.notify.notify_one();
94 return RingSendResult::Evicted;
95 }
96 return RingSendResult::Dropped;
97 }
98 guard.buffer.push_back(value);
99 self.notify.notify_one();
100 RingSendResult::Sent
101 }
102}
103
104impl<T> Drop for RingSender<T> {
105 fn drop(&mut self) {
106 self.inner.lock().closed = true;
107 self.notify.notify_one();
108 }
109}
110
111impl<T> RingReceiver<T> {
112 pub async fn recv(&self) -> Option<T> {
115 loop {
116 {
117 let mut guard = self.inner.lock();
118 if let Some(item) = guard.buffer.pop_front() {
119 return Some(item);
120 }
121 if guard.closed {
122 return None;
123 }
124 }
125 self.notify.notified().await;
126 }
127 }
128}
129
130impl<T> Drop for RingReceiver<T> {
131 fn drop(&mut self) {
132 self.inner.lock().closed = true;
133 }
134}
135
136pub type ClientId = u64;
140
141#[derive(Debug)]
143pub struct ClientState {
144 pub tx: RingSender<Bytes>,
146 pub filter: Option<String>,
148 pub subscription_id: String,
150 pub format: Option<super::sink_config::SinkFormat>,
152 pub messages_dropped: AtomicU64,
154}
155
156impl std::fmt::Debug for RingSender<Bytes> {
157 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158 let guard = self.inner.lock();
159 f.debug_struct("RingSender")
160 .field("buffered", &guard.buffer.len())
161 .field("capacity", &guard.capacity)
162 .finish()
163 }
164}
165
166#[derive(Debug)]
168pub struct ReplayBuffer {
169 buffer: VecDeque<(u64, Bytes)>,
171 max_size: usize,
173}
174
175impl ReplayBuffer {
176 #[must_use]
178 pub fn new(max_size: usize) -> Self {
179 Self {
180 buffer: VecDeque::with_capacity(max_size.min(10_000)),
181 max_size,
182 }
183 }
184
185 pub fn push(&mut self, sequence: u64, data: Bytes) {
187 if self.buffer.len() >= self.max_size {
188 self.buffer.pop_front();
189 }
190 self.buffer.push_back((sequence, data));
191 }
192
193 #[must_use]
195 pub fn replay_from(&self, from_sequence: u64) -> Vec<(u64, Bytes)> {
196 self.buffer
197 .iter()
198 .filter(|(seq, _)| *seq > from_sequence)
199 .cloned()
200 .collect()
201 }
202
203 #[must_use]
205 pub fn oldest_sequence(&self) -> Option<u64> {
206 self.buffer.front().map(|(seq, _)| *seq)
207 }
208
209 #[must_use]
211 pub fn len(&self) -> usize {
212 self.buffer.len()
213 }
214
215 #[must_use]
217 pub fn is_empty(&self) -> bool {
218 self.buffer.is_empty()
219 }
220}
221
222pub struct FanoutManager {
228 clients: Arc<RwLock<HashMap<ClientId, ClientState>>>,
230 policy: SlowClientPolicy,
232 buffer_capacity: usize,
234 next_id: AtomicU64,
236 sequence: AtomicU64,
238 replay_buffer: Option<parking_lot::Mutex<ReplayBuffer>>,
240}
241
242impl FanoutManager {
243 #[must_use]
251 pub fn new(
252 policy: SlowClientPolicy,
253 buffer_capacity: usize,
254 replay_buffer_size: Option<usize>,
255 ) -> Self {
256 let replay_buffer =
257 replay_buffer_size.map(|size| parking_lot::Mutex::new(ReplayBuffer::new(size)));
258 Self {
259 clients: Arc::new(RwLock::new(HashMap::new())),
260 policy,
261 buffer_capacity: buffer_capacity.max(1),
262 next_id: AtomicU64::new(1),
263 sequence: AtomicU64::new(0),
264 replay_buffer,
265 }
266 }
267
268 pub fn add_client(
270 &self,
271 subscription_id: String,
272 filter: Option<String>,
273 format: Option<super::sink_config::SinkFormat>,
274 ) -> (ClientId, RingReceiver<Bytes>) {
275 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
276 let (tx, rx) = ring_channel(self.buffer_capacity);
277
278 let state = ClientState {
279 tx,
280 filter,
281 subscription_id,
282 format,
283 messages_dropped: AtomicU64::new(0),
284 };
285
286 self.clients.write().insert(id, state);
287 debug!(client_id = id, "client registered");
288 (id, rx)
289 }
290
291 pub fn remove_client(&self, id: ClientId) -> bool {
293 let removed = self.clients.write().remove(&id).is_some();
294 if removed {
295 debug!(client_id = id, "client removed");
296 }
297 removed
298 }
299
300 #[must_use]
302 pub fn client_count(&self) -> usize {
303 self.clients.read().len()
304 }
305
306 #[must_use]
308 pub fn current_sequence(&self) -> u64 {
309 self.sequence.load(Ordering::Relaxed)
310 }
311
312 #[must_use]
314 pub fn clients(&self) -> Arc<RwLock<HashMap<ClientId, ClientState>>> {
315 Arc::clone(&self.clients)
316 }
317
318 #[allow(clippy::needless_pass_by_value)]
323 pub fn broadcast(&self, data: Bytes) -> BroadcastResult {
324 let seq = self.sequence.fetch_add(1, Ordering::Relaxed) + 1;
325
326 if let Some(ref replay) = self.replay_buffer {
328 replay.lock().push(seq, data.clone());
329 }
330
331 let clients = self.clients.read();
332 let mut sent = 0u64;
333 let mut dropped = 0u64;
334 let mut disconnected: Vec<ClientId> = Vec::new();
335
336 let drop_oldest = matches!(self.policy, SlowClientPolicy::DropOldest);
337
338 for (&id, state) in clients.iter() {
339 match state.tx.send(data.clone(), drop_oldest) {
340 RingSendResult::Sent => {
341 sent += 1;
342 }
343 RingSendResult::Evicted => {
344 state.messages_dropped.fetch_add(1, Ordering::Relaxed);
345 sent += 1; dropped += 1; }
348 RingSendResult::Dropped => match &self.policy {
349 SlowClientPolicy::DropNewest | SlowClientPolicy::DropOldest => {
350 state.messages_dropped.fetch_add(1, Ordering::Relaxed);
351 dropped += 1;
352 }
353 SlowClientPolicy::Disconnect { .. } => {
354 disconnected.push(id);
355 }
356 SlowClientPolicy::WarnThenDisconnect { .. } => {
357 let total_drops =
358 state.messages_dropped.fetch_add(1, Ordering::Relaxed) + 1;
359 if total_drops > self.buffer_capacity as u64 {
360 disconnected.push(id);
361 } else {
362 dropped += 1;
363 }
364 }
365 },
366 RingSendResult::Closed => {
367 disconnected.push(id);
368 }
369 }
370 }
371 drop(clients);
372
373 if !disconnected.is_empty() {
375 let mut clients = self.clients.write();
376 for id in &disconnected {
377 clients.remove(id);
378 warn!(client_id = id, "slow/disconnected client removed");
379 }
380 }
381
382 BroadcastResult {
383 sequence: seq,
384 sent,
385 dropped,
386 disconnected: disconnected.len() as u64,
387 }
388 }
389
390 #[must_use]
392 pub fn replay_from(&self, from_sequence: u64) -> Vec<(u64, Bytes)> {
393 match &self.replay_buffer {
394 Some(replay) => replay.lock().replay_from(from_sequence),
395 None => Vec::new(),
396 }
397 }
398}
399
400impl std::fmt::Debug for FanoutManager {
401 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
402 f.debug_struct("FanoutManager")
403 .field("clients", &self.client_count())
404 .field("sequence", &self.current_sequence())
405 .field("buffer_capacity", &self.buffer_capacity)
406 .finish_non_exhaustive()
407 }
408}
409
410#[derive(Debug, Clone)]
412pub struct BroadcastResult {
413 pub sequence: u64,
415 pub sent: u64,
417 pub dropped: u64,
419 pub disconnected: u64,
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426
427 #[test]
428 fn test_replay_buffer() {
429 let mut buf = ReplayBuffer::new(3);
430 buf.push(1, Bytes::from("a"));
431 buf.push(2, Bytes::from("b"));
432 buf.push(3, Bytes::from("c"));
433
434 assert_eq!(buf.len(), 3);
435 assert_eq!(buf.oldest_sequence(), Some(1));
436
437 let replay = buf.replay_from(1);
438 assert_eq!(replay.len(), 2);
439 assert_eq!(replay[0].0, 2);
440 assert_eq!(replay[1].0, 3);
441 }
442
443 #[test]
444 fn test_replay_buffer_eviction() {
445 let mut buf = ReplayBuffer::new(2);
446 buf.push(1, Bytes::from("a"));
447 buf.push(2, Bytes::from("b"));
448 buf.push(3, Bytes::from("c")); assert_eq!(buf.len(), 2);
451 assert_eq!(buf.oldest_sequence(), Some(2));
452
453 let replay = buf.replay_from(0);
454 assert_eq!(replay.len(), 2);
455 assert_eq!(replay[0].0, 2);
456 }
457
458 #[test]
459 fn test_fanout_add_remove_client() {
460 let mgr = FanoutManager::new(SlowClientPolicy::DropOldest, 10, None);
461
462 let (id1, _rx1) = mgr.add_client("sub1".into(), None, None);
463 let (id2, _rx2) = mgr.add_client("sub2".into(), None, None);
464 assert_eq!(mgr.client_count(), 2);
465
466 assert!(mgr.remove_client(id1));
467 assert_eq!(mgr.client_count(), 1);
468
469 assert!(!mgr.remove_client(id1)); assert!(mgr.remove_client(id2));
471 assert_eq!(mgr.client_count(), 0);
472 }
473
474 #[tokio::test]
475 async fn test_fanout_broadcast() {
476 let mgr = FanoutManager::new(SlowClientPolicy::DropOldest, 10, None);
477 let (_id1, mut rx1) = mgr.add_client("sub1".into(), None, None);
478 let (_id2, mut rx2) = mgr.add_client("sub2".into(), None, None);
479
480 let result = mgr.broadcast(Bytes::from("hello"));
481 assert_eq!(result.sent, 2);
482 assert_eq!(result.dropped, 0);
483 assert_eq!(result.sequence, 1);
484
485 let msg1 = rx1.recv().await.unwrap();
486 assert_eq!(msg1.as_ref(), b"hello");
487
488 let msg2 = rx2.recv().await.unwrap();
489 assert_eq!(msg2.as_ref(), b"hello");
490 }
491
492 #[test]
493 fn test_fanout_slow_client_drop() {
494 let mgr = FanoutManager::new(SlowClientPolicy::DropNewest, 2, None);
495 let (_id, _rx) = mgr.add_client("sub1".into(), None, None);
496
497 mgr.broadcast(Bytes::from("a"));
499 mgr.broadcast(Bytes::from("b"));
500
501 let result = mgr.broadcast(Bytes::from("c"));
503 assert_eq!(result.dropped, 1);
504 }
505
506 #[test]
507 fn test_fanout_disconnect_policy() {
508 let mgr = FanoutManager::new(SlowClientPolicy::Disconnect { threshold_pct: 80 }, 2, None);
509 let (_id, _rx) = mgr.add_client("sub1".into(), None, None);
510
511 mgr.broadcast(Bytes::from("a"));
512 mgr.broadcast(Bytes::from("b"));
513
514 let result = mgr.broadcast(Bytes::from("c"));
516 assert_eq!(result.disconnected, 1);
517 assert_eq!(mgr.client_count(), 0);
518 }
519
520 #[test]
521 fn test_fanout_with_replay() {
522 let mgr = FanoutManager::new(SlowClientPolicy::DropOldest, 10, Some(5));
523
524 mgr.broadcast(Bytes::from("a"));
525 mgr.broadcast(Bytes::from("b"));
526 mgr.broadcast(Bytes::from("c"));
527
528 let replay = mgr.replay_from(1);
529 assert_eq!(replay.len(), 2);
530 assert_eq!(replay[0].1.as_ref(), b"b");
531 assert_eq!(replay[1].1.as_ref(), b"c");
532 }
533
534 #[tokio::test]
535 async fn test_fanout_closed_client_removed() {
536 let mgr = FanoutManager::new(SlowClientPolicy::DropOldest, 10, None);
537 let (_id, rx) = mgr.add_client("sub1".into(), None, None);
538
539 drop(rx);
541
542 let result = mgr.broadcast(Bytes::from("hello"));
543 assert_eq!(result.disconnected, 1);
544 assert_eq!(mgr.client_count(), 0);
545 }
546
547 #[tokio::test]
548 async fn test_fanout_drop_oldest_evicts() {
549 let mgr = FanoutManager::new(SlowClientPolicy::DropOldest, 2, None);
550 let (_id, rx) = mgr.add_client("sub1".into(), None, None);
551
552 mgr.broadcast(Bytes::from("a"));
554 mgr.broadcast(Bytes::from("b"));
555
556 let result = mgr.broadcast(Bytes::from("c"));
558 assert_eq!(result.sent, 1); assert_eq!(result.dropped, 1); let msg1 = rx.recv().await.unwrap();
563 assert_eq!(msg1.as_ref(), b"b");
564 let msg2 = rx.recv().await.unwrap();
565 assert_eq!(msg2.as_ref(), b"c");
566 }
567
568 #[tokio::test]
569 async fn test_ring_channel_basic() {
570 let (tx, rx) = ring_channel::<Bytes>(4);
571 let _ = tx.send(Bytes::from("hello"), false);
572 let msg = rx.recv().await.unwrap();
573 assert_eq!(msg.as_ref(), b"hello");
574 }
575
576 #[tokio::test]
577 async fn test_ring_channel_sender_dropped() {
578 let (tx, rx) = ring_channel::<Bytes>(4);
579 let _ = tx.send(Bytes::from("last"), false);
580 drop(tx);
581 let msg = rx.recv().await.unwrap();
582 assert_eq!(msg.as_ref(), b"last");
583 assert!(rx.recv().await.is_none());
584 }
585
586 #[test]
587 fn test_fanout_sequence_increments() {
588 let mgr = FanoutManager::new(SlowClientPolicy::DropOldest, 10, None);
589
590 let r1 = mgr.broadcast(Bytes::from("a"));
591 let r2 = mgr.broadcast(Bytes::from("b"));
592 let r3 = mgr.broadcast(Bytes::from("c"));
593
594 assert_eq!(r1.sequence, 1);
595 assert_eq!(r2.sequence, 2);
596 assert_eq!(r3.sequence, 3);
597 }
598}