Skip to main content

laminar_sql/datafusion/
live_source.rs

1//! Swappable table provider that eliminates per-cycle catalog churn and
2//! enables physical plan caching.
3//!
4//! Register a [`LiveSourceProvider`] once at pipeline startup. Each cycle,
5//! swap batches via [`LiveSourceHandle`], then execute the cached physical
6//! plan. The internal `LiveSourceExec` reads from the shared slot at `execute()` time,
7//! so the cached plan always sees fresh data.
8
9use std::any::Any;
10use std::sync::Arc;
11
12use arrow::array::RecordBatch;
13use arrow::datatypes::SchemaRef;
14use async_trait::async_trait;
15use datafusion::catalog::Session;
16use datafusion::datasource::TableProvider;
17use datafusion::error::DataFusionError;
18use datafusion::execution::{SendableRecordBatchStream, TaskContext};
19use datafusion::logical_expr::Expr;
20use datafusion::physical_expr::{EquivalenceProperties, Partitioning};
21use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
22use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
23use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
24use datafusion_common::Statistics;
25use datafusion_expr::TableType;
26use parking_lot::Mutex;
27
28/// Shared batch slot used by the provider and every `LiveSourceExec` scan.
29///
30/// Wrapping the batch list in `Arc` means the hot-path `execute()` clone is
31/// an O(1) Arc bump, not a full `Vec<RecordBatch>` clone (which would
32/// Arc-bump every column of every batch). `swap()` replaces the Arc
33/// wholesale while any in-flight reader keeps its prior snapshot alive.
34type BatchSlot = Arc<Mutex<Arc<Vec<RecordBatch>>>>;
35
36fn new_slot() -> BatchSlot {
37    Arc::new(Mutex::new(Arc::new(Vec::new())))
38}
39
40// ── TableProvider ────────────────────────────────────────────────────
41
42/// Swappable `TableProvider` for streaming micro-batch execution.
43///
44/// `scan()` returns an internal execution plan that reads from the shared
45/// batch slot at `execute()` time — enabling physical plan caching.
46pub struct LiveSourceProvider {
47    current: BatchSlot,
48    schema: SchemaRef,
49}
50
51impl LiveSourceProvider {
52    /// Creates a provider with the given schema and an empty batch slot.
53    #[must_use]
54    pub fn new(schema: SchemaRef) -> Self {
55        Self {
56            current: new_slot(),
57            schema,
58        }
59    }
60
61    /// Returns a handle for swapping batches into this provider.
62    #[must_use]
63    pub fn handle(&self) -> LiveSourceHandle {
64        LiveSourceHandle {
65            slot: Arc::clone(&self.current),
66        }
67    }
68}
69
70impl std::fmt::Debug for LiveSourceProvider {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        f.debug_struct("LiveSourceProvider")
73            .field("schema_fields", &self.schema.fields().len())
74            .finish_non_exhaustive()
75    }
76}
77
78#[async_trait]
79impl TableProvider for LiveSourceProvider {
80    fn as_any(&self) -> &dyn Any {
81        self
82    }
83
84    fn schema(&self) -> SchemaRef {
85        self.schema.clone()
86    }
87
88    fn table_type(&self) -> TableType {
89        TableType::Base
90    }
91
92    async fn scan(
93        &self,
94        _state: &dyn Session,
95        projection: Option<&Vec<usize>>,
96        _filters: &[Expr],
97        _limit: Option<usize>,
98    ) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
99        Ok(Arc::new(LiveSourceExec::new(
100            Arc::clone(&self.current),
101            self.schema.clone(),
102            projection.cloned(),
103        )))
104    }
105}
106
107// ── ExecutionPlan ────────────────────────────────────────────────────
108
109/// Leaf `ExecutionPlan` that reads from a shared batch slot at `execute()`
110/// time, not at construction time. This enables physical plan caching:
111/// the plan tree is built once, and each `execute()` call sees fresh data.
112pub(crate) struct LiveSourceExec {
113    slot: BatchSlot,
114    schema: SchemaRef,
115    projection: Option<Vec<usize>>,
116    properties: PlanProperties,
117}
118
119impl LiveSourceExec {
120    fn new(slot: BatchSlot, source_schema: SchemaRef, projection: Option<Vec<usize>>) -> Self {
121        let schema = match &projection {
122            Some(indices) => {
123                let fields: Vec<_> = indices
124                    .iter()
125                    .map(|&i| source_schema.field(i).clone())
126                    .collect();
127                Arc::new(arrow::datatypes::Schema::new(fields))
128            }
129            None => source_schema,
130        };
131        let properties = PlanProperties::new(
132            EquivalenceProperties::new(schema.clone()),
133            Partitioning::UnknownPartitioning(1),
134            EmissionType::Final,
135            Boundedness::Bounded,
136        );
137        Self {
138            slot,
139            schema,
140            projection,
141            properties,
142        }
143    }
144}
145
146impl std::fmt::Debug for LiveSourceExec {
147    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148        f.debug_struct("LiveSourceExec")
149            .field("schema_fields", &self.schema.fields().len())
150            .finish_non_exhaustive()
151    }
152}
153
154impl DisplayAs for LiveSourceExec {
155    fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156        match t {
157            DisplayFormatType::Default | DisplayFormatType::Verbose => {
158                write!(f, "LiveSourceExec: schema={}", self.schema.fields().len())
159            }
160            DisplayFormatType::TreeRender => write!(f, "LiveSourceExec"),
161        }
162    }
163}
164
165impl ExecutionPlan for LiveSourceExec {
166    fn name(&self) -> &'static str {
167        "LiveSourceExec"
168    }
169
170    fn as_any(&self) -> &dyn Any {
171        self
172    }
173
174    fn schema(&self) -> SchemaRef {
175        self.schema.clone()
176    }
177
178    fn properties(&self) -> &PlanProperties {
179        &self.properties
180    }
181
182    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
183        vec![]
184    }
185
186    fn with_new_children(
187        self: Arc<Self>,
188        children: Vec<Arc<dyn ExecutionPlan>>,
189    ) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
190        if children.is_empty() {
191            Ok(self)
192        } else {
193            Err(DataFusionError::Plan(
194                "LiveSourceExec is a leaf node".to_string(),
195            ))
196        }
197    }
198
199    fn execute(
200        &self,
201        partition: usize,
202        _context: Arc<TaskContext>,
203    ) -> Result<SendableRecordBatchStream, DataFusionError> {
204        if partition != 0 {
205            return Err(DataFusionError::Plan(format!(
206                "LiveSourceExec only supports partition 0, got {partition}"
207            )));
208        }
209
210        // O(1) snapshot: bump the Arc under the lock and release.
211        // Any concurrent `swap()` installs a new Arc; this reader keeps
212        // its prior snapshot alive until the returned stream is dropped.
213        let batches_arc: Arc<Vec<RecordBatch>> = Arc::clone(&self.slot.lock());
214        let schema = self.schema.clone();
215        let projection = self.projection.clone();
216
217        let output = futures::stream::iter(if batches_arc.is_empty() {
218            vec![Ok(RecordBatch::new_empty(schema))]
219        } else if let Some(indices) = projection {
220            batches_arc
221                .iter()
222                .map(|batch| batch.project(&indices).map_err(DataFusionError::from))
223                .collect()
224        } else {
225            batches_arc.iter().cloned().map(Ok).collect()
226        });
227
228        Ok(Box::pin(RecordBatchStreamAdapter::new(
229            self.schema.clone(),
230            output,
231        )))
232    }
233
234    fn statistics(&self) -> datafusion_common::Result<Statistics> {
235        Ok(Statistics::default())
236    }
237}
238
239impl datafusion::physical_plan::ExecutionPlanProperties for LiveSourceExec {
240    fn output_partitioning(&self) -> &Partitioning {
241        self.properties.output_partitioning()
242    }
243
244    fn output_ordering(&self) -> Option<&datafusion::physical_expr::LexOrdering> {
245        self.properties.output_ordering()
246    }
247
248    fn boundedness(&self) -> Boundedness {
249        Boundedness::Bounded
250    }
251
252    fn pipeline_behavior(&self) -> EmissionType {
253        EmissionType::Final
254    }
255
256    fn equivalence_properties(&self) -> &EquivalenceProperties {
257        self.properties.equivalence_properties()
258    }
259}
260
261// ── Handle ───────────────────────────────────────────────────────────
262
263/// Handle for swapping batches into a [`LiveSourceProvider`].
264#[derive(Clone)]
265pub struct LiveSourceHandle {
266    slot: BatchSlot,
267}
268
269impl LiveSourceHandle {
270    /// Replace current batches. In-flight readers that captured the
271    /// prior snapshot continue to see the prior data; the next scan
272    /// picks up the new one.
273    pub fn swap(&self, batches: Vec<RecordBatch>) {
274        *self.slot.lock() = Arc::new(batches);
275    }
276
277    /// Clear all pending batches.
278    pub fn clear(&self) {
279        *self.slot.lock() = Arc::new(Vec::new());
280    }
281}
282
283impl std::fmt::Debug for LiveSourceHandle {
284    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285        f.debug_struct("LiveSourceHandle").finish()
286    }
287}
288
289// ── Tests ────────────────────────────────────────────────────────────
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use arrow::array::{Float64Array, Int64Array, StringArray};
295    use arrow::datatypes::{DataType, Field, Schema};
296
297    fn test_schema() -> SchemaRef {
298        Arc::new(Schema::new(vec![
299            Field::new("id", DataType::Int64, false),
300            Field::new("name", DataType::Utf8, true),
301            Field::new("price", DataType::Float64, true),
302        ]))
303    }
304
305    fn make_batch(ids: &[i64], names: &[&str], prices: &[f64]) -> RecordBatch {
306        RecordBatch::try_new(
307            test_schema(),
308            vec![
309                Arc::new(Int64Array::from(ids.to_vec())),
310                Arc::new(StringArray::from(
311                    names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
312                )),
313                Arc::new(Float64Array::from(prices.to_vec())),
314            ],
315        )
316        .unwrap()
317    }
318
319    fn test_ctx() -> datafusion::prelude::SessionContext {
320        // Use a plain context (no streaming validator) for unit tests.
321        datafusion::prelude::SessionContext::new()
322    }
323
324    async fn count_rows(ctx: &datafusion::prelude::SessionContext, sql: &str) -> usize {
325        let df = ctx.sql(sql).await.unwrap();
326        df.collect()
327            .await
328            .unwrap()
329            .iter()
330            .map(RecordBatch::num_rows)
331            .sum()
332    }
333
334    #[test]
335    fn test_handle_swap_and_clear() {
336        let provider = LiveSourceProvider::new(test_schema());
337        let h1 = provider.handle();
338        let h2 = h1.clone();
339
340        h1.swap(vec![make_batch(&[1, 2], &["A", "B"], &[1.0, 2.0])]);
341        assert_eq!(h2.slot.lock().len(), 1);
342
343        h2.clear();
344        assert_eq!(h1.slot.lock().len(), 0);
345    }
346
347    #[tokio::test]
348    async fn test_scan_reads_fresh_data_each_execute() {
349        let provider = Arc::new(LiveSourceProvider::new(test_schema()));
350        let handle = provider.handle();
351        let ctx = test_ctx();
352        ctx.register_table("t", provider).unwrap();
353
354        handle.swap(vec![make_batch(
355            &[1, 2, 3],
356            &["A", "B", "C"],
357            &[10.0, 20.0, 30.0],
358        )]);
359        assert_eq!(count_rows(&ctx, "SELECT * FROM t").await, 3);
360        assert_eq!(count_rows(&ctx, "SELECT * FROM t").await, 3);
361    }
362
363    #[tokio::test]
364    async fn test_scan_empty() {
365        let provider = Arc::new(LiveSourceProvider::new(test_schema()));
366        let ctx = test_ctx();
367        ctx.register_table("t", provider).unwrap();
368        assert_eq!(count_rows(&ctx, "SELECT * FROM t").await, 0);
369    }
370
371    #[tokio::test]
372    async fn test_projection() {
373        let provider = Arc::new(LiveSourceProvider::new(test_schema()));
374        let handle = provider.handle();
375        let ctx = test_ctx();
376        ctx.register_table("t", provider).unwrap();
377
378        handle.swap(vec![make_batch(
379            &[1, 2, 3],
380            &["A", "B", "C"],
381            &[10.0, 20.0, 30.0],
382        )]);
383
384        let df = ctx.sql("SELECT id, price FROM t").await.unwrap();
385        let result = df.collect().await.unwrap();
386        assert_eq!(result.iter().map(RecordBatch::num_rows).sum::<usize>(), 3);
387        assert_eq!(result[0].schema().fields().len(), 2);
388        assert_eq!(result[0].schema().field(0).name(), "id");
389        assert_eq!(result[0].schema().field(1).name(), "price");
390    }
391
392    #[tokio::test]
393    async fn test_multi_cycle() {
394        let provider = Arc::new(LiveSourceProvider::new(test_schema()));
395        let handle = provider.handle();
396        let ctx = test_ctx();
397        ctx.register_table("t", provider).unwrap();
398
399        handle.swap(vec![make_batch(&[1], &["A"], &[10.0])]);
400        assert_eq!(count_rows(&ctx, "SELECT * FROM t").await, 1);
401
402        handle.swap(vec![make_batch(&[2, 3], &["B", "C"], &[20.0, 30.0])]);
403        assert_eq!(count_rows(&ctx, "SELECT * FROM t").await, 2);
404
405        handle.clear();
406        assert_eq!(count_rows(&ctx, "SELECT * FROM t").await, 0);
407    }
408
409    #[tokio::test]
410    async fn test_cached_plan_sees_fresh_data() {
411        use datafusion::physical_plan::ExecutionPlanProperties as _;
412
413        let provider = Arc::new(LiveSourceProvider::new(test_schema()));
414        let handle = provider.handle();
415        let ctx = test_ctx();
416        ctx.register_table("t", provider).unwrap();
417
418        handle.swap(vec![make_batch(&[1], &["A"], &[10.0])]);
419        let logical = ctx
420            .state()
421            .create_logical_plan("SELECT * FROM t")
422            .await
423            .unwrap();
424        let physical = ctx.state().create_physical_plan(&logical).await.unwrap();
425        assert_eq!(physical.output_partitioning().partition_count(), 1);
426
427        let task_ctx = ctx.task_ctx();
428        let r1 = datafusion::physical_plan::collect(physical.clone(), task_ctx.clone())
429            .await
430            .unwrap();
431        assert_eq!(r1.iter().map(RecordBatch::num_rows).sum::<usize>(), 1);
432
433        handle.swap(vec![make_batch(
434            &[2, 3, 4],
435            &["B", "C", "D"],
436            &[20.0, 30.0, 40.0],
437        )]);
438        let r2 = datafusion::physical_plan::collect(physical.clone(), task_ctx.clone())
439            .await
440            .unwrap();
441        assert_eq!(r2.iter().map(RecordBatch::num_rows).sum::<usize>(), 3);
442
443        handle.clear();
444        let r3 = datafusion::physical_plan::collect(physical, task_ctx)
445            .await
446            .unwrap();
447        assert_eq!(r3.iter().map(RecordBatch::num_rows).sum::<usize>(), 0);
448    }
449}