Skip to main content

laminar_core/streaming/
subscription.rs

1//! Streaming Subscription API.
2//!
3//! A Subscription provides access to records from a Sink. It supports:
4//!
5//! - Non-blocking poll
6//! - Blocking receive with optional timeout
7//! - Iterator interface
8//! - Zero-allocation batch operations
9//!
10//! ## Usage
11//!
12//! ```rust,ignore
13//! let subscription = sink.subscribe();
14//!
15//! // Non-blocking poll
16//! while let Some(batch) = subscription.poll() {
17//!     process(batch);
18//! }
19//!
20//! // Blocking receive
21//! let batch = subscription.recv()?;
22//!
23//! // With timeout
24//! let batch = subscription.recv_timeout(Duration::from_secs(1))?;
25//!
26//! // As iterator
27//! for batch in subscription {
28//!     process(batch);
29//! }
30//! ```
31
32use std::sync::Arc;
33use std::time::{Duration, Instant};
34
35use arrow::array::RecordBatch;
36use arrow::datatypes::SchemaRef;
37
38use super::channel::Consumer;
39use super::error::RecvError;
40use super::sink::SinkInner;
41use super::source::{Record, SourceMessage};
42
43/// A subscription to a streaming sink.
44///
45/// Subscriptions receive records from a Sink and provide them via
46/// polling, blocking receive, or iterator interfaces.
47///
48/// ## Modes
49///
50/// - **Direct**: First subscriber, reads directly from source channel
51/// - **Broadcast**: Additional subscribers, reads from dedicated channel
52pub struct Subscription<T: Record> {
53    inner: SubscriptionInner<T>,
54    schema: SchemaRef,
55}
56
57enum SubscriptionInner<T: Record> {
58    /// Direct subscription to sink's consumer.
59    Direct(Arc<SinkInner<T>>),
60    /// Broadcast subscription with dedicated channel.
61    Broadcast(Consumer<SourceMessage<T>>),
62}
63
64impl<T: Record> Subscription<T> {
65    /// Creates a direct subscription (first subscriber).
66    pub(crate) fn new_direct(sink_inner: Arc<SinkInner<T>>) -> Self {
67        let schema = sink_inner.schema();
68        Self {
69            inner: SubscriptionInner::Direct(sink_inner),
70            schema,
71        }
72    }
73
74    /// Creates a broadcast subscription (additional subscribers).
75    pub(crate) fn new_broadcast(consumer: Consumer<SourceMessage<T>>, schema: SchemaRef) -> Self {
76        Self {
77            inner: SubscriptionInner::Broadcast(consumer),
78            schema,
79        }
80    }
81
82    /// Polls for the next record batch without blocking.
83    ///
84    /// Returns `Some(RecordBatch)` if data is available, `None` if empty.
85    ///
86    /// Records are automatically converted to Arrow `RecordBatch` format.
87    #[must_use]
88    pub fn poll(&self) -> Option<RecordBatch> {
89        let msg = match &self.inner {
90            SubscriptionInner::Direct(sink) => sink.consumer().poll(),
91            SubscriptionInner::Broadcast(consumer) => consumer.poll(),
92        }?;
93
94        Self::message_to_batch(msg)
95    }
96
97    /// Polls for raw messages (without conversion to `RecordBatch`).
98    ///
99    /// This is useful when you need to handle watermarks separately.
100    #[must_use]
101    pub fn poll_message(&self) -> Option<SubscriptionMessage<T>> {
102        let msg = match &self.inner {
103            SubscriptionInner::Direct(sink) => sink.consumer().poll(),
104            SubscriptionInner::Broadcast(consumer) => consumer.poll(),
105        }?;
106
107        Some(Self::convert_message(msg))
108    }
109
110    /// Receives the next record batch, blocking until available.
111    ///
112    /// # Errors
113    ///
114    /// Returns `RecvError::Disconnected` if the source has been dropped
115    /// and there are no more buffered records.
116    pub fn recv(&self) -> Result<RecordBatch, RecvError> {
117        let mut spins = 0u32;
118        loop {
119            if let Some(batch) = self.poll() {
120                return Ok(batch);
121            }
122
123            if self.is_disconnected() {
124                return Err(RecvError::Disconnected);
125            }
126
127            // Progressive backoff: spin → yield → park
128            if spins < 64 {
129                std::hint::spin_loop();
130            } else if spins < 128 {
131                std::thread::yield_now();
132            } else {
133                std::thread::park_timeout(Duration::from_micros(100));
134            }
135            spins = spins.saturating_add(1);
136        }
137    }
138
139    /// Receives the next record batch with a timeout.
140    ///
141    /// # Errors
142    ///
143    /// Returns `RecvError::Timeout` if no record becomes available within the timeout.
144    /// Returns `RecvError::Disconnected` if the source has been dropped.
145    pub fn recv_timeout(&self, timeout: Duration) -> Result<RecordBatch, RecvError> {
146        let deadline = Instant::now() + timeout;
147        let mut spins = 0u32;
148
149        loop {
150            if let Some(batch) = self.poll() {
151                return Ok(batch);
152            }
153
154            if self.is_disconnected() {
155                return Err(RecvError::Disconnected);
156            }
157
158            if Instant::now() >= deadline {
159                return Err(RecvError::Timeout);
160            }
161
162            // Progressive backoff: spin → yield → park
163            if spins < 64 {
164                std::hint::spin_loop();
165            } else if spins < 128 {
166                std::thread::yield_now();
167            } else {
168                let remaining = deadline.saturating_duration_since(Instant::now());
169                std::thread::park_timeout(remaining.min(Duration::from_micros(100)));
170            }
171            spins = spins.saturating_add(1);
172        }
173    }
174
175    /// Polls multiple record batches into a vector.
176    ///
177    /// Returns up to `max_count` batches.
178    ///
179    /// # Performance Warning
180    ///
181    /// **This method allocates a `Vec` on every call.** Do not use on hot paths
182    /// where allocation overhead matters. For zero-allocation consumption, use
183    /// [`poll_each`](Self::poll_each) or [`poll_batch_into`](Self::poll_batch_into).
184    #[cold]
185    #[must_use]
186    pub fn poll_batch(&self, max_count: usize) -> Vec<RecordBatch> {
187        let mut batches = Vec::with_capacity(max_count);
188
189        for _ in 0..max_count {
190            if let Some(batch) = self.poll() {
191                batches.push(batch);
192            } else {
193                break;
194            }
195        }
196
197        batches
198    }
199
200    /// Polls multiple record batches into a pre-allocated vector (zero-allocation).
201    ///
202    /// Appends up to `max_count` batches to the provided vector.
203    /// Returns the number of batches added.
204    ///
205    /// # Example
206    ///
207    /// ```rust,ignore
208    /// let mut buffer = Vec::with_capacity(100);
209    /// loop {
210    ///     buffer.clear();
211    ///     let count = subscription.poll_batch_into(&mut buffer, 100);
212    ///     if count == 0 { break; }
213    ///     for batch in &buffer {
214    ///         process(batch);
215    ///     }
216    /// }
217    /// ```
218    pub fn poll_batch_into(&self, buffer: &mut Vec<RecordBatch>, max_count: usize) -> usize {
219        let mut count = 0;
220
221        for _ in 0..max_count {
222            if let Some(batch) = self.poll() {
223                buffer.push(batch);
224                count += 1;
225            } else {
226                break;
227            }
228        }
229
230        count
231    }
232
233    /// Processes records with a callback (zero-allocation).
234    ///
235    /// The callback receives each `RecordBatch`. Processing stops when:
236    /// - `max_count` batches have been processed
237    /// - No more batches are available
238    /// - The callback returns `false`
239    ///
240    /// Returns the number of batches processed.
241    pub fn poll_each<F>(&self, max_count: usize, mut f: F) -> usize
242    where
243        F: FnMut(RecordBatch) -> bool,
244    {
245        let mut count = 0;
246
247        for _ in 0..max_count {
248            if let Some(batch) = self.poll() {
249                count += 1;
250                if !f(batch) {
251                    break;
252                }
253            } else {
254                break;
255            }
256        }
257
258        count
259    }
260
261    /// Returns true if the source has been dropped and buffer is empty.
262    #[must_use]
263    pub fn is_disconnected(&self) -> bool {
264        match &self.inner {
265            SubscriptionInner::Direct(sink) => sink.is_disconnected(),
266            SubscriptionInner::Broadcast(consumer) => consumer.is_disconnected(),
267        }
268    }
269
270    /// Returns the number of pending items.
271    #[must_use]
272    pub fn pending(&self) -> usize {
273        match &self.inner {
274            SubscriptionInner::Direct(sink) => sink.consumer().len(),
275            SubscriptionInner::Broadcast(consumer) => consumer.len(),
276        }
277    }
278
279    /// Returns the schema for records in this subscription.
280    #[must_use]
281    pub fn schema(&self) -> SchemaRef {
282        Arc::clone(&self.schema)
283    }
284
285    fn message_to_batch(msg: SourceMessage<T>) -> Option<RecordBatch> {
286        match msg {
287            SourceMessage::Record(record) => Some(record.to_record_batch()),
288            SourceMessage::Batch(batch) => Some(batch),
289            SourceMessage::Watermark(_) => {
290                // Skip watermarks in poll(), they're handled separately
291                None
292            }
293        }
294    }
295
296    fn convert_message(msg: SourceMessage<T>) -> SubscriptionMessage<T> {
297        match msg {
298            SourceMessage::Record(record) => SubscriptionMessage::Record(record),
299            SourceMessage::Batch(batch) => SubscriptionMessage::Batch(batch),
300            SourceMessage::Watermark(ts) => SubscriptionMessage::Watermark(ts),
301        }
302    }
303}
304
305/// Message types that can be received from a subscription.
306#[derive(Debug)]
307pub enum SubscriptionMessage<T> {
308    /// A single record.
309    Record(T),
310    /// A batch of records.
311    Batch(RecordBatch),
312    /// A watermark timestamp.
313    Watermark(i64),
314}
315
316impl<T: Record> SubscriptionMessage<T> {
317    /// Returns true if this is a record message.
318    #[must_use]
319    pub fn is_record(&self) -> bool {
320        matches!(self, Self::Record(_))
321    }
322
323    /// Returns true if this is a batch message.
324    #[must_use]
325    pub fn is_batch(&self) -> bool {
326        matches!(self, Self::Batch(_))
327    }
328
329    /// Returns true if this is a watermark message.
330    #[must_use]
331    pub fn is_watermark(&self) -> bool {
332        matches!(self, Self::Watermark(_))
333    }
334
335    /// Converts to a `RecordBatch` if this is a data message.
336    #[must_use]
337    pub fn to_batch(self) -> Option<RecordBatch> {
338        match self {
339            Self::Record(r) => Some(r.to_record_batch()),
340            Self::Batch(b) => Some(b),
341            Self::Watermark(_) => None,
342        }
343    }
344
345    /// Returns the watermark timestamp if this is a watermark message.
346    #[must_use]
347    pub fn watermark(&self) -> Option<i64> {
348        match self {
349            Self::Watermark(ts) => Some(*ts),
350            _ => None,
351        }
352    }
353}
354
355/// Iterator implementation for Subscription.
356///
357/// Iterates over record batches, blocking on each call to `next()`.
358/// Iteration stops when the source is disconnected.
359impl<T: Record> Iterator for Subscription<T> {
360    type Item = RecordBatch;
361
362    fn next(&mut self) -> Option<Self::Item> {
363        self.recv().ok()
364    }
365}
366
367impl<T: Record + std::fmt::Debug> std::fmt::Debug for Subscription<T> {
368    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
369        let mode = match &self.inner {
370            SubscriptionInner::Direct(_) => "Direct",
371            SubscriptionInner::Broadcast(_) => "Broadcast",
372        };
373
374        f.debug_struct("Subscription")
375            .field("mode", &mode)
376            .field("pending", &self.pending())
377            .field("is_disconnected", &self.is_disconnected())
378            .field("schema", &self.schema)
379            .finish()
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386    use crate::streaming::source::create;
387    use arrow::array::{Float64Array, Int64Array};
388    use arrow::datatypes::{DataType, Field, Schema};
389    use std::sync::Arc;
390
391    #[derive(Clone, Debug)]
392    struct TestEvent {
393        id: i64,
394        value: f64,
395    }
396
397    impl Record for TestEvent {
398        fn schema() -> SchemaRef {
399            Arc::new(Schema::new(vec![
400                Field::new("id", DataType::Int64, false),
401                Field::new("value", DataType::Float64, false),
402            ]))
403        }
404
405        fn to_record_batch(&self) -> RecordBatch {
406            RecordBatch::try_new(
407                Self::schema(),
408                vec![
409                    Arc::new(Int64Array::from(vec![self.id])),
410                    Arc::new(Float64Array::from(vec![self.value])),
411                ],
412            )
413            .unwrap()
414        }
415    }
416
417    #[test]
418    fn test_poll_empty() {
419        let (_source, sink) = create::<TestEvent>(16);
420        let sub = sink.subscribe();
421
422        assert!(sub.poll().is_none());
423    }
424
425    #[test]
426    fn test_poll_records() {
427        let (source, sink) = create::<TestEvent>(16);
428        let sub = sink.subscribe();
429
430        source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
431        source.push(TestEvent { id: 2, value: 2.0 }).unwrap();
432
433        let batch1 = sub.poll().unwrap();
434        assert_eq!(batch1.num_rows(), 1);
435
436        let batch2 = sub.poll().unwrap();
437        assert_eq!(batch2.num_rows(), 1);
438
439        assert!(sub.poll().is_none());
440    }
441
442    #[test]
443    fn test_poll_message() {
444        let (source, sink) = create::<TestEvent>(16);
445        let sub = sink.subscribe();
446
447        source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
448
449        let msg = sub.poll_message().unwrap();
450        assert!(msg.is_record());
451    }
452
453    #[test]
454    fn test_recv_timeout() {
455        let (_source, sink) = create::<TestEvent>(16);
456        let sub = sink.subscribe();
457
458        // Should timeout on empty subscription
459        let result = sub.recv_timeout(Duration::from_millis(10));
460        assert!(matches!(result, Err(RecvError::Timeout)));
461    }
462
463    #[test]
464    fn test_recv_timeout_success() {
465        let (source, sink) = create::<TestEvent>(16);
466        let sub = sink.subscribe();
467
468        source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
469
470        let result = sub.recv_timeout(Duration::from_secs(1));
471        assert!(result.is_ok());
472    }
473
474    #[test]
475    fn test_poll_batch() {
476        let (source, sink) = create::<TestEvent>(16);
477        let sub = sink.subscribe();
478
479        source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
480        source.push(TestEvent { id: 2, value: 2.0 }).unwrap();
481        source.push(TestEvent { id: 3, value: 3.0 }).unwrap();
482
483        let batches = sub.poll_batch(10);
484        assert_eq!(batches.len(), 3);
485    }
486
487    #[test]
488    fn test_poll_each() {
489        let (source, sink) = create::<TestEvent>(16);
490        let sub = sink.subscribe();
491
492        source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
493        source.push(TestEvent { id: 2, value: 2.0 }).unwrap();
494
495        let mut total_rows = 0;
496        let count = sub.poll_each(10, |batch| {
497            total_rows += batch.num_rows();
498            true
499        });
500
501        assert_eq!(count, 2);
502        assert_eq!(total_rows, 2);
503    }
504
505    #[test]
506    fn test_poll_each_early_stop() {
507        let (source, sink) = create::<TestEvent>(16);
508        let sub = sink.subscribe();
509
510        source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
511        source.push(TestEvent { id: 2, value: 2.0 }).unwrap();
512        source.push(TestEvent { id: 3, value: 3.0 }).unwrap();
513
514        let mut seen = 0;
515        let count = sub.poll_each(10, |_| {
516            seen += 1;
517            seen < 2 // Stop after 2
518        });
519
520        assert_eq!(count, 2);
521        assert_eq!(seen, 2);
522        assert_eq!(sub.pending(), 1); // One left
523    }
524
525    #[test]
526    fn test_disconnected() {
527        let (source, sink) = create::<TestEvent>(16);
528        let sub = sink.subscribe();
529
530        assert!(!sub.is_disconnected());
531
532        drop(source);
533
534        assert!(sub.is_disconnected());
535    }
536
537    #[test]
538    fn test_pending() {
539        let (source, sink) = create::<TestEvent>(16);
540        let sub = sink.subscribe();
541
542        assert_eq!(sub.pending(), 0);
543
544        source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
545        source.push(TestEvent { id: 2, value: 2.0 }).unwrap();
546
547        assert_eq!(sub.pending(), 2);
548    }
549
550    #[test]
551    fn test_schema() {
552        let (_source, sink) = create::<TestEvent>(16);
553        let sub = sink.subscribe();
554
555        let schema = sub.schema();
556        assert_eq!(schema.fields().len(), 2);
557    }
558
559    #[test]
560    fn test_subscription_message() {
561        let msg = SubscriptionMessage::Record(TestEvent { id: 1, value: 1.0 });
562        assert!(msg.is_record());
563        assert!(!msg.is_batch());
564        assert!(!msg.is_watermark());
565
566        let batch = msg.to_batch().unwrap();
567        assert_eq!(batch.num_rows(), 1);
568
569        let wm = SubscriptionMessage::<TestEvent>::Watermark(1000);
570        assert!(wm.is_watermark());
571        assert_eq!(wm.watermark(), Some(1000));
572    }
573
574    #[test]
575    fn test_iterator() {
576        let (source, sink) = create::<TestEvent>(16);
577        let mut sub = sink.subscribe();
578
579        source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
580        source.push(TestEvent { id: 2, value: 2.0 }).unwrap();
581
582        drop(source);
583
584        let batches: Vec<_> = sub.by_ref().collect();
585        assert_eq!(batches.len(), 2);
586    }
587
588    #[test]
589    fn test_debug_format() {
590        let (_source, sink) = create::<TestEvent>(16);
591        let sub = sink.subscribe();
592
593        let debug = format!("{sub:?}");
594        assert!(debug.contains("Subscription"));
595        assert!(debug.contains("Direct"));
596    }
597}