Skip to main content

laminar_core/streaming/
subscription.rs

1//! Subscription — receive records from a Sink.
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use arrow::array::RecordBatch;
7use arrow::datatypes::SchemaRef;
8use tokio::sync::broadcast;
9
10use super::error::RecvError;
11use super::source::{Record, SourceMessage};
12
13/// A subscription to a streaming sink. Each subscriber independently receives
14/// every message via broadcast.
15pub struct Subscription<T: Record> {
16    rx: broadcast::Receiver<SourceMessage<T>>,
17    schema: SchemaRef,
18    closed: bool,
19}
20
21impl<T: Record> Subscription<T> {
22    pub(crate) fn new(rx: broadcast::Receiver<SourceMessage<T>>, schema: SchemaRef) -> Self {
23        Self {
24            rx,
25            schema,
26            closed: false,
27        }
28    }
29
30    /// Non-blocking poll. Returns the next batch, skipping watermarks.
31    /// Returns `None` on empty or closed channel. Check `is_disconnected()`
32    /// to distinguish.
33    pub fn poll(&mut self) -> Option<RecordBatch> {
34        loop {
35            match self.rx.try_recv() {
36                Ok(msg) => {
37                    if let Some(batch) = message_to_batch(msg) {
38                        return Some(batch);
39                    }
40                }
41                Err(broadcast::error::TryRecvError::Empty) => return None,
42                Err(broadcast::error::TryRecvError::Closed) => {
43                    self.closed = true;
44                    return None;
45                }
46                Err(broadcast::error::TryRecvError::Lagged(_)) => {}
47            }
48        }
49    }
50
51    /// Async receive. Awaits the next batch, skipping watermarks.
52    ///
53    /// # Errors
54    ///
55    /// Returns `RecvError::Disconnected` if the source has been dropped.
56    pub async fn recv_async(&mut self) -> Result<RecordBatch, RecvError> {
57        loop {
58            match self.rx.recv().await {
59                Ok(msg) => {
60                    if let Some(batch) = message_to_batch(msg) {
61                        return Ok(batch);
62                    }
63                }
64                Err(broadcast::error::RecvError::Closed) => {
65                    self.closed = true;
66                    return Err(RecvError::Disconnected);
67                }
68                Err(broadcast::error::RecvError::Lagged(_)) => {}
69            }
70        }
71    }
72
73    /// Blocking receive. Uses tokio's waker-based `blocking_recv`.
74    ///
75    /// # Errors
76    ///
77    /// Returns `RecvError::Disconnected` if the source has been dropped.
78    pub fn recv(&mut self) -> Result<RecordBatch, RecvError> {
79        loop {
80            match self.rx.blocking_recv() {
81                Ok(msg) => {
82                    if let Some(batch) = message_to_batch(msg) {
83                        return Ok(batch);
84                    }
85                }
86                Err(broadcast::error::RecvError::Closed) => {
87                    self.closed = true;
88                    return Err(RecvError::Disconnected);
89                }
90                Err(broadcast::error::RecvError::Lagged(_)) => {}
91            }
92        }
93    }
94
95    /// Blocking receive with timeout. Requires a tokio runtime in the current
96    /// thread context.
97    ///
98    /// # Errors
99    ///
100    /// Returns `RecvError::Timeout` or `RecvError::Disconnected`.
101    pub fn recv_timeout(&mut self, timeout: Duration) -> Result<RecordBatch, RecvError> {
102        let handle = tokio::runtime::Handle::current();
103        match handle.block_on(tokio::time::timeout(timeout, self.recv_async())) {
104            Ok(Ok(batch)) => Ok(batch),
105            Ok(Err(e)) => Err(e),
106            Err(_) => Err(RecvError::Timeout),
107        }
108    }
109
110    /// Returns true if the channel has been observed closed.
111    #[must_use]
112    pub fn is_disconnected(&self) -> bool {
113        self.closed
114    }
115
116    /// Returns the schema for records in this subscription.
117    #[must_use]
118    pub fn schema(&self) -> SchemaRef {
119        Arc::clone(&self.schema)
120    }
121}
122
123fn message_to_batch<T: Record>(msg: SourceMessage<T>) -> Option<RecordBatch> {
124    match msg {
125        SourceMessage::Record(r) => Some(r.to_record_batch()),
126        SourceMessage::Batch(b) => Some(b),
127        SourceMessage::Watermark(_) => None,
128    }
129}
130
131impl<T: Record + std::fmt::Debug> std::fmt::Debug for Subscription<T> {
132    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133        f.debug_struct("Subscription")
134            .field("closed", &self.closed)
135            .field("schema", &self.schema)
136            .finish_non_exhaustive()
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use crate::streaming::source::create;
144    use arrow::array::{Float64Array, Int64Array};
145    use arrow::datatypes::{DataType, Field, Schema};
146
147    #[derive(Clone, Debug)]
148    struct TestEvent {
149        id: i64,
150        value: f64,
151    }
152
153    impl Record for TestEvent {
154        fn schema() -> SchemaRef {
155            Arc::new(Schema::new(vec![
156                Field::new("id", DataType::Int64, false),
157                Field::new("value", DataType::Float64, false),
158            ]))
159        }
160
161        fn to_record_batch(&self) -> RecordBatch {
162            RecordBatch::try_new(
163                Self::schema(),
164                vec![
165                    Arc::new(Int64Array::from(vec![self.id])),
166                    Arc::new(Float64Array::from(vec![self.value])),
167                ],
168            )
169            .unwrap()
170        }
171    }
172
173    #[tokio::test]
174    async fn test_poll_empty() {
175        let (_source, sink) = create::<TestEvent>(16);
176        let mut sub = sink.subscribe();
177        assert!(sub.poll().is_none());
178    }
179
180    #[tokio::test]
181    async fn test_single_subscriber_async() {
182        let (source, sink) = create::<TestEvent>(16);
183        let mut sub = sink.subscribe();
184
185        source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
186        let batch = sub.recv_async().await.unwrap();
187        assert_eq!(batch.num_rows(), 1);
188    }
189
190    #[tokio::test]
191    async fn test_multiple_subscribers_all_receive() {
192        let (source, sink) = create::<TestEvent>(16);
193        let mut sub1 = sink.subscribe();
194        let mut sub2 = sink.subscribe();
195
196        source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
197
198        let b1 = sub1.recv_async().await.unwrap();
199        let b2 = sub2.recv_async().await.unwrap();
200        assert_eq!(b1.num_rows(), 1);
201        assert_eq!(b2.num_rows(), 1);
202    }
203
204    #[tokio::test]
205    async fn test_disconnected_after_source_and_sink_drop() {
206        let (source, sink) = create::<TestEvent>(16);
207        let mut sub = sink.subscribe();
208
209        drop(source);
210        drop(sink);
211        // Drain task exits on source disconnect; once Sink is dropped too,
212        // the broadcast closes and recv_async returns Disconnected.
213        tokio::time::sleep(Duration::from_millis(50)).await;
214
215        assert!(sub.recv_async().await.is_err());
216        assert!(sub.is_disconnected());
217    }
218
219    #[tokio::test]
220    async fn test_schema() {
221        let (_source, sink) = create::<TestEvent>(16);
222        let sub = sink.subscribe();
223        assert_eq!(sub.schema().fields().len(), 2);
224    }
225}