laminar_core/streaming/
sink.rs1use std::sync::Arc;
4
5use arrow::datatypes::SchemaRef;
6use tokio::sync::broadcast;
7
8use super::channel::AsyncConsumer;
9use super::source::{Record, SourceMessage};
10use super::subscription::Subscription;
11
12const DEFAULT_BROADCAST_CAPACITY: usize = 2048;
13
14pub struct Sink<T: Record> {
17 broadcast_tx: broadcast::Sender<SourceMessage<T>>,
18 schema: SchemaRef,
19}
20
21impl<T: Record> Sink<T> {
22 pub(crate) fn new(consumer: AsyncConsumer<SourceMessage<T>>, schema: SchemaRef) -> Self {
23 let (broadcast_tx, _) = broadcast::channel(DEFAULT_BROADCAST_CAPACITY);
24 let tx = broadcast_tx.clone();
25
26 tokio::spawn(async move {
27 drain_loop(consumer, tx).await;
28 });
29
30 Self {
31 broadcast_tx,
32 schema,
33 }
34 }
35
36 #[must_use]
38 pub fn subscribe(&self) -> Subscription<T> {
39 Subscription::new(self.broadcast_tx.subscribe(), Arc::clone(&self.schema))
40 }
41
42 #[must_use]
44 pub fn schema(&self) -> SchemaRef {
45 Arc::clone(&self.schema)
46 }
47
48 #[must_use]
50 pub fn subscriber_count(&self) -> usize {
51 self.broadcast_tx.receiver_count()
52 }
53}
54
55impl<T: Record + std::fmt::Debug> std::fmt::Debug for Sink<T> {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 f.debug_struct("Sink")
58 .field("subscribers", &self.subscriber_count())
59 .finish()
60 }
61}
62
63async fn drain_loop<T: Record>(
64 mut consumer: AsyncConsumer<SourceMessage<T>>,
65 tx: broadcast::Sender<SourceMessage<T>>,
66) {
67 while let Ok(msg) = consumer.recv().await {
68 let _ = tx.send(msg);
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use crate::streaming::source::create;
75 use crate::streaming::source::Record;
76 use arrow::array::{Float64Array, Int64Array, RecordBatch};
77 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
78 use std::sync::Arc;
79
80 #[derive(Clone, Debug)]
81 struct TestEvent {
82 id: i64,
83 value: f64,
84 }
85
86 impl Record for TestEvent {
87 fn schema() -> SchemaRef {
88 Arc::new(Schema::new(vec![
89 Field::new("id", DataType::Int64, false),
90 Field::new("value", DataType::Float64, false),
91 ]))
92 }
93
94 fn to_record_batch(&self) -> RecordBatch {
95 RecordBatch::try_new(
96 Self::schema(),
97 vec![
98 Arc::new(Int64Array::from(vec![self.id])),
99 Arc::new(Float64Array::from(vec![self.value])),
100 ],
101 )
102 .unwrap()
103 }
104 }
105
106 #[tokio::test]
107 async fn test_single_subscriber() {
108 let (source, sink) = create::<TestEvent>(16);
109 let mut sub = sink.subscribe();
110
111 source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
112 let batch = sub.recv_async().await.unwrap();
113 assert_eq!(batch.num_rows(), 1);
114 }
115
116 #[tokio::test]
117 async fn test_multiple_subscribers_all_receive() {
118 let (source, sink) = create::<TestEvent>(16);
119 let mut sub1 = sink.subscribe();
120 let mut sub2 = sink.subscribe();
121
122 source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
123
124 let b1 = sub1.recv_async().await.unwrap();
125 let b2 = sub2.recv_async().await.unwrap();
126 assert_eq!(b1.num_rows(), 1);
127 assert_eq!(b2.num_rows(), 1);
128 }
129
130 #[tokio::test]
131 async fn test_schema() {
132 let (_source, sink) = create::<TestEvent>(16);
133 assert_eq!(sink.schema().fields().len(), 2);
134 }
135
136 #[tokio::test]
137 async fn test_subscriber_count() {
138 let (_source, sink) = create::<TestEvent>(16);
139 assert_eq!(sink.subscriber_count(), 0);
140
141 let sub1 = sink.subscribe();
142 assert_eq!(sink.subscriber_count(), 1);
143
144 let _sub2 = sink.subscribe();
145 assert_eq!(sink.subscriber_count(), 2);
146
147 drop(sub1);
148 assert_eq!(sink.subscriber_count(), 1);
149 }
150}