1use std::sync::Arc;
33use std::time::{Duration, Instant};
34
35use arrow::array::RecordBatch;
36use arrow::datatypes::SchemaRef;
37
38use super::channel::Consumer;
39use super::error::RecvError;
40use super::sink::SinkInner;
41use super::source::{Record, SourceMessage};
42
43pub struct Subscription<T: Record> {
53 inner: SubscriptionInner<T>,
54 schema: SchemaRef,
55}
56
57enum SubscriptionInner<T: Record> {
58 Direct(Arc<SinkInner<T>>),
60 Broadcast(Consumer<SourceMessage<T>>),
62}
63
64impl<T: Record> Subscription<T> {
65 pub(crate) fn new_direct(sink_inner: Arc<SinkInner<T>>) -> Self {
67 let schema = sink_inner.schema();
68 Self {
69 inner: SubscriptionInner::Direct(sink_inner),
70 schema,
71 }
72 }
73
74 pub(crate) fn new_broadcast(consumer: Consumer<SourceMessage<T>>, schema: SchemaRef) -> Self {
76 Self {
77 inner: SubscriptionInner::Broadcast(consumer),
78 schema,
79 }
80 }
81
82 #[must_use]
88 pub fn poll(&self) -> Option<RecordBatch> {
89 let msg = match &self.inner {
90 SubscriptionInner::Direct(sink) => sink.consumer().poll(),
91 SubscriptionInner::Broadcast(consumer) => consumer.poll(),
92 }?;
93
94 Self::message_to_batch(msg)
95 }
96
97 #[must_use]
101 pub fn poll_message(&self) -> Option<SubscriptionMessage<T>> {
102 let msg = match &self.inner {
103 SubscriptionInner::Direct(sink) => sink.consumer().poll(),
104 SubscriptionInner::Broadcast(consumer) => consumer.poll(),
105 }?;
106
107 Some(Self::convert_message(msg))
108 }
109
110 pub fn recv(&self) -> Result<RecordBatch, RecvError> {
117 let mut spins = 0u32;
118 loop {
119 if let Some(batch) = self.poll() {
120 return Ok(batch);
121 }
122
123 if self.is_disconnected() {
124 return Err(RecvError::Disconnected);
125 }
126
127 if spins < 64 {
129 std::hint::spin_loop();
130 } else if spins < 128 {
131 std::thread::yield_now();
132 } else {
133 std::thread::park_timeout(Duration::from_micros(100));
134 }
135 spins = spins.saturating_add(1);
136 }
137 }
138
139 pub fn recv_timeout(&self, timeout: Duration) -> Result<RecordBatch, RecvError> {
146 let deadline = Instant::now() + timeout;
147 let mut spins = 0u32;
148
149 loop {
150 if let Some(batch) = self.poll() {
151 return Ok(batch);
152 }
153
154 if self.is_disconnected() {
155 return Err(RecvError::Disconnected);
156 }
157
158 if Instant::now() >= deadline {
159 return Err(RecvError::Timeout);
160 }
161
162 if spins < 64 {
164 std::hint::spin_loop();
165 } else if spins < 128 {
166 std::thread::yield_now();
167 } else {
168 let remaining = deadline.saturating_duration_since(Instant::now());
169 std::thread::park_timeout(remaining.min(Duration::from_micros(100)));
170 }
171 spins = spins.saturating_add(1);
172 }
173 }
174
175 #[cold]
185 #[must_use]
186 pub fn poll_batch(&self, max_count: usize) -> Vec<RecordBatch> {
187 let mut batches = Vec::with_capacity(max_count);
188
189 for _ in 0..max_count {
190 if let Some(batch) = self.poll() {
191 batches.push(batch);
192 } else {
193 break;
194 }
195 }
196
197 batches
198 }
199
200 pub fn poll_batch_into(&self, buffer: &mut Vec<RecordBatch>, max_count: usize) -> usize {
219 let mut count = 0;
220
221 for _ in 0..max_count {
222 if let Some(batch) = self.poll() {
223 buffer.push(batch);
224 count += 1;
225 } else {
226 break;
227 }
228 }
229
230 count
231 }
232
233 pub fn poll_each<F>(&self, max_count: usize, mut f: F) -> usize
242 where
243 F: FnMut(RecordBatch) -> bool,
244 {
245 let mut count = 0;
246
247 for _ in 0..max_count {
248 if let Some(batch) = self.poll() {
249 count += 1;
250 if !f(batch) {
251 break;
252 }
253 } else {
254 break;
255 }
256 }
257
258 count
259 }
260
261 #[must_use]
263 pub fn is_disconnected(&self) -> bool {
264 match &self.inner {
265 SubscriptionInner::Direct(sink) => sink.is_disconnected(),
266 SubscriptionInner::Broadcast(consumer) => consumer.is_disconnected(),
267 }
268 }
269
270 #[must_use]
272 pub fn pending(&self) -> usize {
273 match &self.inner {
274 SubscriptionInner::Direct(sink) => sink.consumer().len(),
275 SubscriptionInner::Broadcast(consumer) => consumer.len(),
276 }
277 }
278
279 #[must_use]
281 pub fn schema(&self) -> SchemaRef {
282 Arc::clone(&self.schema)
283 }
284
285 fn message_to_batch(msg: SourceMessage<T>) -> Option<RecordBatch> {
286 match msg {
287 SourceMessage::Record(record) => Some(record.to_record_batch()),
288 SourceMessage::Batch(batch) => Some(batch),
289 SourceMessage::Watermark(_) => {
290 None
292 }
293 }
294 }
295
296 fn convert_message(msg: SourceMessage<T>) -> SubscriptionMessage<T> {
297 match msg {
298 SourceMessage::Record(record) => SubscriptionMessage::Record(record),
299 SourceMessage::Batch(batch) => SubscriptionMessage::Batch(batch),
300 SourceMessage::Watermark(ts) => SubscriptionMessage::Watermark(ts),
301 }
302 }
303}
304
305#[derive(Debug)]
307pub enum SubscriptionMessage<T> {
308 Record(T),
310 Batch(RecordBatch),
312 Watermark(i64),
314}
315
316impl<T: Record> SubscriptionMessage<T> {
317 #[must_use]
319 pub fn is_record(&self) -> bool {
320 matches!(self, Self::Record(_))
321 }
322
323 #[must_use]
325 pub fn is_batch(&self) -> bool {
326 matches!(self, Self::Batch(_))
327 }
328
329 #[must_use]
331 pub fn is_watermark(&self) -> bool {
332 matches!(self, Self::Watermark(_))
333 }
334
335 #[must_use]
337 pub fn to_batch(self) -> Option<RecordBatch> {
338 match self {
339 Self::Record(r) => Some(r.to_record_batch()),
340 Self::Batch(b) => Some(b),
341 Self::Watermark(_) => None,
342 }
343 }
344
345 #[must_use]
347 pub fn watermark(&self) -> Option<i64> {
348 match self {
349 Self::Watermark(ts) => Some(*ts),
350 _ => None,
351 }
352 }
353}
354
355impl<T: Record> Iterator for Subscription<T> {
360 type Item = RecordBatch;
361
362 fn next(&mut self) -> Option<Self::Item> {
363 self.recv().ok()
364 }
365}
366
367impl<T: Record + std::fmt::Debug> std::fmt::Debug for Subscription<T> {
368 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
369 let mode = match &self.inner {
370 SubscriptionInner::Direct(_) => "Direct",
371 SubscriptionInner::Broadcast(_) => "Broadcast",
372 };
373
374 f.debug_struct("Subscription")
375 .field("mode", &mode)
376 .field("pending", &self.pending())
377 .field("is_disconnected", &self.is_disconnected())
378 .field("schema", &self.schema)
379 .finish()
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386 use crate::streaming::source::create;
387 use arrow::array::{Float64Array, Int64Array};
388 use arrow::datatypes::{DataType, Field, Schema};
389 use std::sync::Arc;
390
391 #[derive(Clone, Debug)]
392 struct TestEvent {
393 id: i64,
394 value: f64,
395 }
396
397 impl Record for TestEvent {
398 fn schema() -> SchemaRef {
399 Arc::new(Schema::new(vec![
400 Field::new("id", DataType::Int64, false),
401 Field::new("value", DataType::Float64, false),
402 ]))
403 }
404
405 fn to_record_batch(&self) -> RecordBatch {
406 RecordBatch::try_new(
407 Self::schema(),
408 vec![
409 Arc::new(Int64Array::from(vec![self.id])),
410 Arc::new(Float64Array::from(vec![self.value])),
411 ],
412 )
413 .unwrap()
414 }
415 }
416
417 #[test]
418 fn test_poll_empty() {
419 let (_source, sink) = create::<TestEvent>(16);
420 let sub = sink.subscribe();
421
422 assert!(sub.poll().is_none());
423 }
424
425 #[test]
426 fn test_poll_records() {
427 let (source, sink) = create::<TestEvent>(16);
428 let sub = sink.subscribe();
429
430 source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
431 source.push(TestEvent { id: 2, value: 2.0 }).unwrap();
432
433 let batch1 = sub.poll().unwrap();
434 assert_eq!(batch1.num_rows(), 1);
435
436 let batch2 = sub.poll().unwrap();
437 assert_eq!(batch2.num_rows(), 1);
438
439 assert!(sub.poll().is_none());
440 }
441
442 #[test]
443 fn test_poll_message() {
444 let (source, sink) = create::<TestEvent>(16);
445 let sub = sink.subscribe();
446
447 source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
448
449 let msg = sub.poll_message().unwrap();
450 assert!(msg.is_record());
451 }
452
453 #[test]
454 fn test_recv_timeout() {
455 let (_source, sink) = create::<TestEvent>(16);
456 let sub = sink.subscribe();
457
458 let result = sub.recv_timeout(Duration::from_millis(10));
460 assert!(matches!(result, Err(RecvError::Timeout)));
461 }
462
463 #[test]
464 fn test_recv_timeout_success() {
465 let (source, sink) = create::<TestEvent>(16);
466 let sub = sink.subscribe();
467
468 source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
469
470 let result = sub.recv_timeout(Duration::from_secs(1));
471 assert!(result.is_ok());
472 }
473
474 #[test]
475 fn test_poll_batch() {
476 let (source, sink) = create::<TestEvent>(16);
477 let sub = sink.subscribe();
478
479 source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
480 source.push(TestEvent { id: 2, value: 2.0 }).unwrap();
481 source.push(TestEvent { id: 3, value: 3.0 }).unwrap();
482
483 let batches = sub.poll_batch(10);
484 assert_eq!(batches.len(), 3);
485 }
486
487 #[test]
488 fn test_poll_each() {
489 let (source, sink) = create::<TestEvent>(16);
490 let sub = sink.subscribe();
491
492 source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
493 source.push(TestEvent { id: 2, value: 2.0 }).unwrap();
494
495 let mut total_rows = 0;
496 let count = sub.poll_each(10, |batch| {
497 total_rows += batch.num_rows();
498 true
499 });
500
501 assert_eq!(count, 2);
502 assert_eq!(total_rows, 2);
503 }
504
505 #[test]
506 fn test_poll_each_early_stop() {
507 let (source, sink) = create::<TestEvent>(16);
508 let sub = sink.subscribe();
509
510 source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
511 source.push(TestEvent { id: 2, value: 2.0 }).unwrap();
512 source.push(TestEvent { id: 3, value: 3.0 }).unwrap();
513
514 let mut seen = 0;
515 let count = sub.poll_each(10, |_| {
516 seen += 1;
517 seen < 2 });
519
520 assert_eq!(count, 2);
521 assert_eq!(seen, 2);
522 assert_eq!(sub.pending(), 1); }
524
525 #[test]
526 fn test_disconnected() {
527 let (source, sink) = create::<TestEvent>(16);
528 let sub = sink.subscribe();
529
530 assert!(!sub.is_disconnected());
531
532 drop(source);
533
534 assert!(sub.is_disconnected());
535 }
536
537 #[test]
538 fn test_pending() {
539 let (source, sink) = create::<TestEvent>(16);
540 let sub = sink.subscribe();
541
542 assert_eq!(sub.pending(), 0);
543
544 source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
545 source.push(TestEvent { id: 2, value: 2.0 }).unwrap();
546
547 assert_eq!(sub.pending(), 2);
548 }
549
550 #[test]
551 fn test_schema() {
552 let (_source, sink) = create::<TestEvent>(16);
553 let sub = sink.subscribe();
554
555 let schema = sub.schema();
556 assert_eq!(schema.fields().len(), 2);
557 }
558
559 #[test]
560 fn test_subscription_message() {
561 let msg = SubscriptionMessage::Record(TestEvent { id: 1, value: 1.0 });
562 assert!(msg.is_record());
563 assert!(!msg.is_batch());
564 assert!(!msg.is_watermark());
565
566 let batch = msg.to_batch().unwrap();
567 assert_eq!(batch.num_rows(), 1);
568
569 let wm = SubscriptionMessage::<TestEvent>::Watermark(1000);
570 assert!(wm.is_watermark());
571 assert_eq!(wm.watermark(), Some(1000));
572 }
573
574 #[test]
575 fn test_iterator() {
576 let (source, sink) = create::<TestEvent>(16);
577 let mut sub = sink.subscribe();
578
579 source.push(TestEvent { id: 1, value: 1.0 }).unwrap();
580 source.push(TestEvent { id: 2, value: 2.0 }).unwrap();
581
582 drop(source);
583
584 let batches: Vec<_> = sub.by_ref().collect();
585 assert_eq!(batches.len(), 2);
586 }
587
588 #[test]
589 fn test_debug_format() {
590 let (_source, sink) = create::<TestEvent>(16);
591 let sub = sink.subscribe();
592
593 let debug = format!("{sub:?}");
594 assert!(debug.contains("Subscription"));
595 assert!(debug.contains("Direct"));
596 }
597}