Skip to main content

laminar_core/streaming/
subscription.rs

1//! Subscription — receive records from a Sink.
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use arrow::array::RecordBatch;
7use arrow::datatypes::SchemaRef;
8use tokio::sync::broadcast;
9
10use super::error::RecvError;
11use super::source::{Record, SourceMessage};
12
13/// A subscription to a streaming sink. Each subscriber independently receives
14/// every message via broadcast.
15pub struct Subscription<T: Record> {
16    rx: broadcast::Receiver<SourceMessage<T>>,
17    schema: SchemaRef,
18    closed: bool,
19}
20
21impl<T: Record> Subscription<T> {
22    pub(crate) fn new(rx: broadcast::Receiver<SourceMessage<T>>, schema: SchemaRef) -> Self {
23        Self {
24            rx,
25            schema,
26            closed: false,
27        }
28    }
29
30    /// Non-blocking poll. Returns the next batch.
31    /// Returns `None` on empty or closed channel. Check `is_disconnected()`
32    /// to distinguish.
33    pub fn poll(&mut self) -> Option<RecordBatch> {
34        loop {
35            match self.rx.try_recv() {
36                Ok(msg) => return Some(to_batch(msg)),
37                Err(broadcast::error::TryRecvError::Empty) => return None,
38                Err(broadcast::error::TryRecvError::Closed) => {
39                    self.closed = true;
40                    return None;
41                }
42                Err(broadcast::error::TryRecvError::Lagged(_)) => {}
43            }
44        }
45    }
46
47    /// Async receive. Awaits the next batch.
48    ///
49    /// # Errors
50    ///
51    /// Returns `RecvError::Disconnected` if the source has been dropped.
52    pub async fn recv_async(&mut self) -> Result<RecordBatch, RecvError> {
53        loop {
54            match self.rx.recv().await {
55                Ok(msg) => return Ok(to_batch(msg)),
56                Err(broadcast::error::RecvError::Closed) => {
57                    self.closed = true;
58                    return Err(RecvError::Disconnected);
59                }
60                Err(broadcast::error::RecvError::Lagged(_)) => {}
61            }
62        }
63    }
64
65    /// Blocking receive. Uses tokio's waker-based `blocking_recv`.
66    ///
67    /// # Errors
68    ///
69    /// Returns `RecvError::Disconnected` if the source has been dropped.
70    pub fn recv(&mut self) -> Result<RecordBatch, RecvError> {
71        loop {
72            match self.rx.blocking_recv() {
73                Ok(msg) => return Ok(to_batch(msg)),
74                Err(broadcast::error::RecvError::Closed) => {
75                    self.closed = true;
76                    return Err(RecvError::Disconnected);
77                }
78                Err(broadcast::error::RecvError::Lagged(_)) => {}
79            }
80        }
81    }
82
83    /// Blocking receive with timeout. Requires a tokio runtime in the current
84    /// thread context.
85    ///
86    /// # Errors
87    ///
88    /// Returns `RecvError::Timeout` or `RecvError::Disconnected`.
89    pub fn recv_timeout(&mut self, timeout: Duration) -> Result<RecordBatch, RecvError> {
90        let handle = tokio::runtime::Handle::current();
91        match handle.block_on(tokio::time::timeout(timeout, self.recv_async())) {
92            Ok(Ok(batch)) => Ok(batch),
93            Ok(Err(e)) => Err(e),
94            Err(_) => Err(RecvError::Timeout),
95        }
96    }
97
98    /// Returns true if the channel has been observed closed.
99    #[must_use]
100    pub fn is_disconnected(&self) -> bool {
101        self.closed
102    }
103
104    /// Returns the schema for records in this subscription.
105    #[must_use]
106    pub fn schema(&self) -> SchemaRef {
107        Arc::clone(&self.schema)
108    }
109}
110
111fn to_batch<T: Record>(msg: SourceMessage<T>) -> RecordBatch {
112    match msg {
113        SourceMessage::Record(r) => r.to_record_batch(),
114        SourceMessage::Batch(b) => b,
115    }
116}
117
118impl<T: Record + std::fmt::Debug> std::fmt::Debug for Subscription<T> {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        f.debug_struct("Subscription")
121            .field("closed", &self.closed)
122            .field("schema", &self.schema)
123            .finish_non_exhaustive()
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use crate::streaming::source::create;
131    use arrow::array::{Float64Array, Int64Array};
132    use arrow::datatypes::{DataType, Field, Schema};
133
134    #[derive(Clone, Debug)]
135    struct TestEvent {
136        id: i64,
137        value: f64,
138    }
139
140    impl Record for TestEvent {
141        fn schema() -> SchemaRef {
142            Arc::new(Schema::new(vec![
143                Field::new("id", DataType::Int64, false),
144                Field::new("value", DataType::Float64, false),
145            ]))
146        }
147
148        fn to_record_batch(&self) -> RecordBatch {
149            RecordBatch::try_new(
150                Self::schema(),
151                vec![
152                    Arc::new(Int64Array::from(vec![self.id])),
153                    Arc::new(Float64Array::from(vec![self.value])),
154                ],
155            )
156            .unwrap()
157        }
158    }
159
160    #[tokio::test]
161    async fn test_poll_empty() {
162        let (_source, sink) = create::<TestEvent>(16);
163        let mut sub = sink.subscribe();
164        assert!(sub.poll().is_none());
165    }
166
167    #[tokio::test]
168    async fn test_single_subscriber_async() {
169        let (source, sink) = create::<TestEvent>(16);
170        let mut sub = sink.subscribe();
171
172        source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
173        let batch = sub.recv_async().await.unwrap();
174        assert_eq!(batch.num_rows(), 1);
175    }
176
177    #[tokio::test]
178    async fn test_multiple_subscribers_all_receive() {
179        let (source, sink) = create::<TestEvent>(16);
180        let mut sub1 = sink.subscribe();
181        let mut sub2 = sink.subscribe();
182
183        source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
184
185        let b1 = sub1.recv_async().await.unwrap();
186        let b2 = sub2.recv_async().await.unwrap();
187        assert_eq!(b1.num_rows(), 1);
188        assert_eq!(b2.num_rows(), 1);
189    }
190
191    #[tokio::test]
192    async fn test_disconnected_after_source_and_sink_drop() {
193        let (source, sink) = create::<TestEvent>(16);
194        let mut sub = sink.subscribe();
195
196        drop(source);
197        drop(sink);
198        // Drain task exits on source disconnect; once Sink is dropped too,
199        // the broadcast closes and recv_async returns Disconnected.
200        tokio::time::sleep(Duration::from_millis(50)).await;
201
202        assert!(sub.recv_async().await.is_err());
203        assert!(sub.is_disconnected());
204    }
205
206    #[tokio::test]
207    async fn test_schema() {
208        let (_source, sink) = create::<TestEvent>(16);
209        let sub = sink.subscribe();
210        assert_eq!(sub.schema().fields().len(), 2);
211    }
212}