Skip to main content

laminar_core/streaming/
sink.rs

1//! Sink — consumption endpoint with broadcast fan-out to multiple subscribers.
2
3use std::sync::Arc;
4
5use arrow::datatypes::SchemaRef;
6use tokio::sync::broadcast;
7
8use super::channel::AsyncConsumer;
9use super::source::{Record, SourceMessage};
10use super::subscription::Subscription;
11
12const DEFAULT_BROADCAST_CAPACITY: usize = 2048;
13
14/// A streaming data sink. Each `subscribe()` call returns an independent
15/// receiver that gets a copy of every message via broadcast.
16pub struct Sink<T: Record> {
17    broadcast_tx: broadcast::Sender<SourceMessage<T>>,
18    schema: SchemaRef,
19}
20
21impl<T: Record> Sink<T> {
22    pub(crate) fn new(consumer: AsyncConsumer<SourceMessage<T>>, schema: SchemaRef) -> Self {
23        let (broadcast_tx, _) = broadcast::channel(DEFAULT_BROADCAST_CAPACITY);
24        let tx = broadcast_tx.clone();
25
26        tokio::spawn(async move {
27            drain_loop(consumer, tx).await;
28        });
29
30        Self {
31            broadcast_tx,
32            schema,
33        }
34    }
35
36    /// Subscribe to this sink. Returns an independent receiver.
37    #[must_use]
38    pub fn subscribe(&self) -> Subscription<T> {
39        Subscription::new(self.broadcast_tx.subscribe(), Arc::clone(&self.schema))
40    }
41
42    /// Returns the schema for this sink.
43    #[must_use]
44    pub fn schema(&self) -> SchemaRef {
45        Arc::clone(&self.schema)
46    }
47
48    /// Number of active broadcast subscribers.
49    #[must_use]
50    pub fn subscriber_count(&self) -> usize {
51        self.broadcast_tx.receiver_count()
52    }
53}
54
55impl<T: Record + std::fmt::Debug> std::fmt::Debug for Sink<T> {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        f.debug_struct("Sink")
58            .field("subscribers", &self.subscriber_count())
59            .finish()
60    }
61}
62
63async fn drain_loop<T: Record>(
64    mut consumer: AsyncConsumer<SourceMessage<T>>,
65    tx: broadcast::Sender<SourceMessage<T>>,
66) {
67    while let Ok(msg) = consumer.recv().await {
68        let _ = tx.send(msg);
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use crate::streaming::source::create;
75    use crate::streaming::source::Record;
76    use arrow::array::{Float64Array, Int64Array, RecordBatch};
77    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
78    use std::sync::Arc;
79
80    #[derive(Clone, Debug)]
81    struct TestEvent {
82        id: i64,
83        value: f64,
84    }
85
86    impl Record for TestEvent {
87        fn schema() -> SchemaRef {
88            Arc::new(Schema::new(vec![
89                Field::new("id", DataType::Int64, false),
90                Field::new("value", DataType::Float64, false),
91            ]))
92        }
93
94        fn to_record_batch(&self) -> RecordBatch {
95            RecordBatch::try_new(
96                Self::schema(),
97                vec![
98                    Arc::new(Int64Array::from(vec![self.id])),
99                    Arc::new(Float64Array::from(vec![self.value])),
100                ],
101            )
102            .unwrap()
103        }
104    }
105
106    #[tokio::test]
107    async fn test_single_subscriber() {
108        let (source, sink) = create::<TestEvent>(16);
109        let mut sub = sink.subscribe();
110
111        source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
112        let batch = sub.recv_async().await.unwrap();
113        assert_eq!(batch.num_rows(), 1);
114    }
115
116    #[tokio::test]
117    async fn test_multiple_subscribers_all_receive() {
118        let (source, sink) = create::<TestEvent>(16);
119        let mut sub1 = sink.subscribe();
120        let mut sub2 = sink.subscribe();
121
122        source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
123
124        let b1 = sub1.recv_async().await.unwrap();
125        let b2 = sub2.recv_async().await.unwrap();
126        assert_eq!(b1.num_rows(), 1);
127        assert_eq!(b2.num_rows(), 1);
128    }
129
130    #[tokio::test]
131    async fn test_schema() {
132        let (_source, sink) = create::<TestEvent>(16);
133        assert_eq!(sink.schema().fields().len(), 2);
134    }
135
136    #[tokio::test]
137    async fn test_subscriber_count() {
138        let (_source, sink) = create::<TestEvent>(16);
139        assert_eq!(sink.subscriber_count(), 0);
140
141        let sub1 = sink.subscribe();
142        assert_eq!(sink.subscriber_count(), 1);
143
144        let _sub2 = sink.subscribe();
145        assert_eq!(sink.subscriber_count(), 2);
146
147        drop(sub1);
148        assert_eq!(sink.subscriber_count(), 1);
149    }
150}