1use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
4use std::sync::{Arc, OnceLock};
5use std::time::Duration;
6
7use arrow::array::RecordBatch;
8use arrow::datatypes::SchemaRef;
9
10use super::channel::{channel_with_config, Producer};
11use super::config::SourceConfig;
12use super::error::{StreamingError, TryPushError};
13use super::sink::Sink;
14
15pub trait Record: Clone + Send + Sized + 'static {
17 fn schema() -> SchemaRef;
19
20 fn to_record_batch(&self) -> RecordBatch;
24
25 fn event_time(&self) -> Option<i64> {
30 None
31 }
32
33 fn to_record_batch_from_iter<I>(records: I) -> RecordBatch
38 where
39 I: IntoIterator<Item = Self>,
40 {
41 let batches: Vec<RecordBatch> = records.into_iter().map(|r| r.to_record_batch()).collect();
42 if batches.is_empty() {
43 return RecordBatch::new_empty(Self::schema());
44 }
45 arrow::compute::concat_batches(&Self::schema(), &batches)
46 .unwrap_or_else(|_| RecordBatch::new_empty(Self::schema()))
47 }
48}
49
50#[derive(Clone)]
52pub(crate) enum SourceMessage<T> {
53 Record(T),
55
56 Batch(RecordBatch),
58
59 Watermark(#[allow(dead_code)] i64),
61}
62
63struct SourceWatermark {
65 current: Arc<AtomicI64>,
69}
70
71impl SourceWatermark {
72 fn new() -> Self {
73 Self {
74 current: Arc::new(AtomicI64::new(i64::MIN)),
75 }
76 }
77
78 fn from_arc(arc: Arc<AtomicI64>) -> Self {
79 Self { current: arc }
80 }
81
82 fn update(&self, timestamp: i64) {
83 let mut current = self.current.load(Ordering::Acquire);
85 while timestamp > current {
86 match self.current.compare_exchange_weak(
87 current,
88 timestamp,
89 Ordering::AcqRel,
90 Ordering::Acquire,
91 ) {
92 Ok(_) => break,
93 Err(actual) => current = actual,
94 }
95 }
96 }
97
98 fn get(&self) -> i64 {
99 self.current.load(Ordering::Acquire)
100 }
101
102 fn arc(&self) -> Arc<AtomicI64> {
103 Arc::clone(&self.current)
104 }
105}
106
107struct SourceInner<T: Record> {
109 producer: Producer<SourceMessage<T>>,
111
112 watermark: SourceWatermark,
114
115 schema: SchemaRef,
117
118 name: Option<String>,
120
121 sequence: Arc<AtomicU64>,
124
125 event_time_column: OnceLock<String>,
128
129 max_out_of_orderness: OnceLock<Duration>,
132}
133
134pub struct Source<T: Record> {
136 inner: Arc<SourceInner<T>>,
137}
138
139impl<T: Record> Source<T> {
140 pub(crate) fn new(config: SourceConfig) -> (Self, Sink<T>) {
142 let channel_config = config.channel;
143 let (producer, consumer) = channel_with_config::<SourceMessage<T>>(&channel_config);
144
145 let schema = T::schema();
146
147 let inner = Arc::new(SourceInner {
148 producer,
149 watermark: SourceWatermark::new(),
150 schema: schema.clone(),
151 name: config.name,
152 sequence: Arc::new(AtomicU64::new(0)),
153 event_time_column: OnceLock::new(),
154 max_out_of_orderness: OnceLock::new(),
155 });
156
157 let source = Self { inner };
158 let sink = Sink::new(consumer, schema);
159
160 (source, sink)
161 }
162
163 pub fn push(&self, record: T) -> Result<(), StreamingError> {
169 if let Some(event_time) = record.event_time() {
170 self.inner.watermark.update(event_time);
171 }
172
173 self.inner
174 .producer
175 .push(SourceMessage::Record(record))
176 .map_err(|_| StreamingError::ChannelFull)?;
177
178 self.inner.sequence.fetch_add(1, Ordering::Relaxed);
179 Ok(())
180 }
181
182 pub fn try_push(&self, record: T) -> Result<(), TryPushError<T>> {
188 if let Some(event_time) = record.event_time() {
189 self.inner.watermark.update(event_time);
190 }
191
192 self.inner
193 .producer
194 .push(SourceMessage::Record(record))
195 .map_err(|msg| match msg {
196 SourceMessage::Record(r) => TryPushError {
197 value: r,
198 error: StreamingError::ChannelFull,
199 },
200 _ => unreachable!(),
201 })?;
202
203 self.inner.sequence.fetch_add(1, Ordering::Relaxed);
204 Ok(())
205 }
206
207 pub fn push_batch(&self, records: &[T]) -> usize
209 where
210 T: Clone,
211 {
212 self.push_batch_drain(records.iter().cloned())
213 }
214
215 pub fn push_batch_drain<I>(&self, records: I) -> usize
218 where
219 I: IntoIterator<Item = T>,
220 {
221 let mut count = 0;
222 for record in records {
223 if self.push(record).is_err() {
224 break;
225 }
226 count += 1;
227 }
228 count
229 }
230
231 pub fn push_arrow(&self, batch: RecordBatch) -> Result<(), StreamingError> {
241 if !self.inner.schema.fields().is_empty() && batch.schema() != self.inner.schema {
243 return Err(StreamingError::SchemaMismatch {
244 expected: self
245 .inner
246 .schema
247 .fields()
248 .iter()
249 .map(|f| f.name().clone())
250 .collect(),
251 actual: batch
252 .schema()
253 .fields()
254 .iter()
255 .map(|f| f.name().clone())
256 .collect(),
257 });
258 }
259
260 self.inner
261 .producer
262 .push(SourceMessage::Batch(batch))
263 .map_err(|_| StreamingError::ChannelFull)?;
264
265 self.inner.sequence.fetch_add(1, Ordering::Relaxed);
266 Ok(())
267 }
268
269 pub fn watermark(&self, timestamp: i64) {
278 self.inner.watermark.update(timestamp);
279
280 let _ = self
283 .inner
284 .producer
285 .try_push(SourceMessage::Watermark(timestamp));
286 }
287
288 #[must_use]
290 pub fn current_watermark(&self) -> i64 {
291 self.inner.watermark.get()
292 }
293
294 #[must_use]
296 pub fn schema(&self) -> SchemaRef {
297 Arc::clone(&self.inner.schema)
298 }
299
300 #[must_use]
302 pub fn name(&self) -> Option<&str> {
303 self.inner.name.as_deref()
304 }
305
306 #[must_use]
308 pub fn is_closed(&self) -> bool {
309 self.inner.producer.is_closed()
310 }
311
312 #[must_use]
314 pub fn pending(&self) -> usize {
315 self.inner.producer.len()
316 }
317
318 #[must_use]
320 pub fn capacity(&self) -> usize {
321 self.inner.producer.capacity()
322 }
323
324 #[must_use]
326 pub fn sequence(&self) -> u64 {
327 self.inner.sequence.load(Ordering::Acquire)
328 }
329
330 #[must_use]
332 pub fn sequence_counter(&self) -> Arc<AtomicU64> {
333 Arc::clone(&self.inner.sequence)
334 }
335
336 #[must_use]
338 pub fn watermark_atomic(&self) -> Arc<AtomicI64> {
339 self.inner.watermark.arc()
340 }
341
342 pub fn set_event_time_column(&self, column: &str) {
349 let _ = self.inner.event_time_column.set(column.to_owned());
350 }
351
352 #[must_use]
354 pub fn event_time_column(&self) -> Option<String> {
355 self.inner.event_time_column.get().cloned()
356 }
357
358 pub fn set_max_out_of_orderness(&self, dur: Duration) {
362 let _ = self.inner.max_out_of_orderness.set(dur);
363 }
364
365 #[must_use]
367 pub fn max_out_of_orderness(&self) -> Option<Duration> {
368 self.inner.max_out_of_orderness.get().copied()
369 }
370}
371
372impl<T: Record> Clone for Source<T> {
373 fn clone(&self) -> Self {
374 let producer = self.inner.producer.clone();
375 let event_time_col = self.inner.event_time_column.get().cloned();
376 let event_time_column = OnceLock::new();
377 if let Some(col) = event_time_col {
378 let _ = event_time_column.set(col);
379 }
380 let max_ooo = self.inner.max_out_of_orderness.get().copied();
381 let max_out_of_orderness = OnceLock::new();
382 if let Some(dur) = max_ooo {
383 let _ = max_out_of_orderness.set(dur);
384 }
385 Self {
386 inner: Arc::new(SourceInner {
387 producer,
388 watermark: SourceWatermark::from_arc(self.inner.watermark.arc()),
389 schema: Arc::clone(&self.inner.schema),
390 name: self.inner.name.clone(),
391 sequence: Arc::clone(&self.inner.sequence),
392 event_time_column,
393 max_out_of_orderness,
394 }),
395 }
396 }
397}
398
399impl<T: Record + std::fmt::Debug> std::fmt::Debug for Source<T> {
400 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
401 f.debug_struct("Source")
402 .field("name", &self.inner.name)
403 .field("pending", &self.pending())
404 .field("capacity", &self.capacity())
405 .field("watermark", &self.current_watermark())
406 .finish()
407 }
408}
409
410#[must_use]
412pub fn create<T: Record>(buffer_size: usize) -> (Source<T>, Sink<T>) {
413 Source::new(SourceConfig::with_buffer_size(buffer_size))
414}
415
416#[must_use]
418pub fn create_with_config<T: Record>(config: SourceConfig) -> (Source<T>, Sink<T>) {
419 Source::new(config)
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425 use arrow::array::{Float64Array, Int64Array, StringArray};
426 use arrow::datatypes::{DataType, Field, Schema};
427 use std::sync::Arc;
428
429 #[derive(Clone, Debug)]
431 struct TestEvent {
432 id: i64,
433 value: f64,
434 timestamp: i64,
435 }
436
437 impl Record for TestEvent {
438 fn schema() -> SchemaRef {
439 Arc::new(Schema::new(vec![
440 Field::new("id", DataType::Int64, false),
441 Field::new("value", DataType::Float64, false),
442 Field::new("timestamp", DataType::Int64, false),
443 ]))
444 }
445
446 fn to_record_batch(&self) -> RecordBatch {
447 RecordBatch::try_new(
448 Self::schema(),
449 vec![
450 Arc::new(Int64Array::from(vec![self.id])),
451 Arc::new(Float64Array::from(vec![self.value])),
452 Arc::new(Int64Array::from(vec![self.timestamp])),
453 ],
454 )
455 .unwrap()
456 }
457
458 fn event_time(&self) -> Option<i64> {
459 Some(self.timestamp)
460 }
461 }
462
463 #[tokio::test]
464 async fn test_create_source_sink() {
465 let (source, _sink) = create::<TestEvent>(1024);
466
467 assert!(!source.is_closed());
468 assert_eq!(source.pending(), 0);
469 }
470
471 #[tokio::test]
472 async fn test_push_single() {
473 let (source, _sink) = create::<TestEvent>(16);
474
475 let event = TestEvent {
476 id: 1,
477 value: 42.0,
478 timestamp: 1000,
479 };
480
481 assert!(source.push(event).is_ok());
482 assert_eq!(source.pending(), 1);
483 }
484
485 #[tokio::test]
486 async fn test_try_push() {
487 let (source, _sink) = create::<TestEvent>(16);
488
489 let event = TestEvent {
490 id: 1,
491 value: 42.0,
492 timestamp: 1000,
493 };
494
495 assert!(source.try_push(event).is_ok());
496 }
497
498 #[tokio::test]
499 async fn test_push_batch() {
500 let (source, _sink) = create::<TestEvent>(16);
501
502 let events = vec![
503 TestEvent {
504 id: 1,
505 value: 1.0,
506 timestamp: 1000,
507 },
508 TestEvent {
509 id: 2,
510 value: 2.0,
511 timestamp: 2000,
512 },
513 TestEvent {
514 id: 3,
515 value: 3.0,
516 timestamp: 3000,
517 },
518 ];
519
520 let count = source.push_batch(&events);
521 assert_eq!(count, 3);
522 assert_eq!(source.pending(), 3);
523 }
524
525 #[tokio::test]
526 async fn test_push_arrow() {
527 let (source, _sink) = create::<TestEvent>(16);
528
529 let batch = RecordBatch::try_new(
530 TestEvent::schema(),
531 vec![
532 Arc::new(Int64Array::from(vec![1, 2, 3])),
533 Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0])),
534 Arc::new(Int64Array::from(vec![1000, 2000, 3000])),
535 ],
536 )
537 .unwrap();
538
539 assert!(source.push_arrow(batch).is_ok());
540 }
541
542 #[tokio::test]
543 async fn test_push_arrow_schema_mismatch() {
544 let (source, _sink) = create::<TestEvent>(16);
545
546 let wrong_schema = Arc::new(Schema::new(vec![Field::new(
548 "wrong",
549 DataType::Utf8,
550 false,
551 )]));
552
553 let batch = RecordBatch::try_new(
554 wrong_schema,
555 vec![Arc::new(StringArray::from(vec!["test"]))],
556 )
557 .unwrap();
558
559 let result = source.push_arrow(batch);
560 assert!(matches!(result, Err(StreamingError::SchemaMismatch { .. })));
561 }
562
563 #[tokio::test]
564 async fn test_watermark() {
565 let (source, _sink) = create::<TestEvent>(16);
566
567 assert_eq!(source.current_watermark(), i64::MIN);
568
569 source.watermark(1000);
570 assert_eq!(source.current_watermark(), 1000);
571
572 source.watermark(2000);
573 assert_eq!(source.current_watermark(), 2000);
574
575 source.watermark(1500);
577 assert_eq!(source.current_watermark(), 2000);
578 }
579
580 #[tokio::test]
581 async fn test_watermark_from_event_time() {
582 let (source, _sink) = create::<TestEvent>(16);
583
584 let event = TestEvent {
585 id: 1,
586 value: 42.0,
587 timestamp: 5000,
588 };
589
590 source.push(event).unwrap();
591
592 assert_eq!(source.current_watermark(), 5000);
594 }
595
596 #[tokio::test]
597 async fn test_clone_multi_producer() {
598 let (source, sink) = create::<TestEvent>(16);
599 let source2 = source.clone();
600 let mut sub = sink.subscribe(); source
603 .push(TestEvent {
604 id: 1,
605 value: 1.0,
606 timestamp: 1000,
607 })
608 .unwrap();
609 source2
610 .push(TestEvent {
611 id: 2,
612 value: 2.0,
613 timestamp: 2000,
614 })
615 .unwrap();
616
617 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
618 assert!(sub.poll().is_some());
619 assert!(sub.poll().is_some());
620 }
621
622 #[tokio::test]
623 async fn test_schema() {
624 let (source, _sink) = create::<TestEvent>(16);
625
626 let schema = source.schema();
627 assert_eq!(schema.fields().len(), 3);
628 assert_eq!(schema.field(0).name(), "id");
629 assert_eq!(schema.field(1).name(), "value");
630 assert_eq!(schema.field(2).name(), "timestamp");
631 }
632
633 #[tokio::test]
634 async fn test_named_source() {
635 let config = SourceConfig::named("my_source");
636 let (source, _sink) = create_with_config::<TestEvent>(config);
637
638 assert_eq!(source.name(), Some("my_source"));
639 }
640
641 #[tokio::test]
642 async fn test_debug_format() {
643 let (source, _sink) = create::<TestEvent>(16);
644
645 let debug = format!("{source:?}");
646 assert!(debug.contains("Source"));
647 }
648
649 #[tokio::test]
650 async fn test_set_event_time_column() {
651 let (source, _sink) = create::<TestEvent>(16);
652
653 assert!(source.event_time_column().is_none());
654
655 source.set_event_time_column("timestamp");
656 assert_eq!(source.event_time_column(), Some("timestamp".to_string()));
657 }
658
659 #[tokio::test]
660 async fn test_event_time_column_preserved_on_clone() {
661 let (source, _sink) = create::<TestEvent>(16);
662 source.set_event_time_column("ts");
663
664 let source2 = source.clone();
665 assert_eq!(source2.event_time_column(), Some("ts".to_string()));
666 }
667}