Skip to main content

laminar_db/subscription/
registry.rs

1//! Per-name replayable broadcast log for SUBSCRIBE. Send and subscribe
2//! are serialized by one mutex per stream so the (replay snapshot, live
3//! receiver) pair an `AS OF EPOCH n` consumer gets has no gap.
4
5#![allow(clippy::disallowed_types)] // cold path
6
7use std::collections::{HashMap, VecDeque};
8use std::sync::Arc;
9
10use arrow_array::RecordBatch;
11use parking_lot::{Mutex, RwLock};
12use tokio::sync::broadcast;
13
14const BROADCAST_CAPACITY: usize = 256;
15
16#[derive(Clone, Debug)]
17pub(crate) enum MvUpdate {
18    Batch(RecordBatch),
19    Barrier { epoch: u64, checkpoint_id: u64 },
20}
21
22/// Where a new subscriber should start reading.
23#[derive(Clone, Copy, Debug)]
24pub enum SubscribeStart {
25    /// See only what's sent after subscribing.
26    Tail,
27    /// Replay everything emitted strictly after the barrier with `epoch == n`.
28    AsOfEpoch(u64),
29}
30
31#[derive(Debug)]
32pub(crate) struct ReplayPruned {
33    /// Earliest barrier epoch still in the log, or `0` if none retained.
34    pub(crate) earliest_retained: u64,
35}
36
37struct StreamLog {
38    inner: Mutex<StreamLogInner>,
39}
40
41struct StreamLogInner {
42    buf: VecDeque<MvUpdate>,
43    bytes: usize,
44    cap: usize,
45    sender: broadcast::Sender<MvUpdate>,
46}
47
48impl StreamLog {
49    fn new(cap: usize) -> Self {
50        let (sender, _) = broadcast::channel(BROADCAST_CAPACITY);
51        Self {
52            inner: Mutex::new(StreamLogInner {
53                buf: VecDeque::new(),
54                bytes: 0,
55                cap,
56                sender,
57            }),
58        }
59    }
60
61    fn send(&self, msg: MvUpdate) {
62        let mut g = self.inner.lock();
63        if g.cap > 0 {
64            g.bytes += approx_size(&msg);
65            g.buf.push_back(msg.clone());
66            evict_to_cap(&mut g);
67        }
68        let _ = g.sender.send(msg);
69    }
70
71    fn subscribe(
72        &self,
73        start: SubscribeStart,
74    ) -> Result<(Vec<MvUpdate>, broadcast::Receiver<MvUpdate>), ReplayPruned> {
75        let g = self.inner.lock();
76        let replay = match start {
77            SubscribeStart::Tail => Vec::new(),
78            SubscribeStart::AsOfEpoch(n) => slice_after_epoch(&g.buf, n)?,
79        };
80        let rx = g.sender.subscribe();
81        Ok((replay, rx))
82    }
83
84    fn set_cap(&self, cap: usize) {
85        let mut g = self.inner.lock();
86        g.cap = cap;
87        evict_to_cap(&mut g);
88    }
89
90    fn subscriber_count(&self) -> usize {
91        self.inner.lock().sender.receiver_count()
92    }
93}
94
95fn evict_to_cap(g: &mut StreamLogInner) {
96    while g.bytes > g.cap && !g.buf.is_empty() {
97        if let Some(evicted) = g.buf.pop_front() {
98            g.bytes = g.bytes.saturating_sub(approx_size(&evicted));
99        }
100    }
101}
102
103fn slice_after_epoch(buf: &VecDeque<MvUpdate>, n: u64) -> Result<Vec<MvUpdate>, ReplayPruned> {
104    let mut found_at = None;
105    let mut earliest = u64::MAX;
106    for (i, msg) in buf.iter().enumerate() {
107        if let MvUpdate::Barrier { epoch, .. } = msg {
108            earliest = earliest.min(*epoch);
109            if *epoch == n {
110                found_at = Some(i);
111            }
112        }
113    }
114    match found_at {
115        Some(i) => Ok(buf.iter().skip(i + 1).cloned().collect()),
116        None => Err(ReplayPruned {
117            earliest_retained: if earliest == u64::MAX { 0 } else { earliest },
118        }),
119    }
120}
121
122fn approx_size(msg: &MvUpdate) -> usize {
123    match msg {
124        MvUpdate::Batch(b) => b.get_array_memory_size(),
125        MvUpdate::Barrier { .. } => 16,
126    }
127}
128
129pub(crate) struct SubscriptionRegistry {
130    streams: RwLock<HashMap<String, Arc<StreamLog>>>,
131}
132
133impl SubscriptionRegistry {
134    pub(crate) fn new() -> Self {
135        Self {
136            streams: RwLock::new(HashMap::new()),
137        }
138    }
139
140    /// `cap == 0` disables retention; the live broadcast still works.
141    pub(crate) fn configure(&self, name: &str, cap: usize) {
142        let log = self.get_or_create(name);
143        log.set_cap(cap);
144    }
145
146    pub(crate) fn subscribe(
147        &self,
148        name: &str,
149        start: SubscribeStart,
150    ) -> Result<(Vec<MvUpdate>, broadcast::Receiver<MvUpdate>), ReplayPruned> {
151        let log = self.get_or_create(name);
152        log.subscribe(start)
153    }
154
155    pub(crate) fn send_batch(&self, name: &str, batch: RecordBatch) {
156        if let Some(log) = self.streams.read().get(name).cloned() {
157            log.send(MvUpdate::Batch(batch));
158        }
159    }
160
161    pub(crate) fn broadcast_barrier(&self, epoch: u64, checkpoint_id: u64) {
162        let msg = MvUpdate::Barrier {
163            epoch,
164            checkpoint_id,
165        };
166        for log in self.streams.read().values() {
167            log.send(msg.clone());
168        }
169    }
170
171    pub(crate) fn drop_name(&self, name: &str) -> bool {
172        self.streams.write().remove(name).is_some()
173    }
174
175    pub(crate) fn subscriber_count(&self, name: &str) -> usize {
176        self.streams
177            .read()
178            .get(name)
179            .map_or(0, |log| log.subscriber_count())
180    }
181
182    fn get_or_create(&self, name: &str) -> Arc<StreamLog> {
183        if let Some(log) = self.streams.read().get(name) {
184            return Arc::clone(log);
185        }
186        Arc::clone(
187            self.streams
188                .write()
189                .entry(name.to_string())
190                .or_insert_with(|| Arc::new(StreamLog::new(0))),
191        )
192    }
193}
194
195impl Default for SubscriptionRegistry {
196    fn default() -> Self {
197        Self::new()
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use std::sync::Arc as StdArc;
204
205    use arrow_array::Int64Array;
206    use arrow_schema::{DataType, Field, Schema};
207
208    use super::*;
209
210    fn batch(ids: &[i64]) -> RecordBatch {
211        let schema = StdArc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
212        RecordBatch::try_new(schema, vec![StdArc::new(Int64Array::from(ids.to_vec()))]).unwrap()
213    }
214
215    #[tokio::test]
216    async fn round_trip() {
217        let reg = SubscriptionRegistry::new();
218        let (replay, mut rx) = reg.subscribe("mv", SubscribeStart::Tail).unwrap();
219        assert!(replay.is_empty());
220
221        reg.send_batch("mv", batch(&[1, 2, 3]));
222        let MvUpdate::Batch(b) = rx.recv().await.unwrap() else {
223            panic!("expected batch");
224        };
225        assert_eq!(b.num_rows(), 3);
226
227        reg.broadcast_barrier(7, 42);
228        let MvUpdate::Barrier {
229            epoch,
230            checkpoint_id,
231        } = rx.recv().await.unwrap()
232        else {
233            panic!("expected barrier");
234        };
235        assert_eq!((epoch, checkpoint_id), (7, 42));
236    }
237
238    #[tokio::test]
239    async fn no_subscribers_is_noop() {
240        let reg = SubscriptionRegistry::new();
241        reg.send_batch("nobody", batch(&[1]));
242        reg.broadcast_barrier(1, 1);
243        assert_eq!(reg.subscriber_count("nobody"), 0);
244    }
245
246    #[tokio::test]
247    async fn drop_name_closes_receivers() {
248        let reg = SubscriptionRegistry::new();
249        let (_, mut rx) = reg.subscribe("mv", SubscribeStart::Tail).unwrap();
250        assert!(reg.drop_name("mv"));
251        assert!(matches!(
252            rx.recv().await,
253            Err(broadcast::error::RecvError::Closed)
254        ));
255    }
256
257    #[test]
258    fn subscriber_count_tracks_attach_drop() {
259        let reg = SubscriptionRegistry::new();
260        let (_, r1) = reg.subscribe("mv", SubscribeStart::Tail).unwrap();
261        let (_, r2) = reg.subscribe("mv", SubscribeStart::Tail).unwrap();
262        assert_eq!(reg.subscriber_count("mv"), 2);
263        drop(r1);
264        assert_eq!(reg.subscriber_count("mv"), 1);
265        drop(r2);
266        assert_eq!(reg.subscriber_count("mv"), 0);
267    }
268
269    #[test]
270    fn tail_subscribe_does_not_replay_history() {
271        let reg = SubscriptionRegistry::new();
272        reg.configure("mv", 1 << 20);
273        reg.broadcast_barrier(1, 1);
274        reg.send_batch("mv", batch(&[10]));
275        reg.broadcast_barrier(2, 2);
276
277        let (replay, _rx) = reg.subscribe("mv", SubscribeStart::Tail).unwrap();
278        assert!(replay.is_empty());
279    }
280
281    #[test]
282    fn as_of_returns_messages_after_matching_barrier() {
283        let reg = SubscriptionRegistry::new();
284        reg.configure("mv", 1 << 20);
285        reg.broadcast_barrier(1, 10);
286        reg.send_batch("mv", batch(&[10]));
287        reg.broadcast_barrier(2, 20);
288        reg.send_batch("mv", batch(&[20]));
289        reg.broadcast_barrier(3, 30);
290
291        let (replay, _rx) = reg.subscribe("mv", SubscribeStart::AsOfEpoch(2)).unwrap();
292        // Everything after the epoch-2 barrier: one batch, then the epoch-3 barrier.
293        assert_eq!(replay.len(), 2);
294        assert!(matches!(&replay[0], MvUpdate::Batch(b) if b.num_rows() == 1));
295        assert!(matches!(
296            &replay[1],
297            MvUpdate::Barrier {
298                epoch: 3,
299                checkpoint_id: 30
300            }
301        ));
302    }
303
304    #[test]
305    fn as_of_below_head_is_pruned() {
306        let reg = SubscriptionRegistry::new();
307        reg.configure("mv", 1 << 20);
308        reg.broadcast_barrier(5, 50);
309        reg.send_batch("mv", batch(&[1]));
310        reg.broadcast_barrier(6, 60);
311
312        let err = reg
313            .subscribe("mv", SubscribeStart::AsOfEpoch(3))
314            .unwrap_err();
315        assert_eq!(err.earliest_retained, 5);
316    }
317
318    #[test]
319    fn unconfigured_log_does_not_retain() {
320        let reg = SubscriptionRegistry::new();
321        // No configure(): cap stays at 0; live still works but replay is empty.
322        let _ = reg.subscribe("mv", SubscribeStart::Tail).unwrap();
323        reg.broadcast_barrier(1, 1);
324        reg.send_batch("mv", batch(&[1]));
325
326        // Fresh AS OF EPOCH 1 should be pruned because nothing was buffered.
327        let err = reg
328            .subscribe("mv", SubscribeStart::AsOfEpoch(1))
329            .unwrap_err();
330        assert_eq!(err.earliest_retained, 0);
331    }
332
333    #[test]
334    fn eviction_trims_to_cap() {
335        let reg = SubscriptionRegistry::new();
336        // Cap small enough that one big batch will evict prior entries.
337        let one_batch_bytes = batch(&[0]).get_array_memory_size();
338        reg.configure("mv", one_batch_bytes + 1);
339
340        reg.broadcast_barrier(1, 1);
341        for i in 0..8i64 {
342            reg.send_batch("mv", batch(&[i]));
343        }
344        reg.broadcast_barrier(9, 9);
345
346        // The early barrier should have been evicted by now.
347        let err = reg
348            .subscribe("mv", SubscribeStart::AsOfEpoch(1))
349            .unwrap_err();
350        assert!(err.earliest_retained >= 9 || err.earliest_retained == 0);
351    }
352
353    #[test]
354    fn single_oversize_message_is_evicted_not_retained() {
355        // A batch larger than the entire cap must not stick around — better
356        // to surface pruned at AS OF time than silently blow the budget.
357        let reg = SubscriptionRegistry::new();
358        reg.configure("mv", 8); // tiny cap, smaller than any real batch
359        reg.broadcast_barrier(1, 1);
360        reg.send_batch("mv", batch(&[1, 2, 3, 4, 5, 6, 7, 8]));
361
362        let err = reg
363            .subscribe("mv", SubscribeStart::AsOfEpoch(1))
364            .unwrap_err();
365        assert_eq!(err.earliest_retained, 0, "buffer should be empty");
366    }
367
368    #[test]
369    fn lowering_cap_trims_existing_buffer() {
370        let reg = SubscriptionRegistry::new();
371        reg.configure("mv", 1 << 20);
372        for i in 0..4i64 {
373            reg.send_batch("mv", batch(&[i]));
374        }
375        reg.broadcast_barrier(1, 1);
376
377        // Drop the cap to 0; everything should be evicted.
378        reg.configure("mv", 0);
379        let err = reg
380            .subscribe("mv", SubscribeStart::AsOfEpoch(1))
381            .unwrap_err();
382        assert_eq!(err.earliest_retained, 0);
383    }
384}