laminar_core/streaming/
subscription.rs1use std::sync::Arc;
4use std::time::Duration;
5
6use arrow::array::RecordBatch;
7use arrow::datatypes::SchemaRef;
8use tokio::sync::broadcast;
9
10use super::error::RecvError;
11use super::source::{Record, SourceMessage};
12
13pub struct Subscription<T: Record> {
16 rx: broadcast::Receiver<SourceMessage<T>>,
17 schema: SchemaRef,
18 closed: bool,
19}
20
21impl<T: Record> Subscription<T> {
22 pub(crate) fn new(rx: broadcast::Receiver<SourceMessage<T>>, schema: SchemaRef) -> Self {
23 Self {
24 rx,
25 schema,
26 closed: false,
27 }
28 }
29
30 pub fn poll(&mut self) -> Option<RecordBatch> {
34 loop {
35 match self.rx.try_recv() {
36 Ok(msg) => {
37 if let Some(batch) = message_to_batch(msg) {
38 return Some(batch);
39 }
40 }
41 Err(broadcast::error::TryRecvError::Empty) => return None,
42 Err(broadcast::error::TryRecvError::Closed) => {
43 self.closed = true;
44 return None;
45 }
46 Err(broadcast::error::TryRecvError::Lagged(_)) => {}
47 }
48 }
49 }
50
51 pub async fn recv_async(&mut self) -> Result<RecordBatch, RecvError> {
57 loop {
58 match self.rx.recv().await {
59 Ok(msg) => {
60 if let Some(batch) = message_to_batch(msg) {
61 return Ok(batch);
62 }
63 }
64 Err(broadcast::error::RecvError::Closed) => {
65 self.closed = true;
66 return Err(RecvError::Disconnected);
67 }
68 Err(broadcast::error::RecvError::Lagged(_)) => {}
69 }
70 }
71 }
72
73 pub fn recv(&mut self) -> Result<RecordBatch, RecvError> {
79 loop {
80 match self.rx.blocking_recv() {
81 Ok(msg) => {
82 if let Some(batch) = message_to_batch(msg) {
83 return Ok(batch);
84 }
85 }
86 Err(broadcast::error::RecvError::Closed) => {
87 self.closed = true;
88 return Err(RecvError::Disconnected);
89 }
90 Err(broadcast::error::RecvError::Lagged(_)) => {}
91 }
92 }
93 }
94
95 pub fn recv_timeout(&mut self, timeout: Duration) -> Result<RecordBatch, RecvError> {
102 let handle = tokio::runtime::Handle::current();
103 match handle.block_on(tokio::time::timeout(timeout, self.recv_async())) {
104 Ok(Ok(batch)) => Ok(batch),
105 Ok(Err(e)) => Err(e),
106 Err(_) => Err(RecvError::Timeout),
107 }
108 }
109
110 #[must_use]
112 pub fn is_disconnected(&self) -> bool {
113 self.closed
114 }
115
116 #[must_use]
118 pub fn schema(&self) -> SchemaRef {
119 Arc::clone(&self.schema)
120 }
121}
122
123fn message_to_batch<T: Record>(msg: SourceMessage<T>) -> Option<RecordBatch> {
124 match msg {
125 SourceMessage::Record(r) => Some(r.to_record_batch()),
126 SourceMessage::Batch(b) => Some(b),
127 SourceMessage::Watermark(_) => None,
128 }
129}
130
131impl<T: Record + std::fmt::Debug> std::fmt::Debug for Subscription<T> {
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 f.debug_struct("Subscription")
134 .field("closed", &self.closed)
135 .field("schema", &self.schema)
136 .finish_non_exhaustive()
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use crate::streaming::source::create;
144 use arrow::array::{Float64Array, Int64Array};
145 use arrow::datatypes::{DataType, Field, Schema};
146
147 #[derive(Clone, Debug)]
148 struct TestEvent {
149 id: i64,
150 value: f64,
151 }
152
153 impl Record for TestEvent {
154 fn schema() -> SchemaRef {
155 Arc::new(Schema::new(vec![
156 Field::new("id", DataType::Int64, false),
157 Field::new("value", DataType::Float64, false),
158 ]))
159 }
160
161 fn to_record_batch(&self) -> RecordBatch {
162 RecordBatch::try_new(
163 Self::schema(),
164 vec![
165 Arc::new(Int64Array::from(vec![self.id])),
166 Arc::new(Float64Array::from(vec![self.value])),
167 ],
168 )
169 .unwrap()
170 }
171 }
172
173 #[tokio::test]
174 async fn test_poll_empty() {
175 let (_source, sink) = create::<TestEvent>(16);
176 let mut sub = sink.subscribe();
177 assert!(sub.poll().is_none());
178 }
179
180 #[tokio::test]
181 async fn test_single_subscriber_async() {
182 let (source, sink) = create::<TestEvent>(16);
183 let mut sub = sink.subscribe();
184
185 source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
186 let batch = sub.recv_async().await.unwrap();
187 assert_eq!(batch.num_rows(), 1);
188 }
189
190 #[tokio::test]
191 async fn test_multiple_subscribers_all_receive() {
192 let (source, sink) = create::<TestEvent>(16);
193 let mut sub1 = sink.subscribe();
194 let mut sub2 = sink.subscribe();
195
196 source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
197
198 let b1 = sub1.recv_async().await.unwrap();
199 let b2 = sub2.recv_async().await.unwrap();
200 assert_eq!(b1.num_rows(), 1);
201 assert_eq!(b2.num_rows(), 1);
202 }
203
204 #[tokio::test]
205 async fn test_disconnected_after_source_and_sink_drop() {
206 let (source, sink) = create::<TestEvent>(16);
207 let mut sub = sink.subscribe();
208
209 drop(source);
210 drop(sink);
211 tokio::time::sleep(Duration::from_millis(50)).await;
214
215 assert!(sub.recv_async().await.is_err());
216 assert!(sub.is_disconnected());
217 }
218
219 #[tokio::test]
220 async fn test_schema() {
221 let (_source, sink) = create::<TestEvent>(16);
222 let sub = sink.subscribe();
223 assert_eq!(sub.schema().fields().len(), 2);
224 }
225}