laminar_core/streaming/
sink.rs1use std::sync::atomic::{AtomicUsize, Ordering};
27use std::sync::Arc;
28
29use arrow::datatypes::SchemaRef;
30
31use super::channel::{channel_with_config, Consumer, Producer};
32use super::config::ChannelConfig;
33use super::source::{Record, SourceMessage};
34use super::subscription::Subscription;
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum SinkMode {
39 Single,
41 Broadcast,
43}
44
45struct SubscriberInner<T: Record> {
47 producer: Producer<SourceMessage<T>>,
49}
50
51pub(crate) struct SinkInner<T: Record> {
53 consumer: Consumer<SourceMessage<T>>,
55
56 schema: SchemaRef,
58
59 channel_config: ChannelConfig,
61
62 subscriber_count: AtomicUsize,
64}
65
66pub struct Sink<T: Record> {
99 inner: Arc<SinkInner<T>>,
100 subscribers: Arc<parking_lot::RwLock<Vec<SubscriberInner<T>>>>,
103}
104
105impl<T: Record> Sink<T> {
106 pub(crate) fn new(
108 consumer: Consumer<SourceMessage<T>>,
109 schema: SchemaRef,
110 channel_config: ChannelConfig,
111 ) -> Self {
112 Self {
113 inner: Arc::new(SinkInner {
114 consumer,
115 schema,
116 channel_config,
117 subscriber_count: AtomicUsize::new(0),
118 }),
119 subscribers: Arc::new(parking_lot::RwLock::new(Vec::new())),
120 }
121 }
122
123 #[must_use]
141 pub fn subscribe(&self) -> Subscription<T> {
142 let count = self.inner.subscriber_count.fetch_add(1, Ordering::AcqRel);
143
144 if count == 0 {
145 Subscription::new_direct(Arc::clone(&self.inner))
147 } else {
148 let (producer, consumer) =
150 channel_with_config::<SourceMessage<T>>(self.inner.channel_config.clone());
151
152 {
154 let mut subs = self.subscribers.write();
155 subs.push(SubscriberInner { producer });
156 }
157
158 Subscription::new_broadcast(consumer, Arc::clone(&self.inner.schema))
159 }
160 }
161
162 #[must_use]
164 pub fn subscriber_count(&self) -> usize {
165 self.inner.subscriber_count.load(Ordering::Acquire)
166 }
167
168 #[must_use]
170 pub fn mode(&self) -> SinkMode {
171 if self.subscriber_count() <= 1 {
172 SinkMode::Single
173 } else {
174 SinkMode::Broadcast
175 }
176 }
177
178 #[must_use]
180 pub fn schema(&self) -> SchemaRef {
181 Arc::clone(&self.inner.schema)
182 }
183
184 #[must_use]
186 pub fn is_disconnected(&self) -> bool {
187 self.inner.consumer.is_disconnected()
188 }
189
190 #[must_use]
192 pub fn pending(&self) -> usize {
193 self.inner.consumer.len()
194 }
195
196 #[must_use]
208 pub fn poll_and_distribute(&self) -> usize
209 where
210 T: Clone,
211 {
212 if self.mode() != SinkMode::Broadcast {
214 return 0;
215 }
216
217 let producers: smallvec::SmallVec<[Producer<SourceMessage<T>>; 4]> = {
219 let subscribers = self.subscribers.read();
220 if subscribers.is_empty() {
221 return 0;
222 }
223 subscribers.iter().map(|s| s.producer.clone()).collect()
224 };
225 let mut count = 0;
228
229 while let Some(msg) = self.inner.consumer.poll() {
231 for producer in &producers {
232 let msg_clone = match &msg {
234 SourceMessage::Record(r) => SourceMessage::Record(r.clone()),
235 SourceMessage::Batch(b) => SourceMessage::Batch(b.clone()),
236 SourceMessage::Watermark(ts) => SourceMessage::Watermark(*ts),
237 };
238 let _ = producer.try_push(msg_clone);
239 }
240 count += 1;
241 }
242
243 count
244 }
245}
246
247impl<T: Record + std::fmt::Debug> std::fmt::Debug for Sink<T> {
248 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249 f.debug_struct("Sink")
250 .field("mode", &self.mode())
251 .field("subscriber_count", &self.subscriber_count())
252 .field("pending", &self.pending())
253 .field("is_disconnected", &self.is_disconnected())
254 .finish()
255 }
256}
257
258impl<T: Record> SinkInner<T> {
260 pub(crate) fn consumer(&self) -> &Consumer<SourceMessage<T>> {
261 &self.consumer
262 }
263
264 pub(crate) fn schema(&self) -> SchemaRef {
265 Arc::clone(&self.schema)
266 }
267
268 pub(crate) fn is_disconnected(&self) -> bool {
269 self.consumer.is_disconnected()
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use crate::streaming::source::create;
277 use arrow::array::{Float64Array, Int64Array, RecordBatch};
278 use arrow::datatypes::{DataType, Field, Schema};
279 use std::sync::Arc;
280
281 #[derive(Clone, Debug)]
282 struct TestEvent {
283 id: i64,
284 value: f64,
285 }
286
287 impl Record for TestEvent {
288 fn schema() -> SchemaRef {
289 Arc::new(Schema::new(vec![
290 Field::new("id", DataType::Int64, false),
291 Field::new("value", DataType::Float64, false),
292 ]))
293 }
294
295 fn to_record_batch(&self) -> RecordBatch {
296 RecordBatch::try_new(
297 Self::schema(),
298 vec![
299 Arc::new(Int64Array::from(vec![self.id])),
300 Arc::new(Float64Array::from(vec![self.value])),
301 ],
302 )
303 .unwrap()
304 }
305 }
306
307 #[test]
308 fn test_sink_creation() {
309 let (_source, sink) = create::<TestEvent>(16);
310
311 assert_eq!(sink.subscriber_count(), 0);
312 assert_eq!(sink.mode(), SinkMode::Single);
313 assert!(!sink.is_disconnected());
314 }
315
316 #[test]
317 fn test_single_subscriber() {
318 let (_source, sink) = create::<TestEvent>(16);
319
320 let _sub = sink.subscribe();
321
322 assert_eq!(sink.subscriber_count(), 1);
323 assert_eq!(sink.mode(), SinkMode::Single);
324 }
325
326 #[test]
327 fn test_broadcast_mode() {
328 let (_source, sink) = create::<TestEvent>(16);
329
330 let _sub1 = sink.subscribe();
331 let _sub2 = sink.subscribe();
332
333 assert_eq!(sink.subscriber_count(), 2);
334 assert_eq!(sink.mode(), SinkMode::Broadcast);
335 }
336
337 #[test]
338 fn test_schema() {
339 let (_source, sink) = create::<TestEvent>(16);
340
341 let schema = sink.schema();
342 assert_eq!(schema.fields().len(), 2);
343 assert_eq!(schema.field(0).name(), "id");
344 assert_eq!(schema.field(1).name(), "value");
345 }
346
347 #[test]
348 fn test_disconnected_on_source_drop() {
349 let (source, sink) = create::<TestEvent>(16);
350
351 assert!(!sink.is_disconnected());
352
353 drop(source);
354
355 assert!(sink.is_disconnected());
358 }
359
360 #[test]
361 fn test_pending() {
362 let (source, sink) = create::<TestEvent>(16);
363
364 assert_eq!(sink.pending(), 0);
365
366 source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
367 source.push(TestEvent { id: 2, value: 2.0 }).unwrap();
368
369 assert_eq!(sink.pending(), 2);
370 }
371
372 #[test]
373 fn test_debug_format() {
374 let (_source, sink) = create::<TestEvent>(16);
375
376 let debug = format!("{sink:?}");
377 assert!(debug.contains("Sink"));
378 assert!(debug.contains("Single"));
379 }
380}