1use parking_lot::RwLock;
28use rustc_hash::FxHashMap;
29use std::sync::atomic::{AtomicU64, Ordering};
30use std::time::{Duration, Instant};
31
32use tokio::sync::broadcast;
33
34use crate::subscription::event::ChangeEvent;
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
44pub struct SubscriptionId(pub u64);
45
46impl std::fmt::Display for SubscriptionId {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 write!(f, "sub-{}", self.0)
49 }
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum SubscriptionState {
59 Active,
61 Paused,
63 Cancelled,
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73pub enum BackpressureStrategy {
74 DropOldest,
76 DropNewest,
78 Block,
80 Sample(usize),
82}
83
84#[derive(Debug, Clone)]
90pub struct SubscriptionConfig {
91 pub buffer_size: usize,
93 pub backpressure: BackpressureStrategy,
95 pub filter: Option<String>,
97 pub send_snapshot: bool,
99 pub max_batch_size: usize,
101 pub max_batch_delay_us: u64,
103}
104
105impl Default for SubscriptionConfig {
106 fn default() -> Self {
107 Self {
108 buffer_size: 1024,
109 backpressure: BackpressureStrategy::DropOldest,
110 filter: None,
111 send_snapshot: false,
112 max_batch_size: 64,
113 max_batch_delay_us: 100,
114 }
115 }
116}
117
118#[derive(Debug)]
124pub struct SubscriptionEntry {
125 pub id: SubscriptionId,
127 pub source_name: String,
129 pub source_id: u32,
131 pub state: SubscriptionState,
133 pub config: SubscriptionConfig,
135 pub sender: broadcast::Sender<ChangeEvent>,
137 pub created_at: Instant,
139 pub events_delivered: AtomicU64,
141 pub events_dropped: AtomicU64,
143 pub current_lag: u64,
145}
146
147#[derive(Debug, Clone)]
153pub struct SubscriptionMetrics {
154 pub id: SubscriptionId,
156 pub source_name: String,
158 pub state: SubscriptionState,
160 pub events_delivered: u64,
162 pub events_dropped: u64,
164 pub current_lag: u64,
166 pub age: Duration,
168}
169
170pub struct SubscriptionRegistry {
191 subscriptions: RwLock<FxHashMap<SubscriptionId, SubscriptionEntry>>,
193 by_source: RwLock<FxHashMap<u32, Vec<SubscriptionId>>>,
195 by_name: RwLock<FxHashMap<String, Vec<SubscriptionId>>>,
197 next_id: AtomicU64,
199 version: AtomicU64,
202}
203
204#[allow(clippy::missing_panics_doc)] impl SubscriptionRegistry {
206 #[must_use]
208 pub fn new() -> Self {
209 Self {
210 subscriptions: RwLock::new(FxHashMap::default()),
211 by_source: RwLock::new(FxHashMap::default()),
212 by_name: RwLock::new(FxHashMap::default()),
213 next_id: AtomicU64::new(1),
214 version: AtomicU64::new(0),
215 }
216 }
217
218 pub fn create(
229 &self,
230 source_name: String,
231 source_id: u32,
232 config: SubscriptionConfig,
233 ) -> (SubscriptionId, broadcast::Receiver<ChangeEvent>) {
234 let id = SubscriptionId(self.next_id.fetch_add(1, Ordering::Relaxed));
235 let (tx, rx) = broadcast::channel(config.buffer_size);
236
237 let entry = SubscriptionEntry {
238 id,
239 source_name: source_name.clone(),
240 source_id,
241 state: SubscriptionState::Active,
242 config,
243 sender: tx,
244 created_at: Instant::now(),
245 events_delivered: AtomicU64::new(0),
246 events_dropped: AtomicU64::new(0),
247 current_lag: 0,
248 };
249
250 self.subscriptions.write().insert(id, entry);
252
253 self.by_source
255 .write()
256 .entry(source_id)
257 .or_default()
258 .push(id);
259
260 self.by_name
262 .write()
263 .entry(source_name)
264 .or_default()
265 .push(id);
266
267 self.version.fetch_add(1, Ordering::Release);
268 (id, rx)
269 }
270
271 pub fn pause(&self, id: SubscriptionId) -> bool {
276 let mut subs = self.subscriptions.write();
277 if let Some(entry) = subs.get_mut(&id) {
278 if entry.state == SubscriptionState::Active {
279 entry.state = SubscriptionState::Paused;
280 return true;
281 }
282 }
283 false
284 }
285
286 pub fn resume(&self, id: SubscriptionId) -> bool {
291 let mut subs = self.subscriptions.write();
292 if let Some(entry) = subs.get_mut(&id) {
293 if entry.state == SubscriptionState::Paused {
294 entry.state = SubscriptionState::Active;
295 return true;
296 }
297 }
298 false
299 }
300
301 pub fn cancel(&self, id: SubscriptionId) -> bool {
305 let entry = self.subscriptions.write().remove(&id);
306
307 if let Some(entry) = entry {
308 if let Some(ids) = self.by_source.write().get_mut(&entry.source_id) {
310 ids.retain(|&i| i != id);
311 }
312
313 if let Some(ids) = self.by_name.write().get_mut(&entry.source_name) {
315 ids.retain(|&i| i != id);
316 }
317
318 self.version.fetch_add(1, Ordering::Release);
319 true
320 } else {
321 false
322 }
323 }
324
325 #[must_use]
330 pub fn get_senders_for_source(&self, source_id: u32) -> Vec<broadcast::Sender<ChangeEvent>> {
331 let by_source = self.by_source.read();
332 let Some(ids) = by_source.get(&source_id) else {
333 return Vec::new();
334 };
335
336 let subs = self.subscriptions.read();
337 ids.iter()
338 .filter_map(|id| {
339 subs.get(id).and_then(|entry| {
340 if entry.state == SubscriptionState::Active {
341 Some(entry.sender.clone())
342 } else {
343 None
344 }
345 })
346 })
347 .collect()
348 }
349
350 #[must_use]
352 pub fn get_subscriptions_by_name(&self, name: &str) -> Vec<SubscriptionId> {
353 let by_name = self.by_name.read();
354 by_name.get(name).cloned().unwrap_or_default()
355 }
356
357 #[must_use]
362 pub fn version(&self) -> u64 {
363 self.version.load(Ordering::Acquire)
364 }
365
366 #[must_use]
368 pub fn subscription_count(&self) -> usize {
369 self.subscriptions.read().len()
370 }
371
372 #[must_use]
374 pub fn active_count(&self) -> usize {
375 self.subscriptions
376 .read()
377 .values()
378 .filter(|e| e.state == SubscriptionState::Active)
379 .count()
380 }
381
382 #[must_use]
384 pub fn metrics(&self, id: SubscriptionId) -> Option<SubscriptionMetrics> {
385 let subs = self.subscriptions.read();
386 subs.get(&id).map(|entry| SubscriptionMetrics {
387 id: entry.id,
388 source_name: entry.source_name.clone(),
389 state: entry.state,
390 events_delivered: entry.events_delivered.load(Ordering::Relaxed),
391 events_dropped: entry.events_dropped.load(Ordering::Relaxed),
392 current_lag: entry.current_lag,
393 age: entry.created_at.elapsed(),
394 })
395 }
396
397 #[must_use]
399 pub fn state(&self, id: SubscriptionId) -> Option<SubscriptionState> {
400 self.subscriptions.read().get(&id).map(|e| e.state)
401 }
402
403 pub fn record_delivery(&self, id: SubscriptionId, count: u64) {
407 if let Some(entry) = self.subscriptions.read().get(&id) {
408 entry.events_delivered.fetch_add(count, Ordering::Relaxed);
409 }
410 }
411
412 pub fn record_drop(&self, id: SubscriptionId, count: u64) {
416 if let Some(entry) = self.subscriptions.read().get(&id) {
417 entry.events_dropped.fetch_add(count, Ordering::Relaxed);
418 }
419 }
420}
421
422impl Default for SubscriptionRegistry {
423 fn default() -> Self {
424 Self::new()
425 }
426}
427
428#[cfg(test)]
433#[allow(clippy::cast_possible_wrap)]
434mod tests {
435 use super::*;
436 use std::sync::Arc;
437
438 use arrow_array::Int64Array;
439 use arrow_schema::{DataType, Field, Schema};
440
441 fn make_batch(n: usize) -> arrow_array::RecordBatch {
442 let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int64, false)]));
443 let values: Vec<i64> = (0..n as i64).collect();
444 let array = Int64Array::from(values);
445 arrow_array::RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
446 }
447
448 #[test]
451 fn test_registry_config_default() {
452 let cfg = SubscriptionConfig::default();
453 assert_eq!(cfg.buffer_size, 1024);
454 assert_eq!(cfg.backpressure, BackpressureStrategy::DropOldest);
455 assert!(cfg.filter.is_none());
456 assert!(!cfg.send_snapshot);
457 assert_eq!(cfg.max_batch_size, 64);
458 assert_eq!(cfg.max_batch_delay_us, 100);
459 }
460
461 #[test]
464 fn test_registry_create() {
465 let reg = SubscriptionRegistry::new();
466 let (id, _rx) = reg.create("mv_orders".into(), 0, SubscriptionConfig::default());
467 assert_eq!(id.0, 1);
468 assert_eq!(reg.subscription_count(), 1);
469 assert_eq!(reg.active_count(), 1);
470 }
471
472 #[test]
473 fn test_registry_create_multiple() {
474 let reg = SubscriptionRegistry::new();
475 let (id1, _rx1) = reg.create("mv_orders".into(), 0, SubscriptionConfig::default());
476 let (id2, _rx2) = reg.create("mv_orders".into(), 0, SubscriptionConfig::default());
477 let (id3, _rx3) = reg.create("mv_trades".into(), 1, SubscriptionConfig::default());
478
479 assert_ne!(id1, id2);
480 assert_ne!(id2, id3);
481 assert_eq!(reg.subscription_count(), 3);
482
483 let senders_0 = reg.get_senders_for_source(0);
485 assert_eq!(senders_0.len(), 2);
486 let senders_1 = reg.get_senders_for_source(1);
487 assert_eq!(senders_1.len(), 1);
488 }
489
490 #[test]
493 fn test_registry_pause_resume() {
494 let reg = SubscriptionRegistry::new();
495 let (id, _rx) = reg.create("mv_orders".into(), 0, SubscriptionConfig::default());
496
497 assert!(reg.pause(id));
499 assert_eq!(reg.state(id), Some(SubscriptionState::Paused));
500 assert_eq!(reg.active_count(), 0);
501
502 assert!(!reg.pause(id));
504
505 assert!(reg.resume(id));
507 assert_eq!(reg.state(id), Some(SubscriptionState::Active));
508 assert_eq!(reg.active_count(), 1);
509
510 assert!(!reg.resume(id));
512 }
513
514 #[test]
517 fn test_registry_cancel() {
518 let reg = SubscriptionRegistry::new();
519 let (id, _rx) = reg.create("mv_orders".into(), 0, SubscriptionConfig::default());
520 assert_eq!(reg.subscription_count(), 1);
521
522 assert!(reg.cancel(id));
523 assert_eq!(reg.subscription_count(), 0);
524 assert_eq!(reg.active_count(), 0);
525
526 let senders = reg.get_senders_for_source(0);
528 assert!(senders.is_empty());
529
530 let by_name = reg.get_subscriptions_by_name("mv_orders");
532 assert!(by_name.is_empty());
533 }
534
535 #[test]
536 fn test_registry_cancel_nonexistent() {
537 let reg = SubscriptionRegistry::new();
538 assert!(!reg.cancel(SubscriptionId(999)));
539 }
540
541 #[test]
544 fn test_registry_get_senders() {
545 let reg = SubscriptionRegistry::new();
546 let (_, _rx1) = reg.create("mv_a".into(), 0, SubscriptionConfig::default());
547 let (_, _rx2) = reg.create("mv_b".into(), 0, SubscriptionConfig::default());
548
549 let senders = reg.get_senders_for_source(0);
550 assert_eq!(senders.len(), 2);
551 }
552
553 #[test]
554 fn test_registry_get_senders_paused_excluded() {
555 let reg = SubscriptionRegistry::new();
556 let (id1, _rx1) = reg.create("mv_a".into(), 0, SubscriptionConfig::default());
557 let (_, _rx2) = reg.create("mv_b".into(), 0, SubscriptionConfig::default());
558
559 reg.pause(id1);
560 let senders = reg.get_senders_for_source(0);
561 assert_eq!(senders.len(), 1);
562 }
563
564 #[test]
565 fn test_registry_get_senders_no_source() {
566 let reg = SubscriptionRegistry::new();
567 let senders = reg.get_senders_for_source(42);
568 assert!(senders.is_empty());
569 }
570
571 #[test]
574 fn test_registry_subscription_count() {
575 let reg = SubscriptionRegistry::new();
576 assert_eq!(reg.subscription_count(), 0);
577 assert_eq!(reg.active_count(), 0);
578
579 let (id1, _rx1) = reg.create("mv_a".into(), 0, SubscriptionConfig::default());
580 let (_, _rx2) = reg.create("mv_b".into(), 1, SubscriptionConfig::default());
581 assert_eq!(reg.subscription_count(), 2);
582 assert_eq!(reg.active_count(), 2);
583
584 reg.pause(id1);
585 assert_eq!(reg.subscription_count(), 2);
586 assert_eq!(reg.active_count(), 1);
587 }
588
589 #[test]
590 fn test_registry_metrics() {
591 let reg = SubscriptionRegistry::new();
592 let (id, _rx) = reg.create("mv_orders".into(), 0, SubscriptionConfig::default());
593
594 let m = reg.metrics(id).unwrap();
595 assert_eq!(m.id, id);
596 assert_eq!(m.source_name, "mv_orders");
597 assert_eq!(m.state, SubscriptionState::Active);
598 assert_eq!(m.events_delivered, 0);
599 assert_eq!(m.events_dropped, 0);
600 assert_eq!(m.current_lag, 0);
601
602 assert!(reg.metrics(SubscriptionId(999)).is_none());
604 }
605
606 #[test]
607 fn test_registry_record_delivery_and_drop() {
608 let reg = SubscriptionRegistry::new();
609 let (id, _rx) = reg.create("mv_a".into(), 0, SubscriptionConfig::default());
610
611 reg.record_delivery(id, 10);
612 reg.record_delivery(id, 5);
613 reg.record_drop(id, 2);
614
615 let m = reg.metrics(id).unwrap();
616 assert_eq!(m.events_delivered, 15);
617 assert_eq!(m.events_dropped, 2);
618 }
619
620 #[test]
623 fn test_registry_thread_safety() {
624 let reg = Arc::new(SubscriptionRegistry::new());
625 let mut handles = Vec::new();
626
627 for t in 0..4u32 {
629 let reg = Arc::clone(®);
630 handles.push(std::thread::spawn(move || {
631 let mut ids = Vec::new();
632 for i in 0..100u32 {
633 let name = format!("mv_{t}_{i}");
634 let (id, _rx) = reg.create(name, t, SubscriptionConfig::default());
635 ids.push(id);
636 }
637 ids
638 }));
639 }
640
641 let all_ids: Vec<Vec<SubscriptionId>> =
642 handles.into_iter().map(|h| h.join().unwrap()).collect();
643
644 assert_eq!(reg.subscription_count(), 400);
646
647 let mut flat: Vec<u64> = all_ids.iter().flatten().map(|id| id.0).collect();
649 flat.sort_unstable();
650 flat.dedup();
651 assert_eq!(flat.len(), 400);
652
653 for t in 0..4u32 {
655 let senders = reg.get_senders_for_source(t);
656 assert_eq!(senders.len(), 100);
657 }
658
659 for id in &all_ids[0][..50] {
661 assert!(reg.cancel(*id));
662 }
663 assert_eq!(reg.subscription_count(), 350);
664 assert_eq!(reg.get_senders_for_source(0).len(), 50);
665 }
666
667 #[test]
670 fn test_registry_with_notification_hub() {
671 use crate::subscription::NotificationHub;
672
673 let mut hub = NotificationHub::new(4, 64);
674 let reg = SubscriptionRegistry::new();
675
676 let source_id = hub.register_source().unwrap();
678 let (sub_id, _rx) =
679 reg.create("mv_orders".into(), source_id, SubscriptionConfig::default());
680
681 let senders = reg.get_senders_for_source(source_id);
683 assert_eq!(senders.len(), 1);
684
685 assert!(hub.notify_source(
687 source_id,
688 crate::subscription::EventType::Insert,
689 10,
690 1000,
691 0,
692 ));
693
694 let mut count = 0;
696 hub.drain_notifications(|_n| count += 1);
697 assert_eq!(count, 1);
698
699 reg.cancel(sub_id);
701 assert!(reg.get_senders_for_source(source_id).is_empty());
702 }
703
704 #[test]
705 fn test_registry_broadcast_delivery() {
706 let reg = SubscriptionRegistry::new();
707 let (_, mut rx1) = reg.create("mv_a".into(), 0, SubscriptionConfig::default());
708 let (_, mut rx2) = reg.create("mv_a".into(), 0, SubscriptionConfig::default());
709
710 let senders = reg.get_senders_for_source(0);
712 assert_eq!(senders.len(), 2);
713
714 let batch = Arc::new(make_batch(5));
715 let event = ChangeEvent::insert(batch, 1000, 1);
716
717 for sender in &senders {
718 sender.send(event.clone()).unwrap();
719 }
720
721 let e1 = rx1.try_recv().unwrap();
723 assert_eq!(e1.timestamp(), 1000);
724 assert_eq!(e1.sequence(), Some(1));
725 assert_eq!(e1.row_count(), 5);
726
727 let e2 = rx2.try_recv().unwrap();
728 assert_eq!(e2.timestamp(), 1000);
729 assert_eq!(e2.sequence(), Some(1));
730 }
731
732 #[test]
733 fn test_subscription_id_display() {
734 let id = SubscriptionId(42);
735 assert_eq!(format!("{id}"), "sub-42");
736 }
737}