1use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::Arc;
5
6use arrow_array::RecordBatch;
7use arrow_schema::SchemaRef;
8use crossfire::{mpsc, AsyncRx, MAsyncTx};
9use datafusion::physical_expr::PhysicalExpr;
10use tokio::sync::{broadcast, Notify};
11
12use super::registry::MvUpdate;
13
14#[derive(Debug, Clone)]
16pub enum PortalFrame {
17 Batch(RecordBatch),
19 Barrier {
21 epoch: u64,
23 checkpoint_id: u64,
25 },
26 Lagged(u64),
31}
32
33const OUTBOUND_CAPACITY: usize = 256;
34pub(crate) const MAX_SUBSCRIBERS_PER_MV: usize = 64;
35
36#[derive(Debug)]
38pub struct SubscriptionPortal {
39 schema: SchemaRef,
40 outbound: AsyncRx<mpsc::Array<PortalFrame>>,
41 closed: Arc<AtomicBool>,
42 wake: Arc<Notify>,
43}
44
45impl SubscriptionPortal {
46 pub(crate) fn open(
47 name: impl Into<String>,
48 schema: SchemaRef,
49 replay: Vec<MvUpdate>,
50 rx: broadcast::Receiver<MvUpdate>,
51 ) -> Self {
52 Self::spawn(name, schema, replay, rx, None)
53 }
54
55 pub(crate) fn open_with_filter(
56 name: impl Into<String>,
57 schema: SchemaRef,
58 replay: Vec<MvUpdate>,
59 rx: broadcast::Receiver<MvUpdate>,
60 filter: Arc<dyn PhysicalExpr>,
61 ) -> Self {
62 Self::spawn(name, schema, replay, rx, Some(filter))
63 }
64
65 fn spawn(
66 name: impl Into<String>,
67 schema: SchemaRef,
68 replay: Vec<MvUpdate>,
69 rx: broadcast::Receiver<MvUpdate>,
70 filter: Option<Arc<dyn PhysicalExpr>>,
71 ) -> Self {
72 let (tx, outbound) = mpsc::bounded_async::<PortalFrame>(OUTBOUND_CAPACITY);
73 let closed = Arc::new(AtomicBool::new(false));
74 let wake = Arc::new(Notify::new());
75 tokio::spawn(pump_loop(
76 name.into(),
77 replay,
78 rx,
79 tx,
80 Arc::clone(&closed),
81 Arc::clone(&wake),
82 filter,
83 ));
84 Self {
85 schema,
86 outbound,
87 closed,
88 wake,
89 }
90 }
91
92 #[must_use]
94 pub fn schema(&self) -> SchemaRef {
95 Arc::clone(&self.schema)
96 }
97
98 pub async fn next_frame(&mut self) -> Option<PortalFrame> {
100 self.outbound.recv().await.ok()
101 }
102
103 pub fn close(&self) {
106 self.closed.store(true, Ordering::Release);
107 self.wake.notify_waiters();
108 }
109
110 #[must_use]
112 pub fn is_closed(&self) -> bool {
113 self.closed.load(Ordering::Acquire)
114 }
115}
116
117impl Drop for SubscriptionPortal {
118 fn drop(&mut self) {
119 self.close();
120 }
121}
122
123fn translate(
126 msg: MvUpdate,
127 filter: Option<&Arc<dyn PhysicalExpr>>,
128 name: &str,
129) -> Result<Option<PortalFrame>, ()> {
130 match msg {
131 MvUpdate::Batch(batch) => match filter {
132 Some(f) => match crate::filter_compile::apply(&batch, f.as_ref()) {
133 Ok(Some(b)) => Ok(Some(PortalFrame::Batch(b))),
134 Ok(None) => Ok(None),
135 Err(e) => {
136 tracing::warn!(subscription = %name, error = %e, "filter failed; closing");
137 Err(())
138 }
139 },
140 None => Ok(Some(PortalFrame::Batch(batch))),
141 },
142 MvUpdate::Barrier {
143 epoch,
144 checkpoint_id,
145 } => Ok(Some(PortalFrame::Barrier {
146 epoch,
147 checkpoint_id,
148 })),
149 }
150}
151
152async fn pump_loop(
153 name: String,
154 replay: Vec<MvUpdate>,
155 mut broadcast_rx: broadcast::Receiver<MvUpdate>,
156 tx: MAsyncTx<mpsc::Array<PortalFrame>>,
157 closed: Arc<AtomicBool>,
158 wake: Arc<Notify>,
159 filter: Option<Arc<dyn PhysicalExpr>>,
160) {
161 for msg in replay {
162 if closed.load(Ordering::Acquire) {
163 return;
164 }
165 match translate(msg, filter.as_ref(), &name) {
166 Ok(Some(frame)) => {
167 if tx.send(frame).await.is_err() {
168 return;
169 }
170 }
171 Ok(None) => {}
172 Err(()) => return,
173 }
174 }
175 while !closed.load(Ordering::Acquire) {
176 let recv = tokio::select! {
177 biased;
178 () = wake.notified() => continue,
179 r = broadcast_rx.recv() => r,
180 };
181 let msg = match recv {
182 Ok(m) => m,
183 Err(broadcast::error::RecvError::Closed) => return,
184 Err(broadcast::error::RecvError::Lagged(n)) => {
185 tracing::warn!(subscription = %name, skipped = n, "lagged; closing");
186 let _ = tx.send(PortalFrame::Lagged(n)).await;
187 return;
188 }
189 };
190 match translate(msg, filter.as_ref(), &name) {
191 Ok(Some(frame)) => {
192 if tx.send(frame).await.is_err() {
193 return;
194 }
195 }
196 Ok(None) => {}
197 Err(()) => return,
198 }
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use std::sync::Arc as StdArc;
205 use std::time::Duration;
206
207 use arrow_array::Int64Array;
208 use arrow_schema::{DataType, Field, Schema};
209
210 use super::super::registry::{SubscribeStart, SubscriptionRegistry};
211 use super::*;
212
213 fn schema() -> SchemaRef {
214 StdArc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]))
215 }
216
217 fn batch(ids: &[i64]) -> RecordBatch {
218 RecordBatch::try_new(schema(), vec![StdArc::new(Int64Array::from(ids.to_vec()))]).unwrap()
219 }
220
221 #[tokio::test]
222 async fn portal_forwards_batch_and_barrier() {
223 let reg = SubscriptionRegistry::new();
224 let (replay, rx) = reg.subscribe("mv", SubscribeStart::Tail).unwrap();
225 let mut portal = SubscriptionPortal::open("mv", schema(), replay, rx);
226
227 reg.send_batch("mv", batch(&[1, 2]));
228 reg.broadcast_barrier(7, 99);
229
230 let f1 = portal.next_frame().await.expect("frame 1");
231 let PortalFrame::Batch(b) = f1 else {
232 panic!("expected batch, got {f1:?}");
233 };
234 assert_eq!(b.num_rows(), 2);
235
236 let f2 = portal.next_frame().await.expect("frame 2");
237 let PortalFrame::Barrier {
238 epoch,
239 checkpoint_id,
240 } = f2
241 else {
242 panic!("expected barrier, got {f2:?}");
243 };
244 assert_eq!(epoch, 7);
245 assert_eq!(checkpoint_id, 99);
246 }
247
248 #[tokio::test]
249 async fn portal_closes_on_drop_name() {
250 let reg = SubscriptionRegistry::new();
251 let (replay, rx) = reg.subscribe("mv", SubscribeStart::Tail).unwrap();
252 let mut portal = SubscriptionPortal::open("mv", schema(), replay, rx);
253
254 reg.send_batch("mv", batch(&[1]));
255 let _ = portal.next_frame().await;
256
257 reg.drop_name("mv");
258
259 let frame = tokio::time::timeout(Duration::from_millis(500), portal.next_frame())
260 .await
261 .unwrap();
262 assert!(frame.is_none());
263 }
264
265 #[tokio::test]
266 async fn portal_emits_lagged_as_final_frame() {
267 let reg = SubscriptionRegistry::new();
268 let (replay, rx) = reg.subscribe("mv", SubscribeStart::Tail).unwrap();
269 let mut portal = SubscriptionPortal::open("mv", schema(), replay, rx);
270
271 for i in 0..1024 {
272 reg.send_batch("mv", batch(&[i]));
273 }
274
275 let frames = tokio::time::timeout(Duration::from_secs(1), async {
276 let mut frames = Vec::new();
277 while let Some(frame) = portal.next_frame().await {
278 frames.push(frame);
279 }
280 frames
281 })
282 .await
283 .expect("portal must close after lag");
284
285 assert!(
286 matches!(frames.last(), Some(PortalFrame::Lagged(_))),
287 "last frame must be Lagged, got: {:?}",
288 frames.last()
289 );
290 }
291
292 #[tokio::test]
293 async fn portal_drains_replay_before_live() {
294 let reg = SubscriptionRegistry::new();
295 reg.configure("mv", 1 << 20);
296
297 reg.broadcast_barrier(1, 1);
299 reg.send_batch("mv", batch(&[10]));
300 reg.broadcast_barrier(2, 2);
301 reg.send_batch("mv", batch(&[20]));
302
303 let (replay, rx) = reg.subscribe("mv", SubscribeStart::AsOfEpoch(1)).unwrap();
307 let mut portal = SubscriptionPortal::open("mv", schema(), replay, rx);
308
309 reg.send_batch("mv", batch(&[30]));
310
311 let mut row_seq = Vec::new();
312 let mut barriers = Vec::new();
313 for _ in 0..4 {
314 match portal.next_frame().await.unwrap() {
315 PortalFrame::Batch(b) => {
316 let v = b
317 .column(0)
318 .as_any()
319 .downcast_ref::<Int64Array>()
320 .unwrap()
321 .value(0);
322 row_seq.push(v);
323 }
324 PortalFrame::Barrier { epoch, .. } => barriers.push(epoch),
325 PortalFrame::Lagged(n) => panic!("unexpected lag: {n}"),
326 }
327 }
328 assert_eq!(row_seq, vec![10, 20, 30]);
329 assert_eq!(barriers, vec![2]);
330 }
331
332 #[tokio::test]
333 async fn close_is_idempotent() {
334 let reg = SubscriptionRegistry::new();
335 let (replay, rx) = reg.subscribe("mv", SubscribeStart::Tail).unwrap();
336 let portal = SubscriptionPortal::open("mv", schema(), replay, rx);
337
338 portal.close();
339 portal.close();
340 assert!(portal.is_closed());
341 }
342}