1use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
31use std::sync::{Arc, OnceLock};
32
33use arrow::array::RecordBatch;
34use arrow::datatypes::SchemaRef;
35
36use super::channel::{channel_with_config, ChannelMode, Producer};
37use super::config::SourceConfig;
38use super::error::{StreamingError, TryPushError};
39use super::sink::Sink;
40
41pub trait Record: Send + Sized + 'static {
81 fn schema() -> SchemaRef;
83
84 fn to_record_batch(&self) -> RecordBatch;
88
89 fn event_time(&self) -> Option<i64> {
94 None
95 }
96
97 fn to_record_batch_from_iter<I>(records: I) -> RecordBatch
102 where
103 I: IntoIterator<Item = Self>,
104 {
105 let batches: Vec<RecordBatch> = records.into_iter().map(|r| r.to_record_batch()).collect();
106 if batches.is_empty() {
107 return RecordBatch::new_empty(Self::schema());
108 }
109 arrow::compute::concat_batches(&Self::schema(), &batches)
110 .unwrap_or_else(|_| RecordBatch::new_empty(Self::schema()))
111 }
112}
113
114pub(crate) enum SourceMessage<T> {
116 Record(T),
118
119 Batch(RecordBatch),
121
122 Watermark(i64),
124}
125
126struct SourceWatermark {
128 current: Arc<AtomicI64>,
132}
133
134impl SourceWatermark {
135 fn new() -> Self {
136 Self {
137 current: Arc::new(AtomicI64::new(i64::MIN)),
138 }
139 }
140
141 fn from_arc(arc: Arc<AtomicI64>) -> Self {
142 Self { current: arc }
143 }
144
145 fn update(&self, timestamp: i64) {
146 let mut current = self.current.load(Ordering::Acquire);
148 while timestamp > current {
149 match self.current.compare_exchange_weak(
150 current,
151 timestamp,
152 Ordering::AcqRel,
153 Ordering::Acquire,
154 ) {
155 Ok(_) => break,
156 Err(actual) => current = actual,
157 }
158 }
159 }
160
161 fn get(&self) -> i64 {
162 self.current.load(Ordering::Acquire)
163 }
164
165 fn arc(&self) -> Arc<AtomicI64> {
166 Arc::clone(&self.current)
167 }
168}
169
170struct SourceInner<T: Record> {
172 producer: Producer<SourceMessage<T>>,
174
175 watermark: SourceWatermark,
177
178 schema: SchemaRef,
180
181 name: Option<String>,
183
184 sequence: Arc<AtomicU64>,
187
188 event_time_column: OnceLock<String>,
191}
192
193pub struct Source<T: Record> {
219 inner: Arc<SourceInner<T>>,
220}
221
222impl<T: Record> Source<T> {
223 pub(crate) fn new(config: SourceConfig) -> (Self, Sink<T>) {
225 let channel_config = config.channel;
226 let (producer, consumer) = channel_with_config::<SourceMessage<T>>(channel_config.clone());
227
228 let schema = T::schema();
229
230 let inner = Arc::new(SourceInner {
231 producer,
232 watermark: SourceWatermark::new(),
233 schema: schema.clone(),
234 name: config.name,
235 sequence: Arc::new(AtomicU64::new(0)),
236 event_time_column: OnceLock::new(),
237 });
238
239 let source = Self { inner };
240 let sink = Sink::new(consumer, schema, channel_config);
241
242 (source, sink)
243 }
244
245 pub fn push(&self, record: T) -> Result<(), StreamingError> {
252 if let Some(event_time) = record.event_time() {
254 self.inner.watermark.update(event_time);
255 }
256
257 self.inner
258 .producer
259 .push(SourceMessage::Record(record))
260 .map_err(|_| StreamingError::ChannelFull)?;
261
262 self.inner.sequence.fetch_add(1, Ordering::Relaxed);
263 Ok(())
264 }
265
266 pub fn try_push(&self, record: T) -> Result<(), TryPushError<T>> {
272 if let Some(event_time) = record.event_time() {
274 self.inner.watermark.update(event_time);
275 }
276
277 self.inner
278 .producer
279 .try_push(SourceMessage::Record(record))
280 .map_err(|e| match e.into_inner() {
281 SourceMessage::Record(r) => TryPushError {
282 value: r,
283 error: StreamingError::ChannelFull,
284 },
285 _ => unreachable!("pushed a record, got something else back"),
286 })?;
287
288 self.inner.sequence.fetch_add(1, Ordering::Relaxed);
289 Ok(())
290 }
291
292 pub fn push_batch(&self, records: &[T]) -> usize
303 where
304 T: Clone,
305 {
306 let mut count = 0;
307 for record in records {
308 if self.try_push(record.clone()).is_err() {
309 break;
310 }
311 count += 1;
312 }
313 count
314 }
315
316 pub fn push_batch_drain<I>(&self, records: I) -> usize
328 where
329 I: IntoIterator<Item = T>,
330 {
331 let mut count = 0;
332 for record in records {
333 if let Some(event_time) = record.event_time() {
335 self.inner.watermark.update(event_time);
336 }
337
338 if self
339 .inner
340 .producer
341 .try_push(SourceMessage::Record(record))
342 .is_err()
343 {
344 break;
345 }
346 self.inner.sequence.fetch_add(1, Ordering::Relaxed);
347 count += 1;
348 }
349 count
350 }
351
352 pub fn push_arrow(&self, batch: RecordBatch) -> Result<(), StreamingError> {
362 if !self.inner.schema.fields().is_empty() && batch.schema() != self.inner.schema {
364 return Err(StreamingError::SchemaMismatch {
365 expected: self
366 .inner
367 .schema
368 .fields()
369 .iter()
370 .map(|f| f.name().clone())
371 .collect(),
372 actual: batch
373 .schema()
374 .fields()
375 .iter()
376 .map(|f| f.name().clone())
377 .collect(),
378 });
379 }
380
381 self.inner
382 .producer
383 .push(SourceMessage::Batch(batch))
384 .map_err(|_| StreamingError::ChannelFull)?;
385
386 self.inner.sequence.fetch_add(1, Ordering::Relaxed);
387 Ok(())
388 }
389
390 pub fn watermark(&self, timestamp: i64) {
399 self.inner.watermark.update(timestamp);
400
401 let _ = self
404 .inner
405 .producer
406 .try_push(SourceMessage::Watermark(timestamp));
407 }
408
409 #[must_use]
411 pub fn current_watermark(&self) -> i64 {
412 self.inner.watermark.get()
413 }
414
415 #[must_use]
417 pub fn schema(&self) -> SchemaRef {
418 Arc::clone(&self.inner.schema)
419 }
420
421 #[must_use]
423 pub fn name(&self) -> Option<&str> {
424 self.inner.name.as_deref()
425 }
426
427 #[must_use]
429 pub fn is_mpsc(&self) -> bool {
430 self.inner.producer.is_mpsc()
431 }
432
433 #[must_use]
435 pub fn mode(&self) -> ChannelMode {
436 self.inner.producer.mode()
437 }
438
439 #[must_use]
441 pub fn is_closed(&self) -> bool {
442 self.inner.producer.is_closed()
443 }
444
445 #[must_use]
447 pub fn pending(&self) -> usize {
448 self.inner.producer.len()
449 }
450
451 #[must_use]
453 pub fn capacity(&self) -> usize {
454 self.inner.producer.capacity()
455 }
456
457 #[must_use]
459 pub fn sequence(&self) -> u64 {
460 self.inner.sequence.load(Ordering::Acquire)
461 }
462
463 #[must_use]
465 pub fn sequence_counter(&self) -> Arc<AtomicU64> {
466 Arc::clone(&self.inner.sequence)
467 }
468
469 #[must_use]
471 pub fn watermark_atomic(&self) -> Arc<AtomicI64> {
472 self.inner.watermark.arc()
473 }
474
475 pub fn set_event_time_column(&self, column: &str) {
482 let _ = self.inner.event_time_column.set(column.to_owned());
483 }
484
485 #[must_use]
487 pub fn event_time_column(&self) -> Option<String> {
488 self.inner.event_time_column.get().cloned()
489 }
490}
491
492impl<T: Record> Clone for Source<T> {
493 fn clone(&self) -> Self {
504 let producer = self.inner.producer.clone();
506
507 let event_time_col = self.inner.event_time_column.get().cloned();
511 let event_time_column = OnceLock::new();
512 if let Some(col) = event_time_col {
513 let _ = event_time_column.set(col);
514 }
515 Self {
516 inner: Arc::new(SourceInner {
517 producer,
518 watermark: SourceWatermark::from_arc(self.inner.watermark.arc()),
519 schema: Arc::clone(&self.inner.schema),
520 name: self.inner.name.clone(),
521 sequence: Arc::clone(&self.inner.sequence),
522 event_time_column,
523 }),
524 }
525 }
526}
527
528impl<T: Record + std::fmt::Debug> std::fmt::Debug for Source<T> {
529 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
530 f.debug_struct("Source")
531 .field("name", &self.inner.name)
532 .field("mode", &self.mode())
533 .field("pending", &self.pending())
534 .field("capacity", &self.capacity())
535 .field("watermark", &self.current_watermark())
536 .finish()
537 }
538}
539
540#[must_use]
561pub fn create<T: Record>(buffer_size: usize) -> (Source<T>, Sink<T>) {
562 Source::new(SourceConfig::with_buffer_size(buffer_size))
563}
564
565#[must_use]
567pub fn create_with_config<T: Record>(config: SourceConfig) -> (Source<T>, Sink<T>) {
568 Source::new(config)
569}
570
571#[cfg(test)]
572mod tests {
573 use super::*;
574 use arrow::array::{Float64Array, Int64Array, StringArray};
575 use arrow::datatypes::{DataType, Field, Schema};
576 use std::sync::Arc;
577
578 #[derive(Clone, Debug)]
580 struct TestEvent {
581 id: i64,
582 value: f64,
583 timestamp: i64,
584 }
585
586 impl Record for TestEvent {
587 fn schema() -> SchemaRef {
588 Arc::new(Schema::new(vec![
589 Field::new("id", DataType::Int64, false),
590 Field::new("value", DataType::Float64, false),
591 Field::new("timestamp", DataType::Int64, false),
592 ]))
593 }
594
595 fn to_record_batch(&self) -> RecordBatch {
596 RecordBatch::try_new(
597 Self::schema(),
598 vec![
599 Arc::new(Int64Array::from(vec![self.id])),
600 Arc::new(Float64Array::from(vec![self.value])),
601 Arc::new(Int64Array::from(vec![self.timestamp])),
602 ],
603 )
604 .unwrap()
605 }
606
607 fn event_time(&self) -> Option<i64> {
608 Some(self.timestamp)
609 }
610 }
611
612 #[test]
613 fn test_create_source_sink() {
614 let (source, _sink) = create::<TestEvent>(1024);
615
616 assert!(!source.is_mpsc());
617 assert!(!source.is_closed());
618 assert_eq!(source.pending(), 0);
619 }
620
621 #[test]
622 fn test_push_single() {
623 let (source, _sink) = create::<TestEvent>(16);
624
625 let event = TestEvent {
626 id: 1,
627 value: 42.0,
628 timestamp: 1000,
629 };
630
631 assert!(source.push(event).is_ok());
632 assert_eq!(source.pending(), 1);
633 }
634
635 #[test]
636 fn test_try_push() {
637 let (source, _sink) = create::<TestEvent>(16);
638
639 let event = TestEvent {
640 id: 1,
641 value: 42.0,
642 timestamp: 1000,
643 };
644
645 assert!(source.try_push(event).is_ok());
646 }
647
648 #[test]
649 fn test_push_batch() {
650 let (source, _sink) = create::<TestEvent>(16);
651
652 let events = vec![
653 TestEvent {
654 id: 1,
655 value: 1.0,
656 timestamp: 1000,
657 },
658 TestEvent {
659 id: 2,
660 value: 2.0,
661 timestamp: 2000,
662 },
663 TestEvent {
664 id: 3,
665 value: 3.0,
666 timestamp: 3000,
667 },
668 ];
669
670 let count = source.push_batch(&events);
671 assert_eq!(count, 3);
672 assert_eq!(source.pending(), 3);
673 }
674
675 #[test]
676 fn test_push_arrow() {
677 let (source, _sink) = create::<TestEvent>(16);
678
679 let batch = RecordBatch::try_new(
680 TestEvent::schema(),
681 vec![
682 Arc::new(Int64Array::from(vec![1, 2, 3])),
683 Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0])),
684 Arc::new(Int64Array::from(vec![1000, 2000, 3000])),
685 ],
686 )
687 .unwrap();
688
689 assert!(source.push_arrow(batch).is_ok());
690 }
691
692 #[test]
693 fn test_push_arrow_schema_mismatch() {
694 let (source, _sink) = create::<TestEvent>(16);
695
696 let wrong_schema = Arc::new(Schema::new(vec![Field::new(
698 "wrong",
699 DataType::Utf8,
700 false,
701 )]));
702
703 let batch = RecordBatch::try_new(
704 wrong_schema,
705 vec![Arc::new(StringArray::from(vec!["test"]))],
706 )
707 .unwrap();
708
709 let result = source.push_arrow(batch);
710 assert!(matches!(result, Err(StreamingError::SchemaMismatch { .. })));
711 }
712
713 #[test]
714 fn test_watermark() {
715 let (source, _sink) = create::<TestEvent>(16);
716
717 assert_eq!(source.current_watermark(), i64::MIN);
718
719 source.watermark(1000);
720 assert_eq!(source.current_watermark(), 1000);
721
722 source.watermark(2000);
723 assert_eq!(source.current_watermark(), 2000);
724
725 source.watermark(1500);
727 assert_eq!(source.current_watermark(), 2000);
728 }
729
730 #[test]
731 fn test_watermark_from_event_time() {
732 let (source, _sink) = create::<TestEvent>(16);
733
734 let event = TestEvent {
735 id: 1,
736 value: 42.0,
737 timestamp: 5000,
738 };
739
740 source.push(event).unwrap();
741
742 assert_eq!(source.current_watermark(), 5000);
744 }
745
746 #[test]
747 fn test_clone_upgrades_to_mpsc() {
748 let (source, _sink) = create::<TestEvent>(16);
749
750 assert!(!source.is_mpsc());
751 assert_eq!(source.mode(), ChannelMode::Spsc);
752
753 let source2 = source.clone();
754
755 assert!(source.is_mpsc());
756 assert!(source2.is_mpsc());
757 }
758
759 #[test]
760 fn test_closed_on_sink_drop() {
761 let (source, sink) = create::<TestEvent>(16);
762
763 assert!(!source.is_closed());
764
765 drop(sink);
766
767 assert!(source.is_closed());
768 }
769
770 #[test]
771 fn test_schema() {
772 let (source, _sink) = create::<TestEvent>(16);
773
774 let schema = source.schema();
775 assert_eq!(schema.fields().len(), 3);
776 assert_eq!(schema.field(0).name(), "id");
777 assert_eq!(schema.field(1).name(), "value");
778 assert_eq!(schema.field(2).name(), "timestamp");
779 }
780
781 #[test]
782 fn test_named_source() {
783 let config = SourceConfig::named("my_source");
784 let (source, _sink) = create_with_config::<TestEvent>(config);
785
786 assert_eq!(source.name(), Some("my_source"));
787 }
788
789 #[test]
790 fn test_debug_format() {
791 let (source, _sink) = create::<TestEvent>(16);
792
793 let debug = format!("{source:?}");
794 assert!(debug.contains("Source"));
795 assert!(debug.contains("Spsc"));
796 }
797
798 #[test]
799 fn test_set_event_time_column() {
800 let (source, _sink) = create::<TestEvent>(16);
801
802 assert!(source.event_time_column().is_none());
803
804 source.set_event_time_column("timestamp");
805 assert_eq!(source.event_time_column(), Some("timestamp".to_string()));
806 }
807
808 #[test]
809 fn test_event_time_column_preserved_on_clone() {
810 let (source, _sink) = create::<TestEvent>(16);
811 source.set_event_time_column("ts");
812
813 let source2 = source.clone();
814 assert_eq!(source2.event_time_column(), Some("ts".to_string()));
815 }
816}