Skip to main content

laminar_core/streaming/
channel.rs

1//! Bounded MPSC channel backed by crossfire.
2
3use std::ops::Deref;
4
5use crossfire::mpsc;
6
7use super::config::ChannelConfig;
8use super::error::TryPushError;
9
10type Flavor<T> = mpsc::Array<T>;
11
12/// Sender half. Cloneable for multi-producer use.
13pub struct Producer<T: Send + 'static> {
14    tx: crossfire::MTx<Flavor<T>>,
15}
16
17/// Async receiver half.
18pub struct AsyncConsumer<T: Send + 'static> {
19    rx: crossfire::AsyncRx<Flavor<T>>,
20}
21
22/// Creates a bounded channel with blocking sender and async receiver.
23#[must_use]
24pub fn channel<T: Send + 'static>(buffer_size: usize) -> (Producer<T>, AsyncConsumer<T>) {
25    let cap = buffer_size.max(2);
26    let (tx, rx) = mpsc::bounded_blocking_async::<T>(cap);
27    (Producer { tx }, AsyncConsumer { rx })
28}
29
30/// Creates a bounded channel from a [`ChannelConfig`].
31#[must_use]
32pub(crate) fn channel_with_config<T: Send + 'static>(
33    config: &ChannelConfig,
34) -> (Producer<T>, AsyncConsumer<T>) {
35    channel(config.buffer_size)
36}
37
38// -- Producer -----------------------------------------------------------------
39
40impl<T: Send + 'static> Producer<T> {
41    /// Non-blocking send.
42    ///
43    /// # Errors
44    ///
45    /// Returns the item if the channel is full or the receiver was dropped.
46    pub fn push(&self, item: T) -> Result<(), T> {
47        self.tx
48            .try_send(item)
49            .map_err(crossfire::TrySendError::into_inner)
50    }
51
52    /// Non-blocking send with typed error.
53    ///
54    /// # Errors
55    ///
56    /// Returns `TryPushError` containing the item if the channel is full.
57    pub fn try_push(&self, item: T) -> Result<(), TryPushError<T>> {
58        self.tx.try_send(item).map_err(|e| match e {
59            crossfire::TrySendError::Full(v) => TryPushError::full(v),
60            crossfire::TrySendError::Disconnected(v) => TryPushError::disconnected(v),
61        })
62    }
63
64    /// Returns `true` if the receiver has been dropped.
65    pub fn is_closed(&self) -> bool {
66        self.tx.is_disconnected()
67    }
68
69    /// Number of items currently buffered.
70    pub fn len(&self) -> usize {
71        self.tx.deref().len()
72    }
73
74    /// Buffer capacity.
75    pub fn capacity(&self) -> usize {
76        self.tx.deref().capacity().unwrap_or(0)
77    }
78
79    /// Whether the buffer is empty.
80    pub fn is_empty(&self) -> bool {
81        self.tx.deref().is_empty()
82    }
83}
84
85impl<T: Send + 'static> Clone for Producer<T> {
86    fn clone(&self) -> Self {
87        Self {
88            tx: self.tx.clone(),
89        }
90    }
91}
92
93impl<T: Send + 'static> std::fmt::Debug for Producer<T> {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        f.debug_struct("Producer")
96            .field("len", &self.len())
97            .field("capacity", &self.capacity())
98            .finish()
99    }
100}
101
102// -- AsyncConsumer ------------------------------------------------------------
103
104impl<T: Send + 'static> AsyncConsumer<T> {
105    /// Async receive. Suspends until a message arrives or the sender disconnects.
106    ///
107    /// # Errors
108    ///
109    /// Returns `crossfire::RecvError` if the sender was dropped.
110    pub async fn recv(&mut self) -> Result<T, crossfire::RecvError> {
111        self.rx.recv().await
112    }
113
114    /// Non-blocking receive. Returns `Err` immediately if the channel is
115    /// empty or the senders have all disconnected.
116    ///
117    /// # Errors
118    ///
119    /// Returns `crossfire::TryRecvError::Empty` when no items are buffered,
120    /// `crossfire::TryRecvError::Disconnected` after all senders are dropped
121    /// and the buffer is drained.
122    pub fn try_recv(&self) -> Result<T, crossfire::TryRecvError> {
123        self.rx.try_recv()
124    }
125
126    /// Returns `true` if the sender has been dropped.
127    #[must_use]
128    pub fn is_disconnected(&self) -> bool {
129        self.rx.is_disconnected()
130    }
131}
132
133impl<T: Send + 'static> std::fmt::Debug for AsyncConsumer<T> {
134    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135        f.debug_struct("AsyncConsumer")
136            .field("disconnected", &self.is_disconnected())
137            .finish()
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    #[test]
146    fn test_send_recv() {
147        let rt = tokio::runtime::Runtime::new().unwrap();
148        let (tx, mut rx) = channel::<i32>(16);
149        tx.push(42).unwrap();
150        let val = rt.block_on(rx.recv()).unwrap();
151        assert_eq!(val, 42);
152    }
153
154    #[test]
155    fn test_try_push_full() {
156        let rt = tokio::runtime::Runtime::new().unwrap();
157        let (tx, mut rx) = channel::<i32>(2);
158        assert!(tx.try_push(1).is_ok());
159        assert!(tx.try_push(2).is_ok());
160        let err = tx.try_push(3);
161        assert!(err.is_err());
162        assert_eq!(err.unwrap_err().into_inner(), 3);
163        assert_eq!(rt.block_on(rx.recv()).unwrap(), 1);
164        assert!(tx.try_push(3).is_ok());
165    }
166
167    #[tokio::test]
168    async fn test_disconnected_on_drop() {
169        let (tx, rx) = channel::<i32>(16);
170        assert!(!rx.is_disconnected());
171        drop(tx);
172        assert!(rx.is_disconnected());
173    }
174
175    #[test]
176    fn test_closed_on_drop() {
177        let (tx, rx) = channel::<i32>(16);
178        assert!(!tx.is_closed());
179        drop(rx);
180        assert!(tx.is_closed());
181    }
182
183    #[test]
184    fn test_clone_multi_producer() {
185        let rt = tokio::runtime::Runtime::new().unwrap();
186        let (tx, mut rx) = channel::<i32>(16);
187        let tx2 = tx.clone();
188        tx.push(1).unwrap();
189        tx2.push(2).unwrap();
190        let a = rt.block_on(rx.recv()).unwrap();
191        let b = rt.block_on(rx.recv()).unwrap();
192        let mut items = vec![a, b];
193        items.sort_unstable();
194        assert_eq!(items, vec![1, 2]);
195    }
196
197    #[tokio::test]
198    async fn test_async_recv() {
199        let (tx, mut rx) = channel::<i32>(64);
200        tokio::spawn(async move {
201            tokio::time::sleep(std::time::Duration::from_millis(5)).await;
202            tx.push(42).unwrap();
203        });
204        let val = rx.recv().await.unwrap();
205        assert_eq!(val, 42);
206    }
207}