Skip to main content

laminar_db/subscription/
portal.rs

1//! Per-subscriber portal: pump task forwards broadcast updates to the wire.
2
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::Arc;
5
6use arrow_array::RecordBatch;
7use arrow_schema::SchemaRef;
8use crossfire::{mpsc, AsyncRx, MAsyncTx};
9use datafusion::physical_expr::PhysicalExpr;
10use tokio::sync::{broadcast, Notify};
11
12use super::registry::MvUpdate;
13
14/// One frame emitted toward the wire.
15#[derive(Debug, Clone)]
16pub enum PortalFrame {
17    /// Rows produced in a cycle.
18    Batch(RecordBatch),
19    /// Rows preceding this marker are durable as of `epoch`.
20    Barrier {
21        /// Engine checkpoint epoch.
22        epoch: u64,
23        /// Engine checkpoint id.
24        checkpoint_id: u64,
25    },
26    /// Consumer fell behind by `skipped` messages and the broadcast dropped
27    /// them. The portal closes immediately after this frame; the wire layer
28    /// translates it into a client-visible error so the disconnect isn't
29    /// silent.
30    Lagged(u64),
31}
32
33const OUTBOUND_CAPACITY: usize = 256;
34pub(crate) const MAX_SUBSCRIBERS_PER_MV: usize = 64;
35
36/// One SUBSCRIBE consumer.
37#[derive(Debug)]
38pub struct SubscriptionPortal {
39    schema: SchemaRef,
40    outbound: AsyncRx<mpsc::Array<PortalFrame>>,
41    closed: Arc<AtomicBool>,
42    wake: Arc<Notify>,
43}
44
45impl SubscriptionPortal {
46    pub(crate) fn open(
47        name: impl Into<String>,
48        schema: SchemaRef,
49        replay: Vec<MvUpdate>,
50        rx: broadcast::Receiver<MvUpdate>,
51    ) -> Self {
52        Self::spawn(name, schema, replay, rx, None)
53    }
54
55    pub(crate) fn open_with_filter(
56        name: impl Into<String>,
57        schema: SchemaRef,
58        replay: Vec<MvUpdate>,
59        rx: broadcast::Receiver<MvUpdate>,
60        filter: Arc<dyn PhysicalExpr>,
61    ) -> Self {
62        Self::spawn(name, schema, replay, rx, Some(filter))
63    }
64
65    fn spawn(
66        name: impl Into<String>,
67        schema: SchemaRef,
68        replay: Vec<MvUpdate>,
69        rx: broadcast::Receiver<MvUpdate>,
70        filter: Option<Arc<dyn PhysicalExpr>>,
71    ) -> Self {
72        let (tx, outbound) = mpsc::bounded_async::<PortalFrame>(OUTBOUND_CAPACITY);
73        let closed = Arc::new(AtomicBool::new(false));
74        let wake = Arc::new(Notify::new());
75        tokio::spawn(pump_loop(
76            name.into(),
77            replay,
78            rx,
79            tx,
80            Arc::clone(&closed),
81            Arc::clone(&wake),
82            filter,
83        ));
84        Self {
85            schema,
86            outbound,
87            closed,
88            wake,
89        }
90    }
91
92    /// Schema of the subscribed object.
93    #[must_use]
94    pub fn schema(&self) -> SchemaRef {
95        Arc::clone(&self.schema)
96    }
97
98    /// Next frame, or `None` once the pump exits.
99    pub async fn next_frame(&mut self) -> Option<PortalFrame> {
100        self.outbound.recv().await.ok()
101    }
102
103    /// Signal the pump to stop. Idempotent. Wakes the pump if it's parked
104    /// on `broadcast_rx.recv()` so it can re-check the flag and exit.
105    pub fn close(&self) {
106        self.closed.store(true, Ordering::Release);
107        self.wake.notify_waiters();
108    }
109
110    /// True after `close()` has been called.
111    #[must_use]
112    pub fn is_closed(&self) -> bool {
113        self.closed.load(Ordering::Acquire)
114    }
115}
116
117impl Drop for SubscriptionPortal {
118    fn drop(&mut self) {
119        self.close();
120    }
121}
122
123/// `Ok(None)` when the filter excluded the batch; `Err(())` when the
124/// filter itself failed (caller should close the pump).
125fn translate(
126    msg: MvUpdate,
127    filter: Option<&Arc<dyn PhysicalExpr>>,
128    name: &str,
129) -> Result<Option<PortalFrame>, ()> {
130    match msg {
131        MvUpdate::Batch(batch) => match filter {
132            Some(f) => match crate::filter_compile::apply(&batch, f.as_ref()) {
133                Ok(Some(b)) => Ok(Some(PortalFrame::Batch(b))),
134                Ok(None) => Ok(None),
135                Err(e) => {
136                    tracing::warn!(subscription = %name, error = %e, "filter failed; closing");
137                    Err(())
138                }
139            },
140            None => Ok(Some(PortalFrame::Batch(batch))),
141        },
142        MvUpdate::Barrier {
143            epoch,
144            checkpoint_id,
145        } => Ok(Some(PortalFrame::Barrier {
146            epoch,
147            checkpoint_id,
148        })),
149    }
150}
151
152async fn pump_loop(
153    name: String,
154    replay: Vec<MvUpdate>,
155    mut broadcast_rx: broadcast::Receiver<MvUpdate>,
156    tx: MAsyncTx<mpsc::Array<PortalFrame>>,
157    closed: Arc<AtomicBool>,
158    wake: Arc<Notify>,
159    filter: Option<Arc<dyn PhysicalExpr>>,
160) {
161    for msg in replay {
162        if closed.load(Ordering::Acquire) {
163            return;
164        }
165        match translate(msg, filter.as_ref(), &name) {
166            Ok(Some(frame)) => {
167                if tx.send(frame).await.is_err() {
168                    return;
169                }
170            }
171            Ok(None) => {}
172            Err(()) => return,
173        }
174    }
175    while !closed.load(Ordering::Acquire) {
176        let recv = tokio::select! {
177            biased;
178            () = wake.notified() => continue,
179            r = broadcast_rx.recv() => r,
180        };
181        let msg = match recv {
182            Ok(m) => m,
183            Err(broadcast::error::RecvError::Closed) => return,
184            Err(broadcast::error::RecvError::Lagged(n)) => {
185                tracing::warn!(subscription = %name, skipped = n, "lagged; closing");
186                let _ = tx.send(PortalFrame::Lagged(n)).await;
187                return;
188            }
189        };
190        match translate(msg, filter.as_ref(), &name) {
191            Ok(Some(frame)) => {
192                if tx.send(frame).await.is_err() {
193                    return;
194                }
195            }
196            Ok(None) => {}
197            Err(()) => return,
198        }
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use std::sync::Arc as StdArc;
205    use std::time::Duration;
206
207    use arrow_array::Int64Array;
208    use arrow_schema::{DataType, Field, Schema};
209
210    use super::super::registry::{SubscribeStart, SubscriptionRegistry};
211    use super::*;
212
213    fn schema() -> SchemaRef {
214        StdArc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]))
215    }
216
217    fn batch(ids: &[i64]) -> RecordBatch {
218        RecordBatch::try_new(schema(), vec![StdArc::new(Int64Array::from(ids.to_vec()))]).unwrap()
219    }
220
221    #[tokio::test]
222    async fn portal_forwards_batch_and_barrier() {
223        let reg = SubscriptionRegistry::new();
224        let (replay, rx) = reg.subscribe("mv", SubscribeStart::Tail).unwrap();
225        let mut portal = SubscriptionPortal::open("mv", schema(), replay, rx);
226
227        reg.send_batch("mv", batch(&[1, 2]));
228        reg.broadcast_barrier(7, 99);
229
230        let f1 = portal.next_frame().await.expect("frame 1");
231        let PortalFrame::Batch(b) = f1 else {
232            panic!("expected batch, got {f1:?}");
233        };
234        assert_eq!(b.num_rows(), 2);
235
236        let f2 = portal.next_frame().await.expect("frame 2");
237        let PortalFrame::Barrier {
238            epoch,
239            checkpoint_id,
240        } = f2
241        else {
242            panic!("expected barrier, got {f2:?}");
243        };
244        assert_eq!(epoch, 7);
245        assert_eq!(checkpoint_id, 99);
246    }
247
248    #[tokio::test]
249    async fn portal_closes_on_drop_name() {
250        let reg = SubscriptionRegistry::new();
251        let (replay, rx) = reg.subscribe("mv", SubscribeStart::Tail).unwrap();
252        let mut portal = SubscriptionPortal::open("mv", schema(), replay, rx);
253
254        reg.send_batch("mv", batch(&[1]));
255        let _ = portal.next_frame().await;
256
257        reg.drop_name("mv");
258
259        let frame = tokio::time::timeout(Duration::from_millis(500), portal.next_frame())
260            .await
261            .unwrap();
262        assert!(frame.is_none());
263    }
264
265    #[tokio::test]
266    async fn portal_emits_lagged_as_final_frame() {
267        let reg = SubscriptionRegistry::new();
268        let (replay, rx) = reg.subscribe("mv", SubscribeStart::Tail).unwrap();
269        let mut portal = SubscriptionPortal::open("mv", schema(), replay, rx);
270
271        for i in 0..1024 {
272            reg.send_batch("mv", batch(&[i]));
273        }
274
275        let frames = tokio::time::timeout(Duration::from_secs(1), async {
276            let mut frames = Vec::new();
277            while let Some(frame) = portal.next_frame().await {
278                frames.push(frame);
279            }
280            frames
281        })
282        .await
283        .expect("portal must close after lag");
284
285        assert!(
286            matches!(frames.last(), Some(PortalFrame::Lagged(_))),
287            "last frame must be Lagged, got: {:?}",
288            frames.last()
289        );
290    }
291
292    #[tokio::test]
293    async fn portal_drains_replay_before_live() {
294        let reg = SubscriptionRegistry::new();
295        reg.configure("mv", 1 << 20);
296
297        // Two epochs of pre-existing history.
298        reg.broadcast_barrier(1, 1);
299        reg.send_batch("mv", batch(&[10]));
300        reg.broadcast_barrier(2, 2);
301        reg.send_batch("mv", batch(&[20]));
302
303        // Subscribe AS OF EPOCH 1: the client has everything up to and
304        // including the epoch-1 barrier; replay must cover everything after
305        // it. That's batch[10], then barrier(2), then batch[20].
306        let (replay, rx) = reg.subscribe("mv", SubscribeStart::AsOfEpoch(1)).unwrap();
307        let mut portal = SubscriptionPortal::open("mv", schema(), replay, rx);
308
309        reg.send_batch("mv", batch(&[30]));
310
311        let mut row_seq = Vec::new();
312        let mut barriers = Vec::new();
313        for _ in 0..4 {
314            match portal.next_frame().await.unwrap() {
315                PortalFrame::Batch(b) => {
316                    let v = b
317                        .column(0)
318                        .as_any()
319                        .downcast_ref::<Int64Array>()
320                        .unwrap()
321                        .value(0);
322                    row_seq.push(v);
323                }
324                PortalFrame::Barrier { epoch, .. } => barriers.push(epoch),
325                PortalFrame::Lagged(n) => panic!("unexpected lag: {n}"),
326            }
327        }
328        assert_eq!(row_seq, vec![10, 20, 30]);
329        assert_eq!(barriers, vec![2]);
330    }
331
332    #[tokio::test]
333    async fn close_is_idempotent() {
334        let reg = SubscriptionRegistry::new();
335        let (replay, rx) = reg.subscribe("mv", SubscribeStart::Tail).unwrap();
336        let portal = SubscriptionPortal::open("mv", schema(), replay, rx);
337
338        portal.close();
339        portal.close();
340        assert!(portal.is_closed());
341    }
342}