laminar_core/streaming/
channel.rs1use std::ops::Deref;
4
5use crossfire::mpsc;
6
7use super::config::ChannelConfig;
8use super::error::TryPushError;
9
10type Flavor<T> = mpsc::Array<T>;
11
12pub struct Producer<T: Send + 'static> {
14 tx: crossfire::MTx<Flavor<T>>,
15}
16
17pub struct AsyncConsumer<T: Send + 'static> {
19 rx: crossfire::AsyncRx<Flavor<T>>,
20}
21
22#[must_use]
24pub fn channel<T: Send + 'static>(buffer_size: usize) -> (Producer<T>, AsyncConsumer<T>) {
25 let cap = buffer_size.max(2);
26 let (tx, rx) = mpsc::bounded_blocking_async::<T>(cap);
27 (Producer { tx }, AsyncConsumer { rx })
28}
29
30#[must_use]
32pub(crate) fn channel_with_config<T: Send + 'static>(
33 config: &ChannelConfig,
34) -> (Producer<T>, AsyncConsumer<T>) {
35 channel(config.buffer_size)
36}
37
38impl<T: Send + 'static> Producer<T> {
41 pub fn push(&self, item: T) -> Result<(), T> {
47 self.tx
48 .try_send(item)
49 .map_err(crossfire::TrySendError::into_inner)
50 }
51
52 pub fn try_push(&self, item: T) -> Result<(), TryPushError<T>> {
58 self.tx.try_send(item).map_err(|e| match e {
59 crossfire::TrySendError::Full(v) => TryPushError::full(v),
60 crossfire::TrySendError::Disconnected(v) => TryPushError::disconnected(v),
61 })
62 }
63
64 pub fn is_closed(&self) -> bool {
66 self.tx.is_disconnected()
67 }
68
69 pub fn len(&self) -> usize {
71 self.tx.deref().len()
72 }
73
74 pub fn capacity(&self) -> usize {
76 self.tx.deref().capacity().unwrap_or(0)
77 }
78
79 pub fn is_empty(&self) -> bool {
81 self.tx.deref().is_empty()
82 }
83}
84
85impl<T: Send + 'static> Clone for Producer<T> {
86 fn clone(&self) -> Self {
87 Self {
88 tx: self.tx.clone(),
89 }
90 }
91}
92
93impl<T: Send + 'static> std::fmt::Debug for Producer<T> {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 f.debug_struct("Producer")
96 .field("len", &self.len())
97 .field("capacity", &self.capacity())
98 .finish()
99 }
100}
101
102impl<T: Send + 'static> AsyncConsumer<T> {
105 pub async fn recv(&mut self) -> Result<T, crossfire::RecvError> {
111 self.rx.recv().await
112 }
113
114 pub fn try_recv(&self) -> Result<T, crossfire::TryRecvError> {
123 self.rx.try_recv()
124 }
125
126 #[must_use]
128 pub fn is_disconnected(&self) -> bool {
129 self.rx.is_disconnected()
130 }
131}
132
133impl<T: Send + 'static> std::fmt::Debug for AsyncConsumer<T> {
134 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 f.debug_struct("AsyncConsumer")
136 .field("disconnected", &self.is_disconnected())
137 .finish()
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144
145 #[test]
146 fn test_send_recv() {
147 let rt = tokio::runtime::Runtime::new().unwrap();
148 let (tx, mut rx) = channel::<i32>(16);
149 tx.push(42).unwrap();
150 let val = rt.block_on(rx.recv()).unwrap();
151 assert_eq!(val, 42);
152 }
153
154 #[test]
155 fn test_try_push_full() {
156 let rt = tokio::runtime::Runtime::new().unwrap();
157 let (tx, mut rx) = channel::<i32>(2);
158 assert!(tx.try_push(1).is_ok());
159 assert!(tx.try_push(2).is_ok());
160 let err = tx.try_push(3);
161 assert!(err.is_err());
162 assert_eq!(err.unwrap_err().into_inner(), 3);
163 assert_eq!(rt.block_on(rx.recv()).unwrap(), 1);
164 assert!(tx.try_push(3).is_ok());
165 }
166
167 #[tokio::test]
168 async fn test_disconnected_on_drop() {
169 let (tx, rx) = channel::<i32>(16);
170 assert!(!rx.is_disconnected());
171 drop(tx);
172 assert!(rx.is_disconnected());
173 }
174
175 #[test]
176 fn test_closed_on_drop() {
177 let (tx, rx) = channel::<i32>(16);
178 assert!(!tx.is_closed());
179 drop(rx);
180 assert!(tx.is_closed());
181 }
182
183 #[test]
184 fn test_clone_multi_producer() {
185 let rt = tokio::runtime::Runtime::new().unwrap();
186 let (tx, mut rx) = channel::<i32>(16);
187 let tx2 = tx.clone();
188 tx.push(1).unwrap();
189 tx2.push(2).unwrap();
190 let a = rt.block_on(rx.recv()).unwrap();
191 let b = rt.block_on(rx.recv()).unwrap();
192 let mut items = vec![a, b];
193 items.sort_unstable();
194 assert_eq!(items, vec![1, 2]);
195 }
196
197 #[tokio::test]
198 async fn test_async_recv() {
199 let (tx, mut rx) = channel::<i32>(64);
200 tokio::spawn(async move {
201 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
202 tx.push(42).unwrap();
203 });
204 let val = rx.recv().await.unwrap();
205 assert_eq!(val, 42);
206 }
207}