1use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
4use std::sync::{Arc, OnceLock};
5use std::time::Duration;
6
7use arrow::array::RecordBatch;
8use arrow::datatypes::SchemaRef;
9
10use super::channel::{channel_with_config, Producer};
11use super::config::SourceConfig;
12use super::error::{StreamingError, TryPushError};
13use super::sink::Sink;
14
15pub trait Record: Clone + Send + Sized + 'static {
17 fn schema() -> SchemaRef;
19
20 fn to_record_batch(&self) -> RecordBatch;
24
25 fn event_time(&self) -> Option<i64> {
30 None
31 }
32
33 fn to_record_batch_from_iter<I>(records: I) -> RecordBatch
38 where
39 I: IntoIterator<Item = Self>,
40 {
41 let batches: Vec<RecordBatch> = records.into_iter().map(|r| r.to_record_batch()).collect();
42 if batches.is_empty() {
43 return RecordBatch::new_empty(Self::schema());
44 }
45 arrow::compute::concat_batches(&Self::schema(), &batches)
46 .unwrap_or_else(|_| RecordBatch::new_empty(Self::schema()))
47 }
48}
49
50#[derive(Clone)]
52pub(crate) enum SourceMessage<T> {
53 Record(T),
55
56 Batch(RecordBatch),
58}
59
60struct SourceWatermark {
62 current: Arc<AtomicI64>,
66}
67
68impl SourceWatermark {
69 fn new() -> Self {
70 Self {
71 current: Arc::new(AtomicI64::new(i64::MIN)),
72 }
73 }
74
75 fn from_arc(arc: Arc<AtomicI64>) -> Self {
76 Self { current: arc }
77 }
78
79 fn update(&self, timestamp: i64) {
80 let mut current = self.current.load(Ordering::Acquire);
82 while timestamp > current {
83 match self.current.compare_exchange_weak(
84 current,
85 timestamp,
86 Ordering::AcqRel,
87 Ordering::Acquire,
88 ) {
89 Ok(_) => break,
90 Err(actual) => current = actual,
91 }
92 }
93 }
94
95 fn get(&self) -> i64 {
96 self.current.load(Ordering::Acquire)
97 }
98
99 fn arc(&self) -> Arc<AtomicI64> {
100 Arc::clone(&self.current)
101 }
102}
103
104struct SourceInner<T: Record> {
106 producer: Producer<SourceMessage<T>>,
108
109 watermark: SourceWatermark,
111
112 schema: SchemaRef,
114
115 name: Option<String>,
117
118 sequence: Arc<AtomicU64>,
121
122 event_time_column: OnceLock<String>,
125
126 max_out_of_orderness: OnceLock<Duration>,
129}
130
131pub struct Source<T: Record> {
133 inner: Arc<SourceInner<T>>,
134}
135
136impl<T: Record> Source<T> {
137 pub(crate) fn new(config: SourceConfig) -> (Self, Sink<T>) {
139 let channel_config = config.channel;
140 let (producer, consumer) = channel_with_config::<SourceMessage<T>>(&channel_config);
141
142 let schema = T::schema();
143
144 let inner = Arc::new(SourceInner {
145 producer,
146 watermark: SourceWatermark::new(),
147 schema: schema.clone(),
148 name: config.name,
149 sequence: Arc::new(AtomicU64::new(0)),
150 event_time_column: OnceLock::new(),
151 max_out_of_orderness: OnceLock::new(),
152 });
153
154 let source = Self { inner };
155 let sink = Sink::new(consumer, schema);
156
157 (source, sink)
158 }
159
160 pub fn push(&self, record: T) -> Result<(), StreamingError> {
166 if let Some(event_time) = record.event_time() {
167 self.inner.watermark.update(event_time);
168 }
169
170 self.inner
171 .producer
172 .push(SourceMessage::Record(record))
173 .map_err(|_| StreamingError::ChannelFull)?;
174
175 self.inner.sequence.fetch_add(1, Ordering::Relaxed);
176 Ok(())
177 }
178
179 pub fn try_push(&self, record: T) -> Result<(), TryPushError<T>> {
185 if let Some(event_time) = record.event_time() {
186 self.inner.watermark.update(event_time);
187 }
188
189 self.inner
190 .producer
191 .push(SourceMessage::Record(record))
192 .map_err(|msg| match msg {
193 SourceMessage::Record(r) => TryPushError {
194 value: r,
195 error: StreamingError::ChannelFull,
196 },
197 SourceMessage::Batch(_) => unreachable!("only Record is pushed here"),
198 })?;
199
200 self.inner.sequence.fetch_add(1, Ordering::Relaxed);
201 Ok(())
202 }
203
204 pub fn push_batch(&self, records: &[T]) -> usize
206 where
207 T: Clone,
208 {
209 self.push_batch_drain(records.iter().cloned())
210 }
211
212 pub fn push_batch_drain<I>(&self, records: I) -> usize
215 where
216 I: IntoIterator<Item = T>,
217 {
218 let mut count = 0;
219 for record in records {
220 if self.push(record).is_err() {
221 break;
222 }
223 count += 1;
224 }
225 count
226 }
227
228 pub fn push_arrow(&self, batch: RecordBatch) -> Result<(), StreamingError> {
238 if !self.inner.schema.fields().is_empty() && batch.schema() != self.inner.schema {
240 return Err(StreamingError::SchemaMismatch {
241 expected: self
242 .inner
243 .schema
244 .fields()
245 .iter()
246 .map(|f| f.name().clone())
247 .collect(),
248 actual: batch
249 .schema()
250 .fields()
251 .iter()
252 .map(|f| f.name().clone())
253 .collect(),
254 });
255 }
256
257 self.inner
258 .producer
259 .push(SourceMessage::Batch(batch))
260 .map_err(|_| StreamingError::ChannelFull)?;
261
262 self.inner.sequence.fetch_add(1, Ordering::Relaxed);
263 Ok(())
264 }
265
266 pub fn watermark(&self, timestamp: i64) {
275 self.inner.watermark.update(timestamp);
280 }
281
282 #[must_use]
284 pub fn current_watermark(&self) -> i64 {
285 self.inner.watermark.get()
286 }
287
288 #[must_use]
290 pub fn schema(&self) -> SchemaRef {
291 Arc::clone(&self.inner.schema)
292 }
293
294 #[must_use]
296 pub fn name(&self) -> Option<&str> {
297 self.inner.name.as_deref()
298 }
299
300 #[must_use]
302 pub fn is_closed(&self) -> bool {
303 self.inner.producer.is_closed()
304 }
305
306 #[must_use]
308 pub fn pending(&self) -> usize {
309 self.inner.producer.len()
310 }
311
312 #[must_use]
314 pub fn capacity(&self) -> usize {
315 self.inner.producer.capacity()
316 }
317
318 #[must_use]
320 pub fn sequence(&self) -> u64 {
321 self.inner.sequence.load(Ordering::Acquire)
322 }
323
324 #[must_use]
326 pub fn sequence_counter(&self) -> Arc<AtomicU64> {
327 Arc::clone(&self.inner.sequence)
328 }
329
330 #[must_use]
332 pub fn watermark_atomic(&self) -> Arc<AtomicI64> {
333 self.inner.watermark.arc()
334 }
335
336 pub fn set_event_time_column(&self, column: &str) {
343 let _ = self.inner.event_time_column.set(column.to_owned());
344 }
345
346 #[must_use]
348 pub fn event_time_column(&self) -> Option<String> {
349 self.inner.event_time_column.get().cloned()
350 }
351
352 pub fn set_max_out_of_orderness(&self, dur: Duration) {
356 let _ = self.inner.max_out_of_orderness.set(dur);
357 }
358
359 #[must_use]
361 pub fn max_out_of_orderness(&self) -> Option<Duration> {
362 self.inner.max_out_of_orderness.get().copied()
363 }
364}
365
366impl<T: Record> Clone for Source<T> {
367 fn clone(&self) -> Self {
368 let producer = self.inner.producer.clone();
369 let event_time_col = self.inner.event_time_column.get().cloned();
370 let event_time_column = OnceLock::new();
371 if let Some(col) = event_time_col {
372 let _ = event_time_column.set(col);
373 }
374 let max_ooo = self.inner.max_out_of_orderness.get().copied();
375 let max_out_of_orderness = OnceLock::new();
376 if let Some(dur) = max_ooo {
377 let _ = max_out_of_orderness.set(dur);
378 }
379 Self {
380 inner: Arc::new(SourceInner {
381 producer,
382 watermark: SourceWatermark::from_arc(self.inner.watermark.arc()),
383 schema: Arc::clone(&self.inner.schema),
384 name: self.inner.name.clone(),
385 sequence: Arc::clone(&self.inner.sequence),
386 event_time_column,
387 max_out_of_orderness,
388 }),
389 }
390 }
391}
392
393impl<T: Record + std::fmt::Debug> std::fmt::Debug for Source<T> {
394 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
395 f.debug_struct("Source")
396 .field("name", &self.inner.name)
397 .field("pending", &self.pending())
398 .field("capacity", &self.capacity())
399 .field("watermark", &self.current_watermark())
400 .finish()
401 }
402}
403
404#[must_use]
406pub fn create<T: Record>(buffer_size: usize) -> (Source<T>, Sink<T>) {
407 Source::new(SourceConfig::with_buffer_size(buffer_size))
408}
409
410#[must_use]
412pub fn create_with_config<T: Record>(config: SourceConfig) -> (Source<T>, Sink<T>) {
413 Source::new(config)
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419 use arrow::array::{Float64Array, Int64Array, StringArray};
420 use arrow::datatypes::{DataType, Field, Schema};
421 use std::sync::Arc;
422
423 #[derive(Clone, Debug)]
425 struct TestEvent {
426 id: i64,
427 value: f64,
428 timestamp: i64,
429 }
430
431 impl Record for TestEvent {
432 fn schema() -> SchemaRef {
433 Arc::new(Schema::new(vec![
434 Field::new("id", DataType::Int64, false),
435 Field::new("value", DataType::Float64, false),
436 Field::new("timestamp", DataType::Int64, false),
437 ]))
438 }
439
440 fn to_record_batch(&self) -> RecordBatch {
441 RecordBatch::try_new(
442 Self::schema(),
443 vec![
444 Arc::new(Int64Array::from(vec![self.id])),
445 Arc::new(Float64Array::from(vec![self.value])),
446 Arc::new(Int64Array::from(vec![self.timestamp])),
447 ],
448 )
449 .unwrap()
450 }
451
452 fn event_time(&self) -> Option<i64> {
453 Some(self.timestamp)
454 }
455 }
456
457 #[tokio::test]
458 async fn test_create_source_sink() {
459 let (source, _sink) = create::<TestEvent>(1024);
460
461 assert!(!source.is_closed());
462 assert_eq!(source.pending(), 0);
463 }
464
465 #[tokio::test]
466 async fn test_push_single() {
467 let (source, _sink) = create::<TestEvent>(16);
468
469 let event = TestEvent {
470 id: 1,
471 value: 42.0,
472 timestamp: 1000,
473 };
474
475 assert!(source.push(event).is_ok());
476 assert_eq!(source.pending(), 1);
477 }
478
479 #[tokio::test]
480 async fn test_try_push() {
481 let (source, _sink) = create::<TestEvent>(16);
482
483 let event = TestEvent {
484 id: 1,
485 value: 42.0,
486 timestamp: 1000,
487 };
488
489 assert!(source.try_push(event).is_ok());
490 }
491
492 #[tokio::test]
493 async fn test_push_batch() {
494 let (source, _sink) = create::<TestEvent>(16);
495
496 let events = vec![
497 TestEvent {
498 id: 1,
499 value: 1.0,
500 timestamp: 1000,
501 },
502 TestEvent {
503 id: 2,
504 value: 2.0,
505 timestamp: 2000,
506 },
507 TestEvent {
508 id: 3,
509 value: 3.0,
510 timestamp: 3000,
511 },
512 ];
513
514 let count = source.push_batch(&events);
515 assert_eq!(count, 3);
516 assert_eq!(source.pending(), 3);
517 }
518
519 #[tokio::test]
520 async fn test_push_arrow() {
521 let (source, _sink) = create::<TestEvent>(16);
522
523 let batch = RecordBatch::try_new(
524 TestEvent::schema(),
525 vec![
526 Arc::new(Int64Array::from(vec![1, 2, 3])),
527 Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0])),
528 Arc::new(Int64Array::from(vec![1000, 2000, 3000])),
529 ],
530 )
531 .unwrap();
532
533 assert!(source.push_arrow(batch).is_ok());
534 }
535
536 #[tokio::test]
537 async fn test_push_arrow_schema_mismatch() {
538 let (source, _sink) = create::<TestEvent>(16);
539
540 let wrong_schema = Arc::new(Schema::new(vec![Field::new(
542 "wrong",
543 DataType::Utf8,
544 false,
545 )]));
546
547 let batch = RecordBatch::try_new(
548 wrong_schema,
549 vec![Arc::new(StringArray::from(vec!["test"]))],
550 )
551 .unwrap();
552
553 let result = source.push_arrow(batch);
554 assert!(matches!(result, Err(StreamingError::SchemaMismatch { .. })));
555 }
556
557 #[tokio::test]
558 async fn test_watermark() {
559 let (source, _sink) = create::<TestEvent>(16);
560
561 assert_eq!(source.current_watermark(), i64::MIN);
562
563 source.watermark(1000);
564 assert_eq!(source.current_watermark(), 1000);
565
566 source.watermark(2000);
567 assert_eq!(source.current_watermark(), 2000);
568
569 source.watermark(1500);
571 assert_eq!(source.current_watermark(), 2000);
572 }
573
574 #[tokio::test]
575 async fn test_watermark_from_event_time() {
576 let (source, _sink) = create::<TestEvent>(16);
577
578 let event = TestEvent {
579 id: 1,
580 value: 42.0,
581 timestamp: 5000,
582 };
583
584 source.push(event).unwrap();
585
586 assert_eq!(source.current_watermark(), 5000);
588 }
589
590 #[tokio::test]
591 async fn test_clone_multi_producer() {
592 let (source, sink) = create::<TestEvent>(16);
593 let source2 = source.clone();
594 let mut sub = sink.subscribe(); source
597 .push(TestEvent {
598 id: 1,
599 value: 1.0,
600 timestamp: 1000,
601 })
602 .unwrap();
603 source2
604 .push(TestEvent {
605 id: 2,
606 value: 2.0,
607 timestamp: 2000,
608 })
609 .unwrap();
610
611 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
612 assert!(sub.poll().is_some());
613 assert!(sub.poll().is_some());
614 }
615
616 #[tokio::test]
617 async fn test_schema() {
618 let (source, _sink) = create::<TestEvent>(16);
619
620 let schema = source.schema();
621 assert_eq!(schema.fields().len(), 3);
622 assert_eq!(schema.field(0).name(), "id");
623 assert_eq!(schema.field(1).name(), "value");
624 assert_eq!(schema.field(2).name(), "timestamp");
625 }
626
627 #[tokio::test]
628 async fn test_named_source() {
629 let config = SourceConfig::named("my_source");
630 let (source, _sink) = create_with_config::<TestEvent>(config);
631
632 assert_eq!(source.name(), Some("my_source"));
633 }
634
635 #[tokio::test]
636 async fn test_debug_format() {
637 let (source, _sink) = create::<TestEvent>(16);
638
639 let debug = format!("{source:?}");
640 assert!(debug.contains("Source"));
641 }
642
643 #[tokio::test]
644 async fn test_set_event_time_column() {
645 let (source, _sink) = create::<TestEvent>(16);
646
647 assert!(source.event_time_column().is_none());
648
649 source.set_event_time_column("timestamp");
650 assert_eq!(source.event_time_column(), Some("timestamp".to_string()));
651 }
652
653 #[tokio::test]
654 async fn test_event_time_column_preserved_on_clone() {
655 let (source, _sink) = create::<TestEvent>(16);
656 source.set_event_time_column("ts");
657
658 let source2 = source.clone();
659 assert_eq!(source2.event_time_column(), Some("ts".to_string()));
660 }
661}