1#![allow(clippy::disallowed_types)] use std::collections::{HashMap, VecDeque};
8use std::sync::Arc;
9
10use arrow_array::RecordBatch;
11use parking_lot::{Mutex, RwLock};
12use tokio::sync::broadcast;
13
14const BROADCAST_CAPACITY: usize = 256;
15
16#[derive(Clone, Debug)]
17pub(crate) enum MvUpdate {
18 Batch(RecordBatch),
19 Barrier { epoch: u64, checkpoint_id: u64 },
20}
21
22#[derive(Clone, Copy, Debug)]
24pub enum SubscribeStart {
25 Tail,
27 AsOfEpoch(u64),
29}
30
31#[derive(Debug)]
32pub(crate) struct ReplayPruned {
33 pub(crate) earliest_retained: u64,
35}
36
37struct StreamLog {
38 inner: Mutex<StreamLogInner>,
39}
40
41struct StreamLogInner {
42 buf: VecDeque<MvUpdate>,
43 bytes: usize,
44 cap: usize,
45 sender: broadcast::Sender<MvUpdate>,
46}
47
48impl StreamLog {
49 fn new(cap: usize) -> Self {
50 let (sender, _) = broadcast::channel(BROADCAST_CAPACITY);
51 Self {
52 inner: Mutex::new(StreamLogInner {
53 buf: VecDeque::new(),
54 bytes: 0,
55 cap,
56 sender,
57 }),
58 }
59 }
60
61 fn send(&self, msg: MvUpdate) {
62 let mut g = self.inner.lock();
63 if g.cap > 0 {
64 g.bytes += approx_size(&msg);
65 g.buf.push_back(msg.clone());
66 evict_to_cap(&mut g);
67 }
68 let _ = g.sender.send(msg);
69 }
70
71 fn subscribe(
72 &self,
73 start: SubscribeStart,
74 ) -> Result<(Vec<MvUpdate>, broadcast::Receiver<MvUpdate>), ReplayPruned> {
75 let g = self.inner.lock();
76 let replay = match start {
77 SubscribeStart::Tail => Vec::new(),
78 SubscribeStart::AsOfEpoch(n) => slice_after_epoch(&g.buf, n)?,
79 };
80 let rx = g.sender.subscribe();
81 Ok((replay, rx))
82 }
83
84 fn set_cap(&self, cap: usize) {
85 let mut g = self.inner.lock();
86 g.cap = cap;
87 evict_to_cap(&mut g);
88 }
89
90 fn subscriber_count(&self) -> usize {
91 self.inner.lock().sender.receiver_count()
92 }
93}
94
95fn evict_to_cap(g: &mut StreamLogInner) {
96 while g.bytes > g.cap && !g.buf.is_empty() {
97 if let Some(evicted) = g.buf.pop_front() {
98 g.bytes = g.bytes.saturating_sub(approx_size(&evicted));
99 }
100 }
101}
102
103fn slice_after_epoch(buf: &VecDeque<MvUpdate>, n: u64) -> Result<Vec<MvUpdate>, ReplayPruned> {
104 let mut found_at = None;
105 let mut earliest = u64::MAX;
106 for (i, msg) in buf.iter().enumerate() {
107 if let MvUpdate::Barrier { epoch, .. } = msg {
108 earliest = earliest.min(*epoch);
109 if *epoch == n {
110 found_at = Some(i);
111 }
112 }
113 }
114 match found_at {
115 Some(i) => Ok(buf.iter().skip(i + 1).cloned().collect()),
116 None => Err(ReplayPruned {
117 earliest_retained: if earliest == u64::MAX { 0 } else { earliest },
118 }),
119 }
120}
121
122fn approx_size(msg: &MvUpdate) -> usize {
123 match msg {
124 MvUpdate::Batch(b) => b.get_array_memory_size(),
125 MvUpdate::Barrier { .. } => 16,
126 }
127}
128
129pub(crate) struct SubscriptionRegistry {
130 streams: RwLock<HashMap<String, Arc<StreamLog>>>,
131}
132
133impl SubscriptionRegistry {
134 pub(crate) fn new() -> Self {
135 Self {
136 streams: RwLock::new(HashMap::new()),
137 }
138 }
139
140 pub(crate) fn configure(&self, name: &str, cap: usize) {
142 let log = self.get_or_create(name);
143 log.set_cap(cap);
144 }
145
146 pub(crate) fn subscribe(
147 &self,
148 name: &str,
149 start: SubscribeStart,
150 ) -> Result<(Vec<MvUpdate>, broadcast::Receiver<MvUpdate>), ReplayPruned> {
151 let log = self.get_or_create(name);
152 log.subscribe(start)
153 }
154
155 pub(crate) fn send_batch(&self, name: &str, batch: RecordBatch) {
156 if let Some(log) = self.streams.read().get(name).cloned() {
157 log.send(MvUpdate::Batch(batch));
158 }
159 }
160
161 pub(crate) fn broadcast_barrier(&self, epoch: u64, checkpoint_id: u64) {
162 let msg = MvUpdate::Barrier {
163 epoch,
164 checkpoint_id,
165 };
166 for log in self.streams.read().values() {
167 log.send(msg.clone());
168 }
169 }
170
171 pub(crate) fn drop_name(&self, name: &str) -> bool {
172 self.streams.write().remove(name).is_some()
173 }
174
175 pub(crate) fn subscriber_count(&self, name: &str) -> usize {
176 self.streams
177 .read()
178 .get(name)
179 .map_or(0, |log| log.subscriber_count())
180 }
181
182 fn get_or_create(&self, name: &str) -> Arc<StreamLog> {
183 if let Some(log) = self.streams.read().get(name) {
184 return Arc::clone(log);
185 }
186 Arc::clone(
187 self.streams
188 .write()
189 .entry(name.to_string())
190 .or_insert_with(|| Arc::new(StreamLog::new(0))),
191 )
192 }
193}
194
195impl Default for SubscriptionRegistry {
196 fn default() -> Self {
197 Self::new()
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use std::sync::Arc as StdArc;
204
205 use arrow_array::Int64Array;
206 use arrow_schema::{DataType, Field, Schema};
207
208 use super::*;
209
210 fn batch(ids: &[i64]) -> RecordBatch {
211 let schema = StdArc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
212 RecordBatch::try_new(schema, vec![StdArc::new(Int64Array::from(ids.to_vec()))]).unwrap()
213 }
214
215 #[tokio::test]
216 async fn round_trip() {
217 let reg = SubscriptionRegistry::new();
218 let (replay, mut rx) = reg.subscribe("mv", SubscribeStart::Tail).unwrap();
219 assert!(replay.is_empty());
220
221 reg.send_batch("mv", batch(&[1, 2, 3]));
222 let MvUpdate::Batch(b) = rx.recv().await.unwrap() else {
223 panic!("expected batch");
224 };
225 assert_eq!(b.num_rows(), 3);
226
227 reg.broadcast_barrier(7, 42);
228 let MvUpdate::Barrier {
229 epoch,
230 checkpoint_id,
231 } = rx.recv().await.unwrap()
232 else {
233 panic!("expected barrier");
234 };
235 assert_eq!((epoch, checkpoint_id), (7, 42));
236 }
237
238 #[tokio::test]
239 async fn no_subscribers_is_noop() {
240 let reg = SubscriptionRegistry::new();
241 reg.send_batch("nobody", batch(&[1]));
242 reg.broadcast_barrier(1, 1);
243 assert_eq!(reg.subscriber_count("nobody"), 0);
244 }
245
246 #[tokio::test]
247 async fn drop_name_closes_receivers() {
248 let reg = SubscriptionRegistry::new();
249 let (_, mut rx) = reg.subscribe("mv", SubscribeStart::Tail).unwrap();
250 assert!(reg.drop_name("mv"));
251 assert!(matches!(
252 rx.recv().await,
253 Err(broadcast::error::RecvError::Closed)
254 ));
255 }
256
257 #[test]
258 fn subscriber_count_tracks_attach_drop() {
259 let reg = SubscriptionRegistry::new();
260 let (_, r1) = reg.subscribe("mv", SubscribeStart::Tail).unwrap();
261 let (_, r2) = reg.subscribe("mv", SubscribeStart::Tail).unwrap();
262 assert_eq!(reg.subscriber_count("mv"), 2);
263 drop(r1);
264 assert_eq!(reg.subscriber_count("mv"), 1);
265 drop(r2);
266 assert_eq!(reg.subscriber_count("mv"), 0);
267 }
268
269 #[test]
270 fn tail_subscribe_does_not_replay_history() {
271 let reg = SubscriptionRegistry::new();
272 reg.configure("mv", 1 << 20);
273 reg.broadcast_barrier(1, 1);
274 reg.send_batch("mv", batch(&[10]));
275 reg.broadcast_barrier(2, 2);
276
277 let (replay, _rx) = reg.subscribe("mv", SubscribeStart::Tail).unwrap();
278 assert!(replay.is_empty());
279 }
280
281 #[test]
282 fn as_of_returns_messages_after_matching_barrier() {
283 let reg = SubscriptionRegistry::new();
284 reg.configure("mv", 1 << 20);
285 reg.broadcast_barrier(1, 10);
286 reg.send_batch("mv", batch(&[10]));
287 reg.broadcast_barrier(2, 20);
288 reg.send_batch("mv", batch(&[20]));
289 reg.broadcast_barrier(3, 30);
290
291 let (replay, _rx) = reg.subscribe("mv", SubscribeStart::AsOfEpoch(2)).unwrap();
292 assert_eq!(replay.len(), 2);
294 assert!(matches!(&replay[0], MvUpdate::Batch(b) if b.num_rows() == 1));
295 assert!(matches!(
296 &replay[1],
297 MvUpdate::Barrier {
298 epoch: 3,
299 checkpoint_id: 30
300 }
301 ));
302 }
303
304 #[test]
305 fn as_of_below_head_is_pruned() {
306 let reg = SubscriptionRegistry::new();
307 reg.configure("mv", 1 << 20);
308 reg.broadcast_barrier(5, 50);
309 reg.send_batch("mv", batch(&[1]));
310 reg.broadcast_barrier(6, 60);
311
312 let err = reg
313 .subscribe("mv", SubscribeStart::AsOfEpoch(3))
314 .unwrap_err();
315 assert_eq!(err.earliest_retained, 5);
316 }
317
318 #[test]
319 fn unconfigured_log_does_not_retain() {
320 let reg = SubscriptionRegistry::new();
321 let _ = reg.subscribe("mv", SubscribeStart::Tail).unwrap();
323 reg.broadcast_barrier(1, 1);
324 reg.send_batch("mv", batch(&[1]));
325
326 let err = reg
328 .subscribe("mv", SubscribeStart::AsOfEpoch(1))
329 .unwrap_err();
330 assert_eq!(err.earliest_retained, 0);
331 }
332
333 #[test]
334 fn eviction_trims_to_cap() {
335 let reg = SubscriptionRegistry::new();
336 let one_batch_bytes = batch(&[0]).get_array_memory_size();
338 reg.configure("mv", one_batch_bytes + 1);
339
340 reg.broadcast_barrier(1, 1);
341 for i in 0..8i64 {
342 reg.send_batch("mv", batch(&[i]));
343 }
344 reg.broadcast_barrier(9, 9);
345
346 let err = reg
348 .subscribe("mv", SubscribeStart::AsOfEpoch(1))
349 .unwrap_err();
350 assert!(err.earliest_retained >= 9 || err.earliest_retained == 0);
351 }
352
353 #[test]
354 fn single_oversize_message_is_evicted_not_retained() {
355 let reg = SubscriptionRegistry::new();
358 reg.configure("mv", 8); reg.broadcast_barrier(1, 1);
360 reg.send_batch("mv", batch(&[1, 2, 3, 4, 5, 6, 7, 8]));
361
362 let err = reg
363 .subscribe("mv", SubscribeStart::AsOfEpoch(1))
364 .unwrap_err();
365 assert_eq!(err.earliest_retained, 0, "buffer should be empty");
366 }
367
368 #[test]
369 fn lowering_cap_trims_existing_buffer() {
370 let reg = SubscriptionRegistry::new();
371 reg.configure("mv", 1 << 20);
372 for i in 0..4i64 {
373 reg.send_batch("mv", batch(&[i]));
374 }
375 reg.broadcast_barrier(1, 1);
376
377 reg.configure("mv", 0);
379 let err = reg
380 .subscribe("mv", SubscribeStart::AsOfEpoch(1))
381 .unwrap_err();
382 assert_eq!(err.earliest_retained, 0);
383 }
384}