Skip to main content

laminar_core/subscription/
callback.rs

1//! Callback-based subscriptions — [`SubscriptionCallback`] trait and
2//! [`CallbackSubscriptionHandle`].
3//!
4//! Provides a callback-based subscription API where users register a callback
5//! function or trait object that is invoked for every change event. The callback
6//! runs on a dedicated tokio task, wrapping the channel-based broadcast receiver
7//! from the [`SubscriptionRegistry`] internally.
8//!
9//! # API Styles
10//!
11//! - **Trait-based**: Implement [`SubscriptionCallback`] for full control over
12//!   change, error, and completion events.
13//! - **Closure-based**: Use [`subscribe_fn`] for simple cases where only
14//!   `on_change` is needed.
15//!
16//! # Panic Safety
17//!
18//! Panics in the callback's [`on_change`](SubscriptionCallback::on_change) are
19//! caught via [`std::panic::catch_unwind`] and forwarded to
20//! [`on_error`](SubscriptionCallback::on_error) as
21//! [`PushSubscriptionError::Internal`].
22//!
23//! # Lifecycle
24//!
25//! Dropping a [`CallbackSubscriptionHandle`] automatically cancels the
26//! subscription and aborts the callback task.
27
28use std::panic::AssertUnwindSafe;
29use std::sync::Arc;
30
31use tokio::sync::broadcast;
32
33use crate::subscription::event::ChangeEvent;
34use crate::subscription::handle::PushSubscriptionError;
35use crate::subscription::registry::{
36    SubscriptionConfig, SubscriptionId, SubscriptionMetrics, SubscriptionRegistry,
37};
38
39// ---------------------------------------------------------------------------
40// SubscriptionCallback
41// ---------------------------------------------------------------------------
42
43/// Callback trait for push-based subscriptions.
44///
45/// Implement this trait to receive change events via callback. The callback
46/// runs on a dedicated tokio task and is invoked for every event pushed by
47/// the Ring 1 dispatcher.
48///
49/// # Example
50///
51/// ```rust,ignore
52/// struct MyHandler;
53///
54/// impl SubscriptionCallback for MyHandler {
55///     fn on_change(&self, event: ChangeEvent) {
56///         match event {
57///             ChangeEvent::Insert { data, .. } => println!("{} rows", data.num_rows()),
58///             _ => {}
59///         }
60///     }
61/// }
62/// ```
63pub trait SubscriptionCallback: Send + Sync + 'static {
64    /// Called for each change event.
65    fn on_change(&self, event: ChangeEvent);
66
67    /// Called when an error occurs (e.g., lagged behind, internal error).
68    ///
69    /// Default implementation logs the error via `tracing::warn!`.
70    fn on_error(&self, error: PushSubscriptionError) {
71        tracing::warn!("subscription callback error: {}", error);
72    }
73
74    /// Called when the subscription is closed (source dropped or cancelled).
75    ///
76    /// Default implementation is a no-op.
77    fn on_complete(&self) {}
78}
79
80// ---------------------------------------------------------------------------
81// FnCallback (private adapter)
82// ---------------------------------------------------------------------------
83
84/// Adapter that wraps a closure into a [`SubscriptionCallback`].
85struct FnCallback<F>(F);
86
87impl<F: Fn(ChangeEvent) + Send + Sync + 'static> SubscriptionCallback for FnCallback<F> {
88    fn on_change(&self, event: ChangeEvent) {
89        (self.0)(event);
90    }
91}
92
93// ---------------------------------------------------------------------------
94// CallbackSubscriptionHandle
95// ---------------------------------------------------------------------------
96
97/// Handle for a callback-based subscription.
98///
99/// Provides lifecycle management (pause/resume/cancel) for the callback task.
100/// The handle and the callback task share the same `SubscriptionEntry` in
101/// the registry (via [`SubscriptionId`]), so `pause()` / `cancel()` on the
102/// handle directly affects the task's event delivery.
103///
104/// Dropping the handle automatically cancels the subscription and aborts the
105/// callback task.
106pub struct CallbackSubscriptionHandle {
107    /// Subscription ID (shared with the callback task).
108    id: SubscriptionId,
109    /// Registry reference for lifecycle management.
110    registry: Arc<SubscriptionRegistry>,
111    /// Join handle for the callback runner task.
112    task: Option<tokio::task::JoinHandle<()>>,
113    /// Whether the subscription has been explicitly cancelled.
114    cancelled: bool,
115}
116
117impl CallbackSubscriptionHandle {
118    /// Pauses the subscription.
119    ///
120    /// While paused, events are buffered or dropped per the backpressure
121    /// configuration. Returns `true` if the subscription was active and is
122    /// now paused.
123    #[must_use]
124    pub fn pause(&self) -> bool {
125        self.registry.pause(self.id)
126    }
127
128    /// Resumes a paused subscription.
129    ///
130    /// Returns `true` if the subscription was paused and is now active.
131    #[must_use]
132    pub fn resume(&self) -> bool {
133        self.registry.resume(self.id)
134    }
135
136    /// Cancels the subscription and aborts the callback task.
137    ///
138    /// The subscription is removed from the registry (dropping the broadcast
139    /// sender) and the task is aborted as a safety net.
140    pub fn cancel(&mut self) {
141        if !self.cancelled {
142            self.cancelled = true;
143            self.registry.cancel(self.id);
144            if let Some(task) = self.task.take() {
145                task.abort();
146            }
147        }
148    }
149
150    /// Returns the subscription ID.
151    #[must_use]
152    pub fn id(&self) -> SubscriptionId {
153        self.id
154    }
155
156    /// Returns subscription metrics from the registry.
157    #[must_use]
158    pub fn metrics(&self) -> Option<SubscriptionMetrics> {
159        self.registry.metrics(self.id)
160    }
161
162    /// Returns `true` if the subscription has been cancelled.
163    #[must_use]
164    pub fn is_cancelled(&self) -> bool {
165        self.cancelled
166    }
167}
168
169impl Drop for CallbackSubscriptionHandle {
170    fn drop(&mut self) {
171        if !self.cancelled {
172            self.registry.cancel(self.id);
173            if let Some(task) = self.task.take() {
174                task.abort();
175            }
176        }
177    }
178}
179
180// ---------------------------------------------------------------------------
181// Factory Functions
182// ---------------------------------------------------------------------------
183
184/// Creates a callback-based subscription.
185///
186/// Registers a subscription in the registry, then spawns a tokio task that
187/// calls `callback.on_change()` for every event. Panics in the callback are
188/// caught and forwarded to `callback.on_error()`.
189///
190/// When the broadcast sender is dropped (e.g., via cancel or registry
191/// cleanup), the task calls `callback.on_complete()` and exits.
192///
193/// # Arguments
194///
195/// * `registry` — Subscription registry for lifecycle management.
196/// * `source_name` — Name of the source MV or query.
197/// * `source_id` — Ring 0 source identifier.
198/// * `config` — Subscription configuration.
199/// * `callback` — Implementation of [`SubscriptionCallback`].
200pub fn subscribe_callback<C: SubscriptionCallback>(
201    registry: Arc<SubscriptionRegistry>,
202    source_name: String,
203    source_id: u32,
204    config: SubscriptionConfig,
205    callback: C,
206) -> CallbackSubscriptionHandle {
207    let (id, receiver) = registry.create(source_name, source_id, config);
208    let callback = Arc::new(callback);
209
210    let task = tokio::spawn(callback_runner(receiver, callback));
211
212    CallbackSubscriptionHandle {
213        id,
214        registry,
215        task: Some(task),
216        cancelled: false,
217    }
218}
219
220/// Creates a closure-based subscription (convenience wrapper).
221///
222/// Equivalent to [`subscribe_callback`] with a closure wrapped in an internal
223/// adapter that implements [`SubscriptionCallback`].
224///
225/// # Example
226///
227/// ```rust,ignore
228/// let handle = subscribe_fn(registry, "trades".into(), 0, config, |event| {
229///     println!("Got: {:?}", event.event_type());
230/// });
231/// ```
232pub fn subscribe_fn<F>(
233    registry: Arc<SubscriptionRegistry>,
234    source_name: String,
235    source_id: u32,
236    config: SubscriptionConfig,
237    f: F,
238) -> CallbackSubscriptionHandle
239where
240    F: Fn(ChangeEvent) + Send + Sync + 'static,
241{
242    subscribe_callback(registry, source_name, source_id, config, FnCallback(f))
243}
244
245// ---------------------------------------------------------------------------
246// Callback Runner (internal)
247// ---------------------------------------------------------------------------
248
249/// Internal task that receives events from the broadcast channel and calls
250/// the callback. Panics in `on_change` are caught and forwarded to `on_error`.
251async fn callback_runner<C: SubscriptionCallback>(
252    mut receiver: broadcast::Receiver<ChangeEvent>,
253    callback: Arc<C>,
254) {
255    loop {
256        match receiver.recv().await {
257            Ok(event) => {
258                let cb = Arc::clone(&callback);
259                let result = std::panic::catch_unwind(AssertUnwindSafe(|| cb.on_change(event)));
260                if let Err(panic) = result {
261                    let msg = if let Some(s) = panic.downcast_ref::<&str>() {
262                        format!("callback panicked: {s}")
263                    } else if let Some(s) = panic.downcast_ref::<String>() {
264                        format!("callback panicked: {s}")
265                    } else {
266                        "callback panicked".to_string()
267                    };
268                    callback.on_error(PushSubscriptionError::Internal(msg));
269                }
270            }
271            Err(broadcast::error::RecvError::Lagged(n)) => {
272                callback.on_error(PushSubscriptionError::Lagged(n));
273                // Continue receiving after lag
274            }
275            Err(broadcast::error::RecvError::Closed) => {
276                callback.on_complete();
277                break;
278            }
279        }
280    }
281}
282
283// ===========================================================================
284// Tests
285// ===========================================================================
286
287#[cfg(test)]
288#[allow(clippy::cast_sign_loss)]
289#[allow(clippy::cast_possible_wrap)]
290#[allow(clippy::field_reassign_with_default)]
291#[allow(clippy::disallowed_types)] // test-only: Mutex for mock state
292mod tests {
293    use super::*;
294    use std::sync::Mutex;
295
296    use arrow_array::Int64Array;
297    use arrow_schema::{DataType, Field, Schema};
298
299    use crate::subscription::registry::SubscriptionState;
300
301    fn make_batch(n: usize) -> arrow_array::RecordBatch {
302        let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int64, false)]));
303        let values: Vec<i64> = (0..n as i64).collect();
304        let array = Int64Array::from(values);
305        arrow_array::RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
306    }
307
308    // --- Test callback implementation ---
309
310    #[derive(Clone)]
311    struct TestCallback {
312        events: Arc<Mutex<Vec<i64>>>,
313        errors: Arc<Mutex<Vec<String>>>,
314        completed: Arc<Mutex<bool>>,
315    }
316
317    impl TestCallback {
318        fn new() -> Self {
319            Self {
320                events: Arc::new(Mutex::new(Vec::new())),
321                errors: Arc::new(Mutex::new(Vec::new())),
322                completed: Arc::new(Mutex::new(false)),
323            }
324        }
325    }
326
327    impl SubscriptionCallback for TestCallback {
328        fn on_change(&self, event: ChangeEvent) {
329            self.events.lock().unwrap().push(event.timestamp());
330        }
331
332        fn on_error(&self, error: PushSubscriptionError) {
333            self.errors.lock().unwrap().push(format!("{error}"));
334        }
335
336        fn on_complete(&self) {
337            *self.completed.lock().unwrap() = true;
338        }
339    }
340
341    // --- Tests ---
342
343    #[tokio::test]
344    async fn test_callback_receives_events() {
345        let registry = Arc::new(SubscriptionRegistry::new());
346        let cb = TestCallback::new();
347        let events = Arc::clone(&cb.events);
348
349        let _handle = subscribe_callback(
350            Arc::clone(&registry),
351            "trades".into(),
352            0,
353            SubscriptionConfig::default(),
354            cb,
355        );
356
357        let senders = registry.get_senders_for_source(0);
358        assert_eq!(senders.len(), 1);
359
360        for i in 0..5i64 {
361            let batch = Arc::new(make_batch(1));
362            senders[0]
363                .send(ChangeEvent::insert(batch, i * 1000, i as u64))
364                .unwrap();
365        }
366
367        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
368
369        let received = events.lock().unwrap();
370        assert_eq!(received.len(), 5);
371        assert_eq!(*received, vec![0, 1000, 2000, 3000, 4000]);
372    }
373
374    #[tokio::test]
375    async fn test_callback_on_error_lagged() {
376        let registry = Arc::new(SubscriptionRegistry::new());
377        let mut cfg = SubscriptionConfig::default();
378        cfg.buffer_size = 4;
379        let cb = TestCallback::new();
380        let errors = Arc::clone(&cb.errors);
381        let events = Arc::clone(&cb.events);
382
383        let _handle = subscribe_callback(Arc::clone(&registry), "trades".into(), 0, cfg, cb);
384
385        let senders = registry.get_senders_for_source(0);
386
387        // Overflow the buffer to cause lag
388        for i in 0..20i64 {
389            let batch = Arc::new(make_batch(1));
390            let _ = senders[0].send(ChangeEvent::insert(batch, i * 100, i as u64));
391        }
392
393        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
394
395        let errs = errors.lock().unwrap();
396        assert!(!errs.is_empty(), "expected at least one lag error");
397        assert!(errs[0].contains("lagged behind"));
398
399        // Should still receive events after lag recovery
400        let evts = events.lock().unwrap();
401        assert!(!evts.is_empty(), "should receive events after lag");
402    }
403
404    #[tokio::test]
405    async fn test_callback_on_complete() {
406        let registry = Arc::new(SubscriptionRegistry::new());
407        let cb = TestCallback::new();
408        let completed = Arc::clone(&cb.completed);
409
410        let handle = subscribe_callback(
411            Arc::clone(&registry),
412            "trades".into(),
413            0,
414            SubscriptionConfig::default(),
415            cb,
416        );
417
418        // Cancel from registry side — drops sender → task gets Closed → on_complete
419        registry.cancel(handle.id());
420
421        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
422
423        assert!(*completed.lock().unwrap());
424    }
425
426    #[tokio::test]
427    async fn test_callback_panic_caught() {
428        struct PanickingCallback {
429            errors: Arc<Mutex<Vec<String>>>,
430        }
431
432        impl SubscriptionCallback for PanickingCallback {
433            fn on_change(&self, _event: ChangeEvent) {
434                panic!("deliberate test panic");
435            }
436
437            fn on_error(&self, error: PushSubscriptionError) {
438                self.errors.lock().unwrap().push(format!("{error}"));
439            }
440        }
441
442        let registry = Arc::new(SubscriptionRegistry::new());
443        let errors: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
444
445        let _handle = subscribe_callback(
446            Arc::clone(&registry),
447            "trades".into(),
448            0,
449            SubscriptionConfig::default(),
450            PanickingCallback {
451                errors: Arc::clone(&errors),
452            },
453        );
454
455        let senders = registry.get_senders_for_source(0);
456        let batch = Arc::new(make_batch(1));
457        senders[0]
458            .send(ChangeEvent::insert(batch, 1000, 1))
459            .unwrap();
460
461        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
462
463        let errs = errors.lock().unwrap();
464        assert_eq!(errs.len(), 1);
465        assert!(errs[0].contains("callback panicked"));
466        assert!(errs[0].contains("deliberate test panic"));
467    }
468
469    #[tokio::test]
470    async fn test_callback_handle_pause_resume() {
471        let registry = Arc::new(SubscriptionRegistry::new());
472        let cb = TestCallback::new();
473
474        let handle = subscribe_callback(
475            Arc::clone(&registry),
476            "trades".into(),
477            0,
478            SubscriptionConfig::default(),
479            cb,
480        );
481
482        assert!(handle.pause());
483        assert_eq!(registry.state(handle.id()), Some(SubscriptionState::Paused));
484
485        // Already paused
486        assert!(!handle.pause());
487
488        assert!(handle.resume());
489        assert_eq!(registry.state(handle.id()), Some(SubscriptionState::Active));
490
491        // Already active
492        assert!(!handle.resume());
493    }
494
495    #[tokio::test]
496    async fn test_callback_handle_cancel() {
497        let registry = Arc::new(SubscriptionRegistry::new());
498        let cb = TestCallback::new();
499
500        let mut handle = subscribe_callback(
501            Arc::clone(&registry),
502            "trades".into(),
503            0,
504            SubscriptionConfig::default(),
505            cb,
506        );
507
508        assert_eq!(registry.subscription_count(), 1);
509        assert!(!handle.is_cancelled());
510
511        handle.cancel();
512
513        assert!(handle.is_cancelled());
514        assert_eq!(registry.subscription_count(), 0);
515
516        // Idempotent
517        handle.cancel();
518        assert_eq!(registry.subscription_count(), 0);
519    }
520
521    #[tokio::test]
522    async fn test_callback_handle_drop_cancels() {
523        let registry = Arc::new(SubscriptionRegistry::new());
524        let cb = TestCallback::new();
525
526        {
527            let _handle = subscribe_callback(
528                Arc::clone(&registry),
529                "trades".into(),
530                0,
531                SubscriptionConfig::default(),
532                cb,
533            );
534            assert_eq!(registry.subscription_count(), 1);
535        }
536        // Dropped — should be cancelled
537        assert_eq!(registry.subscription_count(), 0);
538    }
539
540    #[tokio::test]
541    async fn test_subscribe_fn() {
542        let registry = Arc::new(SubscriptionRegistry::new());
543        let events: Arc<Mutex<Vec<i64>>> = Arc::new(Mutex::new(Vec::new()));
544        let events_clone = Arc::clone(&events);
545
546        let _handle = subscribe_fn(
547            Arc::clone(&registry),
548            "trades".into(),
549            0,
550            SubscriptionConfig::default(),
551            move |event| {
552                events_clone.lock().unwrap().push(event.timestamp());
553            },
554        );
555
556        let senders = registry.get_senders_for_source(0);
557        let batch = Arc::new(make_batch(1));
558        senders[0]
559            .send(ChangeEvent::insert(batch, 5000, 1))
560            .unwrap();
561
562        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
563
564        let received = events.lock().unwrap();
565        assert_eq!(*received, vec![5000]);
566    }
567
568    #[tokio::test]
569    async fn test_callback_ordering() {
570        let registry = Arc::new(SubscriptionRegistry::new());
571        let cb = TestCallback::new();
572        let events = Arc::clone(&cb.events);
573
574        let _handle = subscribe_callback(
575            Arc::clone(&registry),
576            "trades".into(),
577            0,
578            SubscriptionConfig::default(),
579            cb,
580        );
581
582        let senders = registry.get_senders_for_source(0);
583
584        for i in 0..10i64 {
585            let batch = Arc::new(make_batch(1));
586            senders[0]
587                .send(ChangeEvent::insert(batch, i, i as u64))
588                .unwrap();
589        }
590
591        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
592
593        let received = events.lock().unwrap();
594        assert_eq!(received.len(), 10);
595        let expected: Vec<i64> = (0..10).collect();
596        assert_eq!(*received, expected);
597    }
598
599    #[tokio::test]
600    async fn test_callback_handle_metrics() {
601        let registry = Arc::new(SubscriptionRegistry::new());
602        let cb = TestCallback::new();
603
604        let handle = subscribe_callback(
605            Arc::clone(&registry),
606            "trades".into(),
607            0,
608            SubscriptionConfig::default(),
609            cb,
610        );
611
612        let m = handle.metrics().unwrap();
613        assert_eq!(m.id, handle.id());
614        assert_eq!(m.source_name, "trades");
615        assert_eq!(m.state, SubscriptionState::Active);
616    }
617}