laminar_core/streaming/
subscription.rs1use std::sync::Arc;
4use std::time::Duration;
5
6use arrow::array::RecordBatch;
7use arrow::datatypes::SchemaRef;
8use tokio::sync::broadcast;
9
10use super::error::RecvError;
11use super::source::{Record, SourceMessage};
12
13pub struct Subscription<T: Record> {
16 rx: broadcast::Receiver<SourceMessage<T>>,
17 schema: SchemaRef,
18 closed: bool,
19}
20
21impl<T: Record> Subscription<T> {
22 pub(crate) fn new(rx: broadcast::Receiver<SourceMessage<T>>, schema: SchemaRef) -> Self {
23 Self {
24 rx,
25 schema,
26 closed: false,
27 }
28 }
29
30 pub fn poll(&mut self) -> Option<RecordBatch> {
34 loop {
35 match self.rx.try_recv() {
36 Ok(msg) => return Some(to_batch(msg)),
37 Err(broadcast::error::TryRecvError::Empty) => return None,
38 Err(broadcast::error::TryRecvError::Closed) => {
39 self.closed = true;
40 return None;
41 }
42 Err(broadcast::error::TryRecvError::Lagged(_)) => {}
43 }
44 }
45 }
46
47 pub async fn recv_async(&mut self) -> Result<RecordBatch, RecvError> {
53 loop {
54 match self.rx.recv().await {
55 Ok(msg) => return Ok(to_batch(msg)),
56 Err(broadcast::error::RecvError::Closed) => {
57 self.closed = true;
58 return Err(RecvError::Disconnected);
59 }
60 Err(broadcast::error::RecvError::Lagged(_)) => {}
61 }
62 }
63 }
64
65 pub fn recv(&mut self) -> Result<RecordBatch, RecvError> {
71 loop {
72 match self.rx.blocking_recv() {
73 Ok(msg) => return Ok(to_batch(msg)),
74 Err(broadcast::error::RecvError::Closed) => {
75 self.closed = true;
76 return Err(RecvError::Disconnected);
77 }
78 Err(broadcast::error::RecvError::Lagged(_)) => {}
79 }
80 }
81 }
82
83 pub fn recv_timeout(&mut self, timeout: Duration) -> Result<RecordBatch, RecvError> {
90 let handle = tokio::runtime::Handle::current();
91 match handle.block_on(tokio::time::timeout(timeout, self.recv_async())) {
92 Ok(Ok(batch)) => Ok(batch),
93 Ok(Err(e)) => Err(e),
94 Err(_) => Err(RecvError::Timeout),
95 }
96 }
97
98 #[must_use]
100 pub fn is_disconnected(&self) -> bool {
101 self.closed
102 }
103
104 #[must_use]
106 pub fn schema(&self) -> SchemaRef {
107 Arc::clone(&self.schema)
108 }
109}
110
111fn to_batch<T: Record>(msg: SourceMessage<T>) -> RecordBatch {
112 match msg {
113 SourceMessage::Record(r) => r.to_record_batch(),
114 SourceMessage::Batch(b) => b,
115 }
116}
117
118impl<T: Record + std::fmt::Debug> std::fmt::Debug for Subscription<T> {
119 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 f.debug_struct("Subscription")
121 .field("closed", &self.closed)
122 .field("schema", &self.schema)
123 .finish_non_exhaustive()
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130 use crate::streaming::source::create;
131 use arrow::array::{Float64Array, Int64Array};
132 use arrow::datatypes::{DataType, Field, Schema};
133
134 #[derive(Clone, Debug)]
135 struct TestEvent {
136 id: i64,
137 value: f64,
138 }
139
140 impl Record for TestEvent {
141 fn schema() -> SchemaRef {
142 Arc::new(Schema::new(vec![
143 Field::new("id", DataType::Int64, false),
144 Field::new("value", DataType::Float64, false),
145 ]))
146 }
147
148 fn to_record_batch(&self) -> RecordBatch {
149 RecordBatch::try_new(
150 Self::schema(),
151 vec![
152 Arc::new(Int64Array::from(vec![self.id])),
153 Arc::new(Float64Array::from(vec![self.value])),
154 ],
155 )
156 .unwrap()
157 }
158 }
159
160 #[tokio::test]
161 async fn test_poll_empty() {
162 let (_source, sink) = create::<TestEvent>(16);
163 let mut sub = sink.subscribe();
164 assert!(sub.poll().is_none());
165 }
166
167 #[tokio::test]
168 async fn test_single_subscriber_async() {
169 let (source, sink) = create::<TestEvent>(16);
170 let mut sub = sink.subscribe();
171
172 source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
173 let batch = sub.recv_async().await.unwrap();
174 assert_eq!(batch.num_rows(), 1);
175 }
176
177 #[tokio::test]
178 async fn test_multiple_subscribers_all_receive() {
179 let (source, sink) = create::<TestEvent>(16);
180 let mut sub1 = sink.subscribe();
181 let mut sub2 = sink.subscribe();
182
183 source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
184
185 let b1 = sub1.recv_async().await.unwrap();
186 let b2 = sub2.recv_async().await.unwrap();
187 assert_eq!(b1.num_rows(), 1);
188 assert_eq!(b2.num_rows(), 1);
189 }
190
191 #[tokio::test]
192 async fn test_disconnected_after_source_and_sink_drop() {
193 let (source, sink) = create::<TestEvent>(16);
194 let mut sub = sink.subscribe();
195
196 drop(source);
197 drop(sink);
198 tokio::time::sleep(Duration::from_millis(50)).await;
201
202 assert!(sub.recv_async().await.is_err());
203 assert!(sub.is_disconnected());
204 }
205
206 #[tokio::test]
207 async fn test_schema() {
208 let (_source, sink) = create::<TestEvent>(16);
209 let sub = sink.subscribe();
210 assert_eq!(sub.schema().fields().len(), 2);
211 }
212}