Skip to main content

laminar_core/streaming/
source.rs

1//! Source — entry point for data into a streaming pipeline.
2
3use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
4use std::sync::{Arc, OnceLock};
5use std::time::Duration;
6
7use arrow::array::RecordBatch;
8use arrow::datatypes::SchemaRef;
9
10use super::channel::{channel_with_config, Producer};
11use super::config::SourceConfig;
12use super::error::{StreamingError, TryPushError};
13use super::sink::Sink;
14
15/// Trait for types that can be streamed through a Source.
16pub trait Record: Clone + Send + Sized + 'static {
17    /// Returns the Arrow schema for this record type.
18    fn schema() -> SchemaRef;
19
20    /// Converts this record to an Arrow `RecordBatch`.
21    ///
22    /// The batch will contain a single row with this record's data.
23    fn to_record_batch(&self) -> RecordBatch;
24
25    /// Returns the event time for this record, if applicable.
26    ///
27    /// Event time is used for watermark generation and window assignment.
28    /// Returns `None` if the record doesn't have an event time.
29    fn event_time(&self) -> Option<i64> {
30        None
31    }
32
33    /// Converts a batch of records to an Arrow `RecordBatch`.
34    ///
35    /// The default implementation converts each record individually and concatenates them.
36    /// Derived implementations can override this to optimize allocation and copying.
37    fn to_record_batch_from_iter<I>(records: I) -> RecordBatch
38    where
39        I: IntoIterator<Item = Self>,
40    {
41        let batches: Vec<RecordBatch> = records.into_iter().map(|r| r.to_record_batch()).collect();
42        if batches.is_empty() {
43            return RecordBatch::new_empty(Self::schema());
44        }
45        arrow::compute::concat_batches(&Self::schema(), &batches)
46            .unwrap_or_else(|_| RecordBatch::new_empty(Self::schema()))
47    }
48}
49
50/// Internal message type that wraps records and control signals.
51#[derive(Clone)]
52pub(crate) enum SourceMessage<T> {
53    /// A data record.
54    Record(T),
55
56    /// A batch of Arrow records.
57    Batch(RecordBatch),
58
59    /// A watermark timestamp.
60    Watermark(#[allow(dead_code)] i64),
61}
62
63/// Shared state for watermark tracking.
64struct SourceWatermark {
65    /// Current watermark value.
66    /// Atomically updated to support multi-producer scenarios.
67    /// Wrapped in `Arc` so the checkpoint manager can read it without locking.
68    current: Arc<AtomicI64>,
69}
70
71impl SourceWatermark {
72    fn new() -> Self {
73        Self {
74            current: Arc::new(AtomicI64::new(i64::MIN)),
75        }
76    }
77
78    fn from_arc(arc: Arc<AtomicI64>) -> Self {
79        Self { current: arc }
80    }
81
82    fn update(&self, timestamp: i64) {
83        // Only advance watermark, never go backwards
84        let mut current = self.current.load(Ordering::Acquire);
85        while timestamp > current {
86            match self.current.compare_exchange_weak(
87                current,
88                timestamp,
89                Ordering::AcqRel,
90                Ordering::Acquire,
91            ) {
92                Ok(_) => break,
93                Err(actual) => current = actual,
94            }
95        }
96    }
97
98    fn get(&self) -> i64 {
99        self.current.load(Ordering::Acquire)
100    }
101
102    fn arc(&self) -> Arc<AtomicI64> {
103        Arc::clone(&self.current)
104    }
105}
106
107/// Shared state for a Source/Sink pair.
108struct SourceInner<T: Record> {
109    /// Channel producer for sending records.
110    producer: Producer<SourceMessage<T>>,
111
112    /// Watermark state.
113    watermark: SourceWatermark,
114
115    /// Schema for type validation.
116    schema: SchemaRef,
117
118    /// Source name (for debugging/metrics).
119    name: Option<String>,
120
121    /// Monotonic sequence counter, incremented on each successful push.
122    /// Wrapped in `Arc` so the checkpoint manager can read it without locking.
123    sequence: Arc<AtomicU64>,
124
125    /// Event-time column name set via programmatic API.
126    /// Read once at pipeline startup, not on the hot path.
127    event_time_column: OnceLock<String>,
128
129    /// Max out-of-orderness bound, paired with `event_time_column`.
130    /// Read once at pipeline startup, not on the hot path.
131    max_out_of_orderness: OnceLock<Duration>,
132}
133
134/// A streaming data source. Cloneable for multi-producer use.
135pub struct Source<T: Record> {
136    inner: Arc<SourceInner<T>>,
137}
138
139impl<T: Record> Source<T> {
140    /// Creates a new Source/Sink pair.
141    pub(crate) fn new(config: SourceConfig) -> (Self, Sink<T>) {
142        let channel_config = config.channel;
143        let (producer, consumer) = channel_with_config::<SourceMessage<T>>(&channel_config);
144
145        let schema = T::schema();
146
147        let inner = Arc::new(SourceInner {
148            producer,
149            watermark: SourceWatermark::new(),
150            schema: schema.clone(),
151            name: config.name,
152            sequence: Arc::new(AtomicU64::new(0)),
153            event_time_column: OnceLock::new(),
154            max_out_of_orderness: OnceLock::new(),
155        });
156
157        let source = Self { inner };
158        let sink = Sink::new(consumer, schema);
159
160        (source, sink)
161    }
162
163    /// Pushes a record. Non-blocking — returns `ChannelFull` if the buffer is full.
164    ///
165    /// # Errors
166    ///
167    /// Returns `StreamingError::ChannelFull` if the buffer is full or the sink was dropped.
168    pub fn push(&self, record: T) -> Result<(), StreamingError> {
169        if let Some(event_time) = record.event_time() {
170            self.inner.watermark.update(event_time);
171        }
172
173        self.inner
174            .producer
175            .push(SourceMessage::Record(record))
176            .map_err(|_| StreamingError::ChannelFull)?;
177
178        self.inner.sequence.fetch_add(1, Ordering::Relaxed);
179        Ok(())
180    }
181
182    /// Pushes a record, returning it on failure.
183    ///
184    /// # Errors
185    ///
186    /// Returns `TryPushError` containing the record if the channel is full.
187    pub fn try_push(&self, record: T) -> Result<(), TryPushError<T>> {
188        if let Some(event_time) = record.event_time() {
189            self.inner.watermark.update(event_time);
190        }
191
192        self.inner
193            .producer
194            .push(SourceMessage::Record(record))
195            .map_err(|msg| match msg {
196                SourceMessage::Record(r) => TryPushError {
197                    value: r,
198                    error: StreamingError::ChannelFull,
199                },
200                _ => unreachable!(),
201            })?;
202
203        self.inner.sequence.fetch_add(1, Ordering::Relaxed);
204        Ok(())
205    }
206
207    /// Pushes multiple records (cloned). Stops at the first failure.
208    pub fn push_batch(&self, records: &[T]) -> usize
209    where
210        T: Clone,
211    {
212        self.push_batch_drain(records.iter().cloned())
213    }
214
215    /// Pushes records from an iterator, consuming them (zero-clone).
216    /// Stops at the first failure. Returns the number pushed.
217    pub fn push_batch_drain<I>(&self, records: I) -> usize
218    where
219        I: IntoIterator<Item = T>,
220    {
221        let mut count = 0;
222        for record in records {
223            if self.push(record).is_err() {
224                break;
225            }
226            count += 1;
227        }
228        count
229    }
230
231    /// Pushes an Arrow `RecordBatch` directly.
232    ///
233    /// This is more efficient than pushing individual records when you
234    /// already have data in Arrow format.
235    ///
236    /// # Errors
237    ///
238    /// Returns `StreamingError::SchemaMismatch` if the batch schema doesn't match.
239    /// Returns `StreamingError::ChannelClosed` if the sink has been dropped.
240    pub fn push_arrow(&self, batch: RecordBatch) -> Result<(), StreamingError> {
241        // Validate schema matches (skip for type-erased sources with empty schema)
242        if !self.inner.schema.fields().is_empty() && batch.schema() != self.inner.schema {
243            return Err(StreamingError::SchemaMismatch {
244                expected: self
245                    .inner
246                    .schema
247                    .fields()
248                    .iter()
249                    .map(|f| f.name().clone())
250                    .collect(),
251                actual: batch
252                    .schema()
253                    .fields()
254                    .iter()
255                    .map(|f| f.name().clone())
256                    .collect(),
257            });
258        }
259
260        self.inner
261            .producer
262            .push(SourceMessage::Batch(batch))
263            .map_err(|_| StreamingError::ChannelFull)?;
264
265        self.inner.sequence.fetch_add(1, Ordering::Relaxed);
266        Ok(())
267    }
268
269    /// Emits a watermark timestamp.
270    ///
271    /// Watermarks signal that no events with timestamps less than or equal
272    /// to this value will arrive in the future. This enables window triggers
273    /// and garbage collection.
274    ///
275    /// Watermarks are monotonically increasing - if a lower timestamp is
276    /// passed, it will be ignored.
277    pub fn watermark(&self, timestamp: i64) {
278        self.inner.watermark.update(timestamp);
279
280        // Best-effort send of watermark message
281        // It's okay if this fails - the atomic watermark state is updated
282        let _ = self
283            .inner
284            .producer
285            .try_push(SourceMessage::Watermark(timestamp));
286    }
287
288    /// Returns the current watermark value.
289    #[must_use]
290    pub fn current_watermark(&self) -> i64 {
291        self.inner.watermark.get()
292    }
293
294    /// Returns the schema for this source.
295    #[must_use]
296    pub fn schema(&self) -> SchemaRef {
297        Arc::clone(&self.inner.schema)
298    }
299
300    /// Returns the source name, if configured.
301    #[must_use]
302    pub fn name(&self) -> Option<&str> {
303        self.inner.name.as_deref()
304    }
305
306    /// Returns true if the sink has been dropped.
307    #[must_use]
308    pub fn is_closed(&self) -> bool {
309        self.inner.producer.is_closed()
310    }
311
312    /// Returns the number of pending items in the buffer.
313    #[must_use]
314    pub fn pending(&self) -> usize {
315        self.inner.producer.len()
316    }
317
318    /// Returns the buffer capacity.
319    #[must_use]
320    pub fn capacity(&self) -> usize {
321        self.inner.producer.capacity()
322    }
323
324    /// Returns the current sequence number (total successful pushes).
325    #[must_use]
326    pub fn sequence(&self) -> u64 {
327        self.inner.sequence.load(Ordering::Acquire)
328    }
329
330    /// Returns the shared sequence counter for checkpoint registration.
331    #[must_use]
332    pub fn sequence_counter(&self) -> Arc<AtomicU64> {
333        Arc::clone(&self.inner.sequence)
334    }
335
336    /// Returns the shared watermark atomic for checkpoint registration.
337    #[must_use]
338    pub fn watermark_atomic(&self) -> Arc<AtomicI64> {
339        self.inner.watermark.arc()
340    }
341
342    /// Declare which column in the source data represents event time.
343    ///
344    /// When set, `source.watermark()` enables late-row filtering
345    /// without a SQL `WATERMARK FOR` clause.
346    ///
347    /// Only the first call takes effect; subsequent calls are silently ignored.
348    pub fn set_event_time_column(&self, column: &str) {
349        let _ = self.inner.event_time_column.set(column.to_owned());
350    }
351
352    /// Returns the configured event-time column, if any.
353    #[must_use]
354    pub fn event_time_column(&self) -> Option<String> {
355        self.inner.event_time_column.get().cloned()
356    }
357
358    /// Set the max out-of-orderness bound for watermark generation.
359    ///
360    /// Only the first call takes effect; subsequent calls are silently ignored.
361    pub fn set_max_out_of_orderness(&self, dur: Duration) {
362        let _ = self.inner.max_out_of_orderness.set(dur);
363    }
364
365    /// Returns the configured max out-of-orderness, if any.
366    #[must_use]
367    pub fn max_out_of_orderness(&self) -> Option<Duration> {
368        self.inner.max_out_of_orderness.get().copied()
369    }
370}
371
372impl<T: Record> Clone for Source<T> {
373    fn clone(&self) -> Self {
374        let producer = self.inner.producer.clone();
375        let event_time_col = self.inner.event_time_column.get().cloned();
376        let event_time_column = OnceLock::new();
377        if let Some(col) = event_time_col {
378            let _ = event_time_column.set(col);
379        }
380        let max_ooo = self.inner.max_out_of_orderness.get().copied();
381        let max_out_of_orderness = OnceLock::new();
382        if let Some(dur) = max_ooo {
383            let _ = max_out_of_orderness.set(dur);
384        }
385        Self {
386            inner: Arc::new(SourceInner {
387                producer,
388                watermark: SourceWatermark::from_arc(self.inner.watermark.arc()),
389                schema: Arc::clone(&self.inner.schema),
390                name: self.inner.name.clone(),
391                sequence: Arc::clone(&self.inner.sequence),
392                event_time_column,
393                max_out_of_orderness,
394            }),
395        }
396    }
397}
398
399impl<T: Record + std::fmt::Debug> std::fmt::Debug for Source<T> {
400    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
401        f.debug_struct("Source")
402            .field("name", &self.inner.name)
403            .field("pending", &self.pending())
404            .field("capacity", &self.capacity())
405            .field("watermark", &self.current_watermark())
406            .finish()
407    }
408}
409
410/// Creates a new Source/Sink pair with the given buffer size.
411#[must_use]
412pub fn create<T: Record>(buffer_size: usize) -> (Source<T>, Sink<T>) {
413    Source::new(SourceConfig::with_buffer_size(buffer_size))
414}
415
416/// Creates a new Source/Sink pair with custom configuration.
417#[must_use]
418pub fn create_with_config<T: Record>(config: SourceConfig) -> (Source<T>, Sink<T>) {
419    Source::new(config)
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425    use arrow::array::{Float64Array, Int64Array, StringArray};
426    use arrow::datatypes::{DataType, Field, Schema};
427    use std::sync::Arc;
428
429    // Test record type
430    #[derive(Clone, Debug)]
431    struct TestEvent {
432        id: i64,
433        value: f64,
434        timestamp: i64,
435    }
436
437    impl Record for TestEvent {
438        fn schema() -> SchemaRef {
439            Arc::new(Schema::new(vec![
440                Field::new("id", DataType::Int64, false),
441                Field::new("value", DataType::Float64, false),
442                Field::new("timestamp", DataType::Int64, false),
443            ]))
444        }
445
446        fn to_record_batch(&self) -> RecordBatch {
447            RecordBatch::try_new(
448                Self::schema(),
449                vec![
450                    Arc::new(Int64Array::from(vec![self.id])),
451                    Arc::new(Float64Array::from(vec![self.value])),
452                    Arc::new(Int64Array::from(vec![self.timestamp])),
453                ],
454            )
455            .unwrap()
456        }
457
458        fn event_time(&self) -> Option<i64> {
459            Some(self.timestamp)
460        }
461    }
462
463    #[tokio::test]
464    async fn test_create_source_sink() {
465        let (source, _sink) = create::<TestEvent>(1024);
466
467        assert!(!source.is_closed());
468        assert_eq!(source.pending(), 0);
469    }
470
471    #[tokio::test]
472    async fn test_push_single() {
473        let (source, _sink) = create::<TestEvent>(16);
474
475        let event = TestEvent {
476            id: 1,
477            value: 42.0,
478            timestamp: 1000,
479        };
480
481        assert!(source.push(event).is_ok());
482        assert_eq!(source.pending(), 1);
483    }
484
485    #[tokio::test]
486    async fn test_try_push() {
487        let (source, _sink) = create::<TestEvent>(16);
488
489        let event = TestEvent {
490            id: 1,
491            value: 42.0,
492            timestamp: 1000,
493        };
494
495        assert!(source.try_push(event).is_ok());
496    }
497
498    #[tokio::test]
499    async fn test_push_batch() {
500        let (source, _sink) = create::<TestEvent>(16);
501
502        let events = vec![
503            TestEvent {
504                id: 1,
505                value: 1.0,
506                timestamp: 1000,
507            },
508            TestEvent {
509                id: 2,
510                value: 2.0,
511                timestamp: 2000,
512            },
513            TestEvent {
514                id: 3,
515                value: 3.0,
516                timestamp: 3000,
517            },
518        ];
519
520        let count = source.push_batch(&events);
521        assert_eq!(count, 3);
522        assert_eq!(source.pending(), 3);
523    }
524
525    #[tokio::test]
526    async fn test_push_arrow() {
527        let (source, _sink) = create::<TestEvent>(16);
528
529        let batch = RecordBatch::try_new(
530            TestEvent::schema(),
531            vec![
532                Arc::new(Int64Array::from(vec![1, 2, 3])),
533                Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0])),
534                Arc::new(Int64Array::from(vec![1000, 2000, 3000])),
535            ],
536        )
537        .unwrap();
538
539        assert!(source.push_arrow(batch).is_ok());
540    }
541
542    #[tokio::test]
543    async fn test_push_arrow_schema_mismatch() {
544        let (source, _sink) = create::<TestEvent>(16);
545
546        // Create batch with different schema
547        let wrong_schema = Arc::new(Schema::new(vec![Field::new(
548            "wrong",
549            DataType::Utf8,
550            false,
551        )]));
552
553        let batch = RecordBatch::try_new(
554            wrong_schema,
555            vec![Arc::new(StringArray::from(vec!["test"]))],
556        )
557        .unwrap();
558
559        let result = source.push_arrow(batch);
560        assert!(matches!(result, Err(StreamingError::SchemaMismatch { .. })));
561    }
562
563    #[tokio::test]
564    async fn test_watermark() {
565        let (source, _sink) = create::<TestEvent>(16);
566
567        assert_eq!(source.current_watermark(), i64::MIN);
568
569        source.watermark(1000);
570        assert_eq!(source.current_watermark(), 1000);
571
572        source.watermark(2000);
573        assert_eq!(source.current_watermark(), 2000);
574
575        // Watermark should not go backwards
576        source.watermark(1500);
577        assert_eq!(source.current_watermark(), 2000);
578    }
579
580    #[tokio::test]
581    async fn test_watermark_from_event_time() {
582        let (source, _sink) = create::<TestEvent>(16);
583
584        let event = TestEvent {
585            id: 1,
586            value: 42.0,
587            timestamp: 5000,
588        };
589
590        source.push(event).unwrap();
591
592        // Watermark should be updated from event time
593        assert_eq!(source.current_watermark(), 5000);
594    }
595
596    #[tokio::test]
597    async fn test_clone_multi_producer() {
598        let (source, sink) = create::<TestEvent>(16);
599        let source2 = source.clone();
600        let mut sub = sink.subscribe(); // subscribe before push
601
602        source
603            .push(TestEvent {
604                id: 1,
605                value: 1.0,
606                timestamp: 1000,
607            })
608            .unwrap();
609        source2
610            .push(TestEvent {
611                id: 2,
612                value: 2.0,
613                timestamp: 2000,
614            })
615            .unwrap();
616
617        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
618        assert!(sub.poll().is_some());
619        assert!(sub.poll().is_some());
620    }
621
622    #[tokio::test]
623    async fn test_schema() {
624        let (source, _sink) = create::<TestEvent>(16);
625
626        let schema = source.schema();
627        assert_eq!(schema.fields().len(), 3);
628        assert_eq!(schema.field(0).name(), "id");
629        assert_eq!(schema.field(1).name(), "value");
630        assert_eq!(schema.field(2).name(), "timestamp");
631    }
632
633    #[tokio::test]
634    async fn test_named_source() {
635        let config = SourceConfig::named("my_source");
636        let (source, _sink) = create_with_config::<TestEvent>(config);
637
638        assert_eq!(source.name(), Some("my_source"));
639    }
640
641    #[tokio::test]
642    async fn test_debug_format() {
643        let (source, _sink) = create::<TestEvent>(16);
644
645        let debug = format!("{source:?}");
646        assert!(debug.contains("Source"));
647    }
648
649    #[tokio::test]
650    async fn test_set_event_time_column() {
651        let (source, _sink) = create::<TestEvent>(16);
652
653        assert!(source.event_time_column().is_none());
654
655        source.set_event_time_column("timestamp");
656        assert_eq!(source.event_time_column(), Some("timestamp".to_string()));
657    }
658
659    #[tokio::test]
660    async fn test_event_time_column_preserved_on_clone() {
661        let (source, _sink) = create::<TestEvent>(16);
662        source.set_event_time_column("ts");
663
664        let source2 = source.clone();
665        assert_eq!(source2.event_time_column(), Some("ts".to_string()));
666    }
667}