laminar_core/subscription/
callback.rs1use std::panic::AssertUnwindSafe;
29use std::sync::Arc;
30
31use tokio::sync::broadcast;
32
33use crate::subscription::event::ChangeEvent;
34use crate::subscription::handle::PushSubscriptionError;
35use crate::subscription::registry::{
36 SubscriptionConfig, SubscriptionId, SubscriptionMetrics, SubscriptionRegistry,
37};
38
39pub trait SubscriptionCallback: Send + Sync + 'static {
64 fn on_change(&self, event: ChangeEvent);
66
67 fn on_error(&self, error: PushSubscriptionError) {
71 tracing::warn!("subscription callback error: {}", error);
72 }
73
74 fn on_complete(&self) {}
78}
79
80struct FnCallback<F>(F);
86
87impl<F: Fn(ChangeEvent) + Send + Sync + 'static> SubscriptionCallback for FnCallback<F> {
88 fn on_change(&self, event: ChangeEvent) {
89 (self.0)(event);
90 }
91}
92
93pub struct CallbackSubscriptionHandle {
107 id: SubscriptionId,
109 registry: Arc<SubscriptionRegistry>,
111 task: Option<tokio::task::JoinHandle<()>>,
113 cancelled: bool,
115}
116
117impl CallbackSubscriptionHandle {
118 #[must_use]
124 pub fn pause(&self) -> bool {
125 self.registry.pause(self.id)
126 }
127
128 #[must_use]
132 pub fn resume(&self) -> bool {
133 self.registry.resume(self.id)
134 }
135
136 pub fn cancel(&mut self) {
141 if !self.cancelled {
142 self.cancelled = true;
143 self.registry.cancel(self.id);
144 if let Some(task) = self.task.take() {
145 task.abort();
146 }
147 }
148 }
149
150 #[must_use]
152 pub fn id(&self) -> SubscriptionId {
153 self.id
154 }
155
156 #[must_use]
158 pub fn metrics(&self) -> Option<SubscriptionMetrics> {
159 self.registry.metrics(self.id)
160 }
161
162 #[must_use]
164 pub fn is_cancelled(&self) -> bool {
165 self.cancelled
166 }
167}
168
169impl Drop for CallbackSubscriptionHandle {
170 fn drop(&mut self) {
171 if !self.cancelled {
172 self.registry.cancel(self.id);
173 if let Some(task) = self.task.take() {
174 task.abort();
175 }
176 }
177 }
178}
179
180pub fn subscribe_callback<C: SubscriptionCallback>(
201 registry: Arc<SubscriptionRegistry>,
202 source_name: String,
203 source_id: u32,
204 config: SubscriptionConfig,
205 callback: C,
206) -> CallbackSubscriptionHandle {
207 let (id, receiver) = registry.create(source_name, source_id, config);
208 let callback = Arc::new(callback);
209
210 let task = tokio::spawn(callback_runner(receiver, callback));
211
212 CallbackSubscriptionHandle {
213 id,
214 registry,
215 task: Some(task),
216 cancelled: false,
217 }
218}
219
220pub fn subscribe_fn<F>(
233 registry: Arc<SubscriptionRegistry>,
234 source_name: String,
235 source_id: u32,
236 config: SubscriptionConfig,
237 f: F,
238) -> CallbackSubscriptionHandle
239where
240 F: Fn(ChangeEvent) + Send + Sync + 'static,
241{
242 subscribe_callback(registry, source_name, source_id, config, FnCallback(f))
243}
244
245async fn callback_runner<C: SubscriptionCallback>(
252 mut receiver: broadcast::Receiver<ChangeEvent>,
253 callback: Arc<C>,
254) {
255 loop {
256 match receiver.recv().await {
257 Ok(event) => {
258 let cb = Arc::clone(&callback);
259 let result = std::panic::catch_unwind(AssertUnwindSafe(|| cb.on_change(event)));
260 if let Err(panic) = result {
261 let msg = if let Some(s) = panic.downcast_ref::<&str>() {
262 format!("callback panicked: {s}")
263 } else if let Some(s) = panic.downcast_ref::<String>() {
264 format!("callback panicked: {s}")
265 } else {
266 "callback panicked".to_string()
267 };
268 callback.on_error(PushSubscriptionError::Internal(msg));
269 }
270 }
271 Err(broadcast::error::RecvError::Lagged(n)) => {
272 callback.on_error(PushSubscriptionError::Lagged(n));
273 }
275 Err(broadcast::error::RecvError::Closed) => {
276 callback.on_complete();
277 break;
278 }
279 }
280 }
281}
282
283#[cfg(test)]
288#[allow(clippy::cast_sign_loss)]
289#[allow(clippy::cast_possible_wrap)]
290#[allow(clippy::field_reassign_with_default)]
291#[allow(clippy::disallowed_types)] mod tests {
293 use super::*;
294 use std::sync::Mutex;
295
296 use arrow_array::Int64Array;
297 use arrow_schema::{DataType, Field, Schema};
298
299 use crate::subscription::registry::SubscriptionState;
300
301 fn make_batch(n: usize) -> arrow_array::RecordBatch {
302 let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int64, false)]));
303 let values: Vec<i64> = (0..n as i64).collect();
304 let array = Int64Array::from(values);
305 arrow_array::RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
306 }
307
308 #[derive(Clone)]
311 struct TestCallback {
312 events: Arc<Mutex<Vec<i64>>>,
313 errors: Arc<Mutex<Vec<String>>>,
314 completed: Arc<Mutex<bool>>,
315 }
316
317 impl TestCallback {
318 fn new() -> Self {
319 Self {
320 events: Arc::new(Mutex::new(Vec::new())),
321 errors: Arc::new(Mutex::new(Vec::new())),
322 completed: Arc::new(Mutex::new(false)),
323 }
324 }
325 }
326
327 impl SubscriptionCallback for TestCallback {
328 fn on_change(&self, event: ChangeEvent) {
329 self.events.lock().unwrap().push(event.timestamp());
330 }
331
332 fn on_error(&self, error: PushSubscriptionError) {
333 self.errors.lock().unwrap().push(format!("{error}"));
334 }
335
336 fn on_complete(&self) {
337 *self.completed.lock().unwrap() = true;
338 }
339 }
340
341 #[tokio::test]
344 async fn test_callback_receives_events() {
345 let registry = Arc::new(SubscriptionRegistry::new());
346 let cb = TestCallback::new();
347 let events = Arc::clone(&cb.events);
348
349 let _handle = subscribe_callback(
350 Arc::clone(®istry),
351 "trades".into(),
352 0,
353 SubscriptionConfig::default(),
354 cb,
355 );
356
357 let senders = registry.get_senders_for_source(0);
358 assert_eq!(senders.len(), 1);
359
360 for i in 0..5i64 {
361 let batch = Arc::new(make_batch(1));
362 senders[0]
363 .send(ChangeEvent::insert(batch, i * 1000, i as u64))
364 .unwrap();
365 }
366
367 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
368
369 let received = events.lock().unwrap();
370 assert_eq!(received.len(), 5);
371 assert_eq!(*received, vec![0, 1000, 2000, 3000, 4000]);
372 }
373
374 #[tokio::test]
375 async fn test_callback_on_error_lagged() {
376 let registry = Arc::new(SubscriptionRegistry::new());
377 let mut cfg = SubscriptionConfig::default();
378 cfg.buffer_size = 4;
379 let cb = TestCallback::new();
380 let errors = Arc::clone(&cb.errors);
381 let events = Arc::clone(&cb.events);
382
383 let _handle = subscribe_callback(Arc::clone(®istry), "trades".into(), 0, cfg, cb);
384
385 let senders = registry.get_senders_for_source(0);
386
387 for i in 0..20i64 {
389 let batch = Arc::new(make_batch(1));
390 let _ = senders[0].send(ChangeEvent::insert(batch, i * 100, i as u64));
391 }
392
393 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
394
395 let errs = errors.lock().unwrap();
396 assert!(!errs.is_empty(), "expected at least one lag error");
397 assert!(errs[0].contains("lagged behind"));
398
399 let evts = events.lock().unwrap();
401 assert!(!evts.is_empty(), "should receive events after lag");
402 }
403
404 #[tokio::test]
405 async fn test_callback_on_complete() {
406 let registry = Arc::new(SubscriptionRegistry::new());
407 let cb = TestCallback::new();
408 let completed = Arc::clone(&cb.completed);
409
410 let handle = subscribe_callback(
411 Arc::clone(®istry),
412 "trades".into(),
413 0,
414 SubscriptionConfig::default(),
415 cb,
416 );
417
418 registry.cancel(handle.id());
420
421 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
422
423 assert!(*completed.lock().unwrap());
424 }
425
426 #[tokio::test]
427 async fn test_callback_panic_caught() {
428 struct PanickingCallback {
429 errors: Arc<Mutex<Vec<String>>>,
430 }
431
432 impl SubscriptionCallback for PanickingCallback {
433 fn on_change(&self, _event: ChangeEvent) {
434 panic!("deliberate test panic");
435 }
436
437 fn on_error(&self, error: PushSubscriptionError) {
438 self.errors.lock().unwrap().push(format!("{error}"));
439 }
440 }
441
442 let registry = Arc::new(SubscriptionRegistry::new());
443 let errors: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
444
445 let _handle = subscribe_callback(
446 Arc::clone(®istry),
447 "trades".into(),
448 0,
449 SubscriptionConfig::default(),
450 PanickingCallback {
451 errors: Arc::clone(&errors),
452 },
453 );
454
455 let senders = registry.get_senders_for_source(0);
456 let batch = Arc::new(make_batch(1));
457 senders[0]
458 .send(ChangeEvent::insert(batch, 1000, 1))
459 .unwrap();
460
461 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
462
463 let errs = errors.lock().unwrap();
464 assert_eq!(errs.len(), 1);
465 assert!(errs[0].contains("callback panicked"));
466 assert!(errs[0].contains("deliberate test panic"));
467 }
468
469 #[tokio::test]
470 async fn test_callback_handle_pause_resume() {
471 let registry = Arc::new(SubscriptionRegistry::new());
472 let cb = TestCallback::new();
473
474 let handle = subscribe_callback(
475 Arc::clone(®istry),
476 "trades".into(),
477 0,
478 SubscriptionConfig::default(),
479 cb,
480 );
481
482 assert!(handle.pause());
483 assert_eq!(registry.state(handle.id()), Some(SubscriptionState::Paused));
484
485 assert!(!handle.pause());
487
488 assert!(handle.resume());
489 assert_eq!(registry.state(handle.id()), Some(SubscriptionState::Active));
490
491 assert!(!handle.resume());
493 }
494
495 #[tokio::test]
496 async fn test_callback_handle_cancel() {
497 let registry = Arc::new(SubscriptionRegistry::new());
498 let cb = TestCallback::new();
499
500 let mut handle = subscribe_callback(
501 Arc::clone(®istry),
502 "trades".into(),
503 0,
504 SubscriptionConfig::default(),
505 cb,
506 );
507
508 assert_eq!(registry.subscription_count(), 1);
509 assert!(!handle.is_cancelled());
510
511 handle.cancel();
512
513 assert!(handle.is_cancelled());
514 assert_eq!(registry.subscription_count(), 0);
515
516 handle.cancel();
518 assert_eq!(registry.subscription_count(), 0);
519 }
520
521 #[tokio::test]
522 async fn test_callback_handle_drop_cancels() {
523 let registry = Arc::new(SubscriptionRegistry::new());
524 let cb = TestCallback::new();
525
526 {
527 let _handle = subscribe_callback(
528 Arc::clone(®istry),
529 "trades".into(),
530 0,
531 SubscriptionConfig::default(),
532 cb,
533 );
534 assert_eq!(registry.subscription_count(), 1);
535 }
536 assert_eq!(registry.subscription_count(), 0);
538 }
539
540 #[tokio::test]
541 async fn test_subscribe_fn() {
542 let registry = Arc::new(SubscriptionRegistry::new());
543 let events: Arc<Mutex<Vec<i64>>> = Arc::new(Mutex::new(Vec::new()));
544 let events_clone = Arc::clone(&events);
545
546 let _handle = subscribe_fn(
547 Arc::clone(®istry),
548 "trades".into(),
549 0,
550 SubscriptionConfig::default(),
551 move |event| {
552 events_clone.lock().unwrap().push(event.timestamp());
553 },
554 );
555
556 let senders = registry.get_senders_for_source(0);
557 let batch = Arc::new(make_batch(1));
558 senders[0]
559 .send(ChangeEvent::insert(batch, 5000, 1))
560 .unwrap();
561
562 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
563
564 let received = events.lock().unwrap();
565 assert_eq!(*received, vec![5000]);
566 }
567
568 #[tokio::test]
569 async fn test_callback_ordering() {
570 let registry = Arc::new(SubscriptionRegistry::new());
571 let cb = TestCallback::new();
572 let events = Arc::clone(&cb.events);
573
574 let _handle = subscribe_callback(
575 Arc::clone(®istry),
576 "trades".into(),
577 0,
578 SubscriptionConfig::default(),
579 cb,
580 );
581
582 let senders = registry.get_senders_for_source(0);
583
584 for i in 0..10i64 {
585 let batch = Arc::new(make_batch(1));
586 senders[0]
587 .send(ChangeEvent::insert(batch, i, i as u64))
588 .unwrap();
589 }
590
591 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
592
593 let received = events.lock().unwrap();
594 assert_eq!(received.len(), 10);
595 let expected: Vec<i64> = (0..10).collect();
596 assert_eq!(*received, expected);
597 }
598
599 #[tokio::test]
600 async fn test_callback_handle_metrics() {
601 let registry = Arc::new(SubscriptionRegistry::new());
602 let cb = TestCallback::new();
603
604 let handle = subscribe_callback(
605 Arc::clone(®istry),
606 "trades".into(),
607 0,
608 SubscriptionConfig::default(),
609 cb,
610 );
611
612 let m = handle.metrics().unwrap();
613 assert_eq!(m.id, handle.id());
614 assert_eq!(m.source_name, "trades");
615 assert_eq!(m.state, SubscriptionState::Active);
616 }
617}