laminar_core/streaming/
sink.rs1use std::sync::Arc;
4
5use arrow::datatypes::SchemaRef;
6use tokio::sync::broadcast;
7
8use super::channel::AsyncConsumer;
9use super::source::{Record, SourceMessage};
10use super::subscription::Subscription;
11
12const DEFAULT_BROADCAST_CAPACITY: usize = 2048;
13
14pub struct Sink<T: Record> {
23 broadcast_tx: broadcast::Sender<SourceMessage<T>>,
24 schema: SchemaRef,
25}
26
27impl<T: Record> Sink<T> {
28 pub(crate) fn new(consumer: AsyncConsumer<SourceMessage<T>>, schema: SchemaRef) -> Self {
29 let (broadcast_tx, _) = broadcast::channel(DEFAULT_BROADCAST_CAPACITY);
30 let tx = broadcast_tx.clone();
31
32 tokio::spawn(async move {
36 drain_loop(consumer, tx).await;
37 });
38
39 Self {
40 broadcast_tx,
41 schema,
42 }
43 }
44
45 #[must_use]
47 pub fn subscribe(&self) -> Subscription<T> {
48 Subscription::new(self.broadcast_tx.subscribe(), Arc::clone(&self.schema))
49 }
50
51 #[must_use]
53 pub fn schema(&self) -> SchemaRef {
54 Arc::clone(&self.schema)
55 }
56
57 #[must_use]
59 pub fn subscriber_count(&self) -> usize {
60 self.broadcast_tx.receiver_count()
61 }
62}
63
64impl<T: Record + std::fmt::Debug> std::fmt::Debug for Sink<T> {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("Sink")
67 .field("subscribers", &self.subscriber_count())
68 .finish()
69 }
70}
71
72async fn drain_loop<T: Record>(
73 mut consumer: AsyncConsumer<SourceMessage<T>>,
74 tx: broadcast::Sender<SourceMessage<T>>,
75) {
76 while let Ok(msg) = consumer.recv().await {
77 let _ = tx.send(msg);
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use crate::streaming::source::create;
84 use crate::streaming::source::Record;
85 use arrow::array::{Float64Array, Int64Array, RecordBatch};
86 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
87 use std::sync::Arc;
88
89 #[derive(Clone, Debug)]
90 struct TestEvent {
91 id: i64,
92 value: f64,
93 }
94
95 impl Record for TestEvent {
96 fn schema() -> SchemaRef {
97 Arc::new(Schema::new(vec![
98 Field::new("id", DataType::Int64, false),
99 Field::new("value", DataType::Float64, false),
100 ]))
101 }
102
103 fn to_record_batch(&self) -> RecordBatch {
104 RecordBatch::try_new(
105 Self::schema(),
106 vec![
107 Arc::new(Int64Array::from(vec![self.id])),
108 Arc::new(Float64Array::from(vec![self.value])),
109 ],
110 )
111 .unwrap()
112 }
113 }
114
115 #[tokio::test]
116 async fn test_single_subscriber() {
117 let (source, sink) = create::<TestEvent>(16);
118 let mut sub = sink.subscribe();
119
120 source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
121 let batch = sub.recv_async().await.unwrap();
122 assert_eq!(batch.num_rows(), 1);
123 }
124
125 #[tokio::test]
126 async fn test_multiple_subscribers_all_receive() {
127 let (source, sink) = create::<TestEvent>(16);
128 let mut sub1 = sink.subscribe();
129 let mut sub2 = sink.subscribe();
130
131 source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
132
133 let b1 = sub1.recv_async().await.unwrap();
134 let b2 = sub2.recv_async().await.unwrap();
135 assert_eq!(b1.num_rows(), 1);
136 assert_eq!(b2.num_rows(), 1);
137 }
138
139 #[tokio::test]
140 async fn test_schema() {
141 let (_source, sink) = create::<TestEvent>(16);
142 assert_eq!(sink.schema().fields().len(), 2);
143 }
144
145 #[tokio::test]
146 async fn test_subscriber_count() {
147 let (_source, sink) = create::<TestEvent>(16);
148 assert_eq!(sink.subscriber_count(), 0);
149
150 let sub1 = sink.subscribe();
151 assert_eq!(sink.subscriber_count(), 1);
152
153 let _sub2 = sink.subscribe();
154 assert_eq!(sink.subscriber_count(), 2);
155
156 drop(sub1);
157 assert_eq!(sink.subscriber_count(), 1);
158 }
159}