Skip to main content

laminar_core/streaming/
source.rs

1//! Source — entry point for data into a streaming pipeline.
2
3use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
4use std::sync::{Arc, OnceLock};
5use std::time::Duration;
6
7use arrow::array::RecordBatch;
8use arrow::datatypes::SchemaRef;
9
10use super::channel::{channel_with_config, Producer};
11use super::config::SourceConfig;
12use super::error::{StreamingError, TryPushError};
13use super::sink::Sink;
14
15/// Trait for types that can be streamed through a Source.
16pub trait Record: Clone + Send + Sized + 'static {
17    /// Returns the Arrow schema for this record type.
18    fn schema() -> SchemaRef;
19
20    /// Converts this record to an Arrow `RecordBatch`.
21    ///
22    /// The batch will contain a single row with this record's data.
23    fn to_record_batch(&self) -> RecordBatch;
24
25    /// Returns the event time for this record, if applicable.
26    ///
27    /// Event time is used for watermark generation and window assignment.
28    /// Returns `None` if the record doesn't have an event time.
29    fn event_time(&self) -> Option<i64> {
30        None
31    }
32
33    /// Converts a batch of records to an Arrow `RecordBatch`.
34    ///
35    /// The default implementation converts each record individually and concatenates them.
36    /// Derived implementations can override this to optimize allocation and copying.
37    fn to_record_batch_from_iter<I>(records: I) -> RecordBatch
38    where
39        I: IntoIterator<Item = Self>,
40    {
41        let batches: Vec<RecordBatch> = records.into_iter().map(|r| r.to_record_batch()).collect();
42        if batches.is_empty() {
43            return RecordBatch::new_empty(Self::schema());
44        }
45        arrow::compute::concat_batches(&Self::schema(), &batches)
46            .unwrap_or_else(|_| RecordBatch::new_empty(Self::schema()))
47    }
48}
49
50/// Internal message type that wraps records and control signals.
51#[derive(Clone)]
52pub(crate) enum SourceMessage<T> {
53    /// A data record.
54    Record(T),
55
56    /// A batch of Arrow records.
57    Batch(RecordBatch),
58}
59
60/// Shared state for watermark tracking.
61struct SourceWatermark {
62    /// Current watermark value.
63    /// Atomically updated to support multi-producer scenarios.
64    /// Wrapped in `Arc` so the checkpoint manager can read it without locking.
65    current: Arc<AtomicI64>,
66}
67
68impl SourceWatermark {
69    fn new() -> Self {
70        Self {
71            current: Arc::new(AtomicI64::new(i64::MIN)),
72        }
73    }
74
75    fn from_arc(arc: Arc<AtomicI64>) -> Self {
76        Self { current: arc }
77    }
78
79    fn update(&self, timestamp: i64) {
80        // Only advance watermark, never go backwards
81        let mut current = self.current.load(Ordering::Acquire);
82        while timestamp > current {
83            match self.current.compare_exchange_weak(
84                current,
85                timestamp,
86                Ordering::AcqRel,
87                Ordering::Acquire,
88            ) {
89                Ok(_) => break,
90                Err(actual) => current = actual,
91            }
92        }
93    }
94
95    fn get(&self) -> i64 {
96        self.current.load(Ordering::Acquire)
97    }
98
99    fn arc(&self) -> Arc<AtomicI64> {
100        Arc::clone(&self.current)
101    }
102}
103
104/// Shared state for a Source/Sink pair.
105struct SourceInner<T: Record> {
106    /// Channel producer for sending records.
107    producer: Producer<SourceMessage<T>>,
108
109    /// Watermark state.
110    watermark: SourceWatermark,
111
112    /// Schema for type validation.
113    schema: SchemaRef,
114
115    /// Source name (for debugging/metrics).
116    name: Option<String>,
117
118    /// Monotonic sequence counter, incremented on each successful push.
119    /// Wrapped in `Arc` so the checkpoint manager can read it without locking.
120    sequence: Arc<AtomicU64>,
121
122    /// Event-time column name set via programmatic API.
123    /// Read once at pipeline startup, not on the hot path.
124    event_time_column: OnceLock<String>,
125
126    /// Max out-of-orderness bound, paired with `event_time_column`.
127    /// Read once at pipeline startup, not on the hot path.
128    max_out_of_orderness: OnceLock<Duration>,
129}
130
131/// A streaming data source. Cloneable for multi-producer use.
132pub struct Source<T: Record> {
133    inner: Arc<SourceInner<T>>,
134}
135
136impl<T: Record> Source<T> {
137    /// Creates a new Source/Sink pair.
138    pub(crate) fn new(config: SourceConfig) -> (Self, Sink<T>) {
139        let channel_config = config.channel;
140        let (producer, consumer) = channel_with_config::<SourceMessage<T>>(&channel_config);
141
142        let schema = T::schema();
143
144        let inner = Arc::new(SourceInner {
145            producer,
146            watermark: SourceWatermark::new(),
147            schema: schema.clone(),
148            name: config.name,
149            sequence: Arc::new(AtomicU64::new(0)),
150            event_time_column: OnceLock::new(),
151            max_out_of_orderness: OnceLock::new(),
152        });
153
154        let source = Self { inner };
155        let sink = Sink::new(consumer, schema);
156
157        (source, sink)
158    }
159
160    /// Pushes a record. Non-blocking — returns `ChannelFull` if the buffer is full.
161    ///
162    /// # Errors
163    ///
164    /// Returns `StreamingError::ChannelFull` if the buffer is full or the sink was dropped.
165    pub fn push(&self, record: T) -> Result<(), StreamingError> {
166        if let Some(event_time) = record.event_time() {
167            self.inner.watermark.update(event_time);
168        }
169
170        self.inner
171            .producer
172            .push(SourceMessage::Record(record))
173            .map_err(|_| StreamingError::ChannelFull)?;
174
175        self.inner.sequence.fetch_add(1, Ordering::Relaxed);
176        Ok(())
177    }
178
179    /// Pushes a record, returning it on failure.
180    ///
181    /// # Errors
182    ///
183    /// Returns `TryPushError` containing the record if the channel is full.
184    pub fn try_push(&self, record: T) -> Result<(), TryPushError<T>> {
185        if let Some(event_time) = record.event_time() {
186            self.inner.watermark.update(event_time);
187        }
188
189        self.inner
190            .producer
191            .push(SourceMessage::Record(record))
192            .map_err(|msg| match msg {
193                SourceMessage::Record(r) => TryPushError {
194                    value: r,
195                    error: StreamingError::ChannelFull,
196                },
197                SourceMessage::Batch(_) => unreachable!("only Record is pushed here"),
198            })?;
199
200        self.inner.sequence.fetch_add(1, Ordering::Relaxed);
201        Ok(())
202    }
203
204    /// Pushes multiple records (cloned). Stops at the first failure.
205    pub fn push_batch(&self, records: &[T]) -> usize
206    where
207        T: Clone,
208    {
209        self.push_batch_drain(records.iter().cloned())
210    }
211
212    /// Pushes records from an iterator, consuming them (zero-clone).
213    /// Stops at the first failure. Returns the number pushed.
214    pub fn push_batch_drain<I>(&self, records: I) -> usize
215    where
216        I: IntoIterator<Item = T>,
217    {
218        let mut count = 0;
219        for record in records {
220            if self.push(record).is_err() {
221                break;
222            }
223            count += 1;
224        }
225        count
226    }
227
228    /// Pushes an Arrow `RecordBatch` directly.
229    ///
230    /// This is more efficient than pushing individual records when you
231    /// already have data in Arrow format.
232    ///
233    /// # Errors
234    ///
235    /// Returns `StreamingError::SchemaMismatch` if the batch schema doesn't match.
236    /// Returns `StreamingError::ChannelClosed` if the sink has been dropped.
237    pub fn push_arrow(&self, batch: RecordBatch) -> Result<(), StreamingError> {
238        // Validate schema matches (skip for type-erased sources with empty schema)
239        if !self.inner.schema.fields().is_empty() && batch.schema() != self.inner.schema {
240            return Err(StreamingError::SchemaMismatch {
241                expected: self
242                    .inner
243                    .schema
244                    .fields()
245                    .iter()
246                    .map(|f| f.name().clone())
247                    .collect(),
248                actual: batch
249                    .schema()
250                    .fields()
251                    .iter()
252                    .map(|f| f.name().clone())
253                    .collect(),
254            });
255        }
256
257        self.inner
258            .producer
259            .push(SourceMessage::Batch(batch))
260            .map_err(|_| StreamingError::ChannelFull)?;
261
262        self.inner.sequence.fetch_add(1, Ordering::Relaxed);
263        Ok(())
264    }
265
266    /// Emits a watermark timestamp.
267    ///
268    /// Watermarks signal that no events with timestamps less than or equal
269    /// to this value will arrive in the future. This enables window triggers
270    /// and garbage collection.
271    ///
272    /// Watermarks are monotonically increasing - if a lower timestamp is
273    /// passed, it will be ignored.
274    pub fn watermark(&self, timestamp: i64) {
275        // The shared atomic is the authoritative watermark: the pipeline's
276        // watermark UDF, late-row filter, and checkpoint registration all
277        // read it via `watermark_atomic()`. Subscribers receive data only,
278        // so there is no in-band watermark message to emit.
279        self.inner.watermark.update(timestamp);
280    }
281
282    /// Returns the current watermark value.
283    #[must_use]
284    pub fn current_watermark(&self) -> i64 {
285        self.inner.watermark.get()
286    }
287
288    /// Returns the schema for this source.
289    #[must_use]
290    pub fn schema(&self) -> SchemaRef {
291        Arc::clone(&self.inner.schema)
292    }
293
294    /// Returns the source name, if configured.
295    #[must_use]
296    pub fn name(&self) -> Option<&str> {
297        self.inner.name.as_deref()
298    }
299
300    /// Returns true if the sink has been dropped.
301    #[must_use]
302    pub fn is_closed(&self) -> bool {
303        self.inner.producer.is_closed()
304    }
305
306    /// Returns the number of pending items in the buffer.
307    #[must_use]
308    pub fn pending(&self) -> usize {
309        self.inner.producer.len()
310    }
311
312    /// Returns the buffer capacity.
313    #[must_use]
314    pub fn capacity(&self) -> usize {
315        self.inner.producer.capacity()
316    }
317
318    /// Returns the current sequence number (total successful pushes).
319    #[must_use]
320    pub fn sequence(&self) -> u64 {
321        self.inner.sequence.load(Ordering::Acquire)
322    }
323
324    /// Returns the shared sequence counter for checkpoint registration.
325    #[must_use]
326    pub fn sequence_counter(&self) -> Arc<AtomicU64> {
327        Arc::clone(&self.inner.sequence)
328    }
329
330    /// Returns the shared watermark atomic for checkpoint registration.
331    #[must_use]
332    pub fn watermark_atomic(&self) -> Arc<AtomicI64> {
333        self.inner.watermark.arc()
334    }
335
336    /// Declare which column in the source data represents event time.
337    ///
338    /// When set, `source.watermark()` enables late-row filtering
339    /// without a SQL `WATERMARK FOR` clause.
340    ///
341    /// Only the first call takes effect; subsequent calls are silently ignored.
342    pub fn set_event_time_column(&self, column: &str) {
343        let _ = self.inner.event_time_column.set(column.to_owned());
344    }
345
346    /// Returns the configured event-time column, if any.
347    #[must_use]
348    pub fn event_time_column(&self) -> Option<String> {
349        self.inner.event_time_column.get().cloned()
350    }
351
352    /// Set the max out-of-orderness bound for watermark generation.
353    ///
354    /// Only the first call takes effect; subsequent calls are silently ignored.
355    pub fn set_max_out_of_orderness(&self, dur: Duration) {
356        let _ = self.inner.max_out_of_orderness.set(dur);
357    }
358
359    /// Returns the configured max out-of-orderness, if any.
360    #[must_use]
361    pub fn max_out_of_orderness(&self) -> Option<Duration> {
362        self.inner.max_out_of_orderness.get().copied()
363    }
364}
365
366impl<T: Record> Clone for Source<T> {
367    fn clone(&self) -> Self {
368        let producer = self.inner.producer.clone();
369        let event_time_col = self.inner.event_time_column.get().cloned();
370        let event_time_column = OnceLock::new();
371        if let Some(col) = event_time_col {
372            let _ = event_time_column.set(col);
373        }
374        let max_ooo = self.inner.max_out_of_orderness.get().copied();
375        let max_out_of_orderness = OnceLock::new();
376        if let Some(dur) = max_ooo {
377            let _ = max_out_of_orderness.set(dur);
378        }
379        Self {
380            inner: Arc::new(SourceInner {
381                producer,
382                watermark: SourceWatermark::from_arc(self.inner.watermark.arc()),
383                schema: Arc::clone(&self.inner.schema),
384                name: self.inner.name.clone(),
385                sequence: Arc::clone(&self.inner.sequence),
386                event_time_column,
387                max_out_of_orderness,
388            }),
389        }
390    }
391}
392
393impl<T: Record + std::fmt::Debug> std::fmt::Debug for Source<T> {
394    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
395        f.debug_struct("Source")
396            .field("name", &self.inner.name)
397            .field("pending", &self.pending())
398            .field("capacity", &self.capacity())
399            .field("watermark", &self.current_watermark())
400            .finish()
401    }
402}
403
404/// Creates a new Source/Sink pair with the given buffer size.
405#[must_use]
406pub fn create<T: Record>(buffer_size: usize) -> (Source<T>, Sink<T>) {
407    Source::new(SourceConfig::with_buffer_size(buffer_size))
408}
409
410/// Creates a new Source/Sink pair with custom configuration.
411#[must_use]
412pub fn create_with_config<T: Record>(config: SourceConfig) -> (Source<T>, Sink<T>) {
413    Source::new(config)
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419    use arrow::array::{Float64Array, Int64Array, StringArray};
420    use arrow::datatypes::{DataType, Field, Schema};
421    use std::sync::Arc;
422
423    // Test record type
424    #[derive(Clone, Debug)]
425    struct TestEvent {
426        id: i64,
427        value: f64,
428        timestamp: i64,
429    }
430
431    impl Record for TestEvent {
432        fn schema() -> SchemaRef {
433            Arc::new(Schema::new(vec![
434                Field::new("id", DataType::Int64, false),
435                Field::new("value", DataType::Float64, false),
436                Field::new("timestamp", DataType::Int64, false),
437            ]))
438        }
439
440        fn to_record_batch(&self) -> RecordBatch {
441            RecordBatch::try_new(
442                Self::schema(),
443                vec![
444                    Arc::new(Int64Array::from(vec![self.id])),
445                    Arc::new(Float64Array::from(vec![self.value])),
446                    Arc::new(Int64Array::from(vec![self.timestamp])),
447                ],
448            )
449            .unwrap()
450        }
451
452        fn event_time(&self) -> Option<i64> {
453            Some(self.timestamp)
454        }
455    }
456
457    #[tokio::test]
458    async fn test_create_source_sink() {
459        let (source, _sink) = create::<TestEvent>(1024);
460
461        assert!(!source.is_closed());
462        assert_eq!(source.pending(), 0);
463    }
464
465    #[tokio::test]
466    async fn test_push_single() {
467        let (source, _sink) = create::<TestEvent>(16);
468
469        let event = TestEvent {
470            id: 1,
471            value: 42.0,
472            timestamp: 1000,
473        };
474
475        assert!(source.push(event).is_ok());
476        assert_eq!(source.pending(), 1);
477    }
478
479    #[tokio::test]
480    async fn test_try_push() {
481        let (source, _sink) = create::<TestEvent>(16);
482
483        let event = TestEvent {
484            id: 1,
485            value: 42.0,
486            timestamp: 1000,
487        };
488
489        assert!(source.try_push(event).is_ok());
490    }
491
492    #[tokio::test]
493    async fn test_push_batch() {
494        let (source, _sink) = create::<TestEvent>(16);
495
496        let events = vec![
497            TestEvent {
498                id: 1,
499                value: 1.0,
500                timestamp: 1000,
501            },
502            TestEvent {
503                id: 2,
504                value: 2.0,
505                timestamp: 2000,
506            },
507            TestEvent {
508                id: 3,
509                value: 3.0,
510                timestamp: 3000,
511            },
512        ];
513
514        let count = source.push_batch(&events);
515        assert_eq!(count, 3);
516        assert_eq!(source.pending(), 3);
517    }
518
519    #[tokio::test]
520    async fn test_push_arrow() {
521        let (source, _sink) = create::<TestEvent>(16);
522
523        let batch = RecordBatch::try_new(
524            TestEvent::schema(),
525            vec![
526                Arc::new(Int64Array::from(vec![1, 2, 3])),
527                Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0])),
528                Arc::new(Int64Array::from(vec![1000, 2000, 3000])),
529            ],
530        )
531        .unwrap();
532
533        assert!(source.push_arrow(batch).is_ok());
534    }
535
536    #[tokio::test]
537    async fn test_push_arrow_schema_mismatch() {
538        let (source, _sink) = create::<TestEvent>(16);
539
540        // Create batch with different schema
541        let wrong_schema = Arc::new(Schema::new(vec![Field::new(
542            "wrong",
543            DataType::Utf8,
544            false,
545        )]));
546
547        let batch = RecordBatch::try_new(
548            wrong_schema,
549            vec![Arc::new(StringArray::from(vec!["test"]))],
550        )
551        .unwrap();
552
553        let result = source.push_arrow(batch);
554        assert!(matches!(result, Err(StreamingError::SchemaMismatch { .. })));
555    }
556
557    #[tokio::test]
558    async fn test_watermark() {
559        let (source, _sink) = create::<TestEvent>(16);
560
561        assert_eq!(source.current_watermark(), i64::MIN);
562
563        source.watermark(1000);
564        assert_eq!(source.current_watermark(), 1000);
565
566        source.watermark(2000);
567        assert_eq!(source.current_watermark(), 2000);
568
569        // Watermark should not go backwards
570        source.watermark(1500);
571        assert_eq!(source.current_watermark(), 2000);
572    }
573
574    #[tokio::test]
575    async fn test_watermark_from_event_time() {
576        let (source, _sink) = create::<TestEvent>(16);
577
578        let event = TestEvent {
579            id: 1,
580            value: 42.0,
581            timestamp: 5000,
582        };
583
584        source.push(event).unwrap();
585
586        // Watermark should be updated from event time
587        assert_eq!(source.current_watermark(), 5000);
588    }
589
590    #[tokio::test]
591    async fn test_clone_multi_producer() {
592        let (source, sink) = create::<TestEvent>(16);
593        let source2 = source.clone();
594        let mut sub = sink.subscribe(); // subscribe before push
595
596        source
597            .push(TestEvent {
598                id: 1,
599                value: 1.0,
600                timestamp: 1000,
601            })
602            .unwrap();
603        source2
604            .push(TestEvent {
605                id: 2,
606                value: 2.0,
607                timestamp: 2000,
608            })
609            .unwrap();
610
611        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
612        assert!(sub.poll().is_some());
613        assert!(sub.poll().is_some());
614    }
615
616    #[tokio::test]
617    async fn test_schema() {
618        let (source, _sink) = create::<TestEvent>(16);
619
620        let schema = source.schema();
621        assert_eq!(schema.fields().len(), 3);
622        assert_eq!(schema.field(0).name(), "id");
623        assert_eq!(schema.field(1).name(), "value");
624        assert_eq!(schema.field(2).name(), "timestamp");
625    }
626
627    #[tokio::test]
628    async fn test_named_source() {
629        let config = SourceConfig::named("my_source");
630        let (source, _sink) = create_with_config::<TestEvent>(config);
631
632        assert_eq!(source.name(), Some("my_source"));
633    }
634
635    #[tokio::test]
636    async fn test_debug_format() {
637        let (source, _sink) = create::<TestEvent>(16);
638
639        let debug = format!("{source:?}");
640        assert!(debug.contains("Source"));
641    }
642
643    #[tokio::test]
644    async fn test_set_event_time_column() {
645        let (source, _sink) = create::<TestEvent>(16);
646
647        assert!(source.event_time_column().is_none());
648
649        source.set_event_time_column("timestamp");
650        assert_eq!(source.event_time_column(), Some("timestamp".to_string()));
651    }
652
653    #[tokio::test]
654    async fn test_event_time_column_preserved_on_clone() {
655        let (source, _sink) = create::<TestEvent>(16);
656        source.set_event_time_column("ts");
657
658        let source2 = source.clone();
659        assert_eq!(source2.event_time_column(), Some("ts".to_string()));
660    }
661}