1use std::any::Any;
10use std::sync::Arc;
11
12use arrow::array::RecordBatch;
13use arrow::datatypes::SchemaRef;
14use async_trait::async_trait;
15use datafusion::catalog::Session;
16use datafusion::datasource::TableProvider;
17use datafusion::error::DataFusionError;
18use datafusion::execution::{SendableRecordBatchStream, TaskContext};
19use datafusion::logical_expr::Expr;
20use datafusion::physical_expr::{EquivalenceProperties, Partitioning};
21use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
22use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
23use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
24use datafusion_common::Statistics;
25use datafusion_expr::TableType;
26use parking_lot::Mutex;
27
28type BatchSlot = Arc<Mutex<Arc<Vec<RecordBatch>>>>;
35
36fn new_slot() -> BatchSlot {
37 Arc::new(Mutex::new(Arc::new(Vec::new())))
38}
39
40pub struct LiveSourceProvider {
47 current: BatchSlot,
48 schema: SchemaRef,
49}
50
51impl LiveSourceProvider {
52 #[must_use]
54 pub fn new(schema: SchemaRef) -> Self {
55 Self {
56 current: new_slot(),
57 schema,
58 }
59 }
60
61 #[must_use]
63 pub fn handle(&self) -> LiveSourceHandle {
64 LiveSourceHandle {
65 slot: Arc::clone(&self.current),
66 }
67 }
68}
69
70impl std::fmt::Debug for LiveSourceProvider {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 f.debug_struct("LiveSourceProvider")
73 .field("schema_fields", &self.schema.fields().len())
74 .finish_non_exhaustive()
75 }
76}
77
78#[async_trait]
79impl TableProvider for LiveSourceProvider {
80 fn as_any(&self) -> &dyn Any {
81 self
82 }
83
84 fn schema(&self) -> SchemaRef {
85 self.schema.clone()
86 }
87
88 fn table_type(&self) -> TableType {
89 TableType::Base
90 }
91
92 async fn scan(
93 &self,
94 _state: &dyn Session,
95 projection: Option<&Vec<usize>>,
96 _filters: &[Expr],
97 _limit: Option<usize>,
98 ) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
99 Ok(Arc::new(LiveSourceExec::new(
100 Arc::clone(&self.current),
101 self.schema.clone(),
102 projection.cloned(),
103 )))
104 }
105}
106
107pub(crate) struct LiveSourceExec {
113 slot: BatchSlot,
114 schema: SchemaRef,
115 projection: Option<Vec<usize>>,
116 properties: PlanProperties,
117}
118
119impl LiveSourceExec {
120 fn new(slot: BatchSlot, source_schema: SchemaRef, projection: Option<Vec<usize>>) -> Self {
121 let schema = match &projection {
122 Some(indices) => {
123 let fields: Vec<_> = indices
124 .iter()
125 .map(|&i| source_schema.field(i).clone())
126 .collect();
127 Arc::new(arrow::datatypes::Schema::new(fields))
128 }
129 None => source_schema,
130 };
131 let properties = PlanProperties::new(
132 EquivalenceProperties::new(schema.clone()),
133 Partitioning::UnknownPartitioning(1),
134 EmissionType::Final,
135 Boundedness::Bounded,
136 );
137 Self {
138 slot,
139 schema,
140 projection,
141 properties,
142 }
143 }
144}
145
146impl std::fmt::Debug for LiveSourceExec {
147 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148 f.debug_struct("LiveSourceExec")
149 .field("schema_fields", &self.schema.fields().len())
150 .finish_non_exhaustive()
151 }
152}
153
154impl DisplayAs for LiveSourceExec {
155 fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156 match t {
157 DisplayFormatType::Default | DisplayFormatType::Verbose => {
158 write!(f, "LiveSourceExec: schema={}", self.schema.fields().len())
159 }
160 DisplayFormatType::TreeRender => write!(f, "LiveSourceExec"),
161 }
162 }
163}
164
165impl ExecutionPlan for LiveSourceExec {
166 fn name(&self) -> &'static str {
167 "LiveSourceExec"
168 }
169
170 fn as_any(&self) -> &dyn Any {
171 self
172 }
173
174 fn schema(&self) -> SchemaRef {
175 self.schema.clone()
176 }
177
178 fn properties(&self) -> &PlanProperties {
179 &self.properties
180 }
181
182 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
183 vec![]
184 }
185
186 fn with_new_children(
187 self: Arc<Self>,
188 children: Vec<Arc<dyn ExecutionPlan>>,
189 ) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
190 if children.is_empty() {
191 Ok(self)
192 } else {
193 Err(DataFusionError::Plan(
194 "LiveSourceExec is a leaf node".to_string(),
195 ))
196 }
197 }
198
199 fn execute(
200 &self,
201 partition: usize,
202 _context: Arc<TaskContext>,
203 ) -> Result<SendableRecordBatchStream, DataFusionError> {
204 if partition != 0 {
205 return Err(DataFusionError::Plan(format!(
206 "LiveSourceExec only supports partition 0, got {partition}"
207 )));
208 }
209
210 let batches_arc: Arc<Vec<RecordBatch>> = Arc::clone(&self.slot.lock());
214 let schema = self.schema.clone();
215 let projection = self.projection.clone();
216
217 let output = futures::stream::iter(if batches_arc.is_empty() {
218 vec![Ok(RecordBatch::new_empty(schema))]
219 } else if let Some(indices) = projection {
220 batches_arc
221 .iter()
222 .map(|batch| batch.project(&indices).map_err(DataFusionError::from))
223 .collect()
224 } else {
225 batches_arc.iter().cloned().map(Ok).collect()
226 });
227
228 Ok(Box::pin(RecordBatchStreamAdapter::new(
229 self.schema.clone(),
230 output,
231 )))
232 }
233
234 fn statistics(&self) -> datafusion_common::Result<Statistics> {
235 Ok(Statistics::default())
236 }
237}
238
239impl datafusion::physical_plan::ExecutionPlanProperties for LiveSourceExec {
240 fn output_partitioning(&self) -> &Partitioning {
241 self.properties.output_partitioning()
242 }
243
244 fn output_ordering(&self) -> Option<&datafusion::physical_expr::LexOrdering> {
245 self.properties.output_ordering()
246 }
247
248 fn boundedness(&self) -> Boundedness {
249 Boundedness::Bounded
250 }
251
252 fn pipeline_behavior(&self) -> EmissionType {
253 EmissionType::Final
254 }
255
256 fn equivalence_properties(&self) -> &EquivalenceProperties {
257 self.properties.equivalence_properties()
258 }
259}
260
261#[derive(Clone)]
265pub struct LiveSourceHandle {
266 slot: BatchSlot,
267}
268
269impl LiveSourceHandle {
270 pub fn swap(&self, batches: Vec<RecordBatch>) {
274 *self.slot.lock() = Arc::new(batches);
275 }
276
277 pub fn clear(&self) {
279 *self.slot.lock() = Arc::new(Vec::new());
280 }
281}
282
283impl std::fmt::Debug for LiveSourceHandle {
284 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285 f.debug_struct("LiveSourceHandle").finish()
286 }
287}
288
289#[cfg(test)]
292mod tests {
293 use super::*;
294 use arrow::array::{Float64Array, Int64Array, StringArray};
295 use arrow::datatypes::{DataType, Field, Schema};
296
297 fn test_schema() -> SchemaRef {
298 Arc::new(Schema::new(vec![
299 Field::new("id", DataType::Int64, false),
300 Field::new("name", DataType::Utf8, true),
301 Field::new("price", DataType::Float64, true),
302 ]))
303 }
304
305 fn make_batch(ids: &[i64], names: &[&str], prices: &[f64]) -> RecordBatch {
306 RecordBatch::try_new(
307 test_schema(),
308 vec![
309 Arc::new(Int64Array::from(ids.to_vec())),
310 Arc::new(StringArray::from(
311 names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
312 )),
313 Arc::new(Float64Array::from(prices.to_vec())),
314 ],
315 )
316 .unwrap()
317 }
318
319 fn test_ctx() -> datafusion::prelude::SessionContext {
320 datafusion::prelude::SessionContext::new()
322 }
323
324 async fn count_rows(ctx: &datafusion::prelude::SessionContext, sql: &str) -> usize {
325 let df = ctx.sql(sql).await.unwrap();
326 df.collect()
327 .await
328 .unwrap()
329 .iter()
330 .map(RecordBatch::num_rows)
331 .sum()
332 }
333
334 #[test]
335 fn test_handle_swap_and_clear() {
336 let provider = LiveSourceProvider::new(test_schema());
337 let h1 = provider.handle();
338 let h2 = h1.clone();
339
340 h1.swap(vec![make_batch(&[1, 2], &["A", "B"], &[1.0, 2.0])]);
341 assert_eq!(h2.slot.lock().len(), 1);
342
343 h2.clear();
344 assert_eq!(h1.slot.lock().len(), 0);
345 }
346
347 #[tokio::test]
348 async fn test_scan_reads_fresh_data_each_execute() {
349 let provider = Arc::new(LiveSourceProvider::new(test_schema()));
350 let handle = provider.handle();
351 let ctx = test_ctx();
352 ctx.register_table("t", provider).unwrap();
353
354 handle.swap(vec![make_batch(
355 &[1, 2, 3],
356 &["A", "B", "C"],
357 &[10.0, 20.0, 30.0],
358 )]);
359 assert_eq!(count_rows(&ctx, "SELECT * FROM t").await, 3);
360 assert_eq!(count_rows(&ctx, "SELECT * FROM t").await, 3);
361 }
362
363 #[tokio::test]
364 async fn test_scan_empty() {
365 let provider = Arc::new(LiveSourceProvider::new(test_schema()));
366 let ctx = test_ctx();
367 ctx.register_table("t", provider).unwrap();
368 assert_eq!(count_rows(&ctx, "SELECT * FROM t").await, 0);
369 }
370
371 #[tokio::test]
372 async fn test_projection() {
373 let provider = Arc::new(LiveSourceProvider::new(test_schema()));
374 let handle = provider.handle();
375 let ctx = test_ctx();
376 ctx.register_table("t", provider).unwrap();
377
378 handle.swap(vec![make_batch(
379 &[1, 2, 3],
380 &["A", "B", "C"],
381 &[10.0, 20.0, 30.0],
382 )]);
383
384 let df = ctx.sql("SELECT id, price FROM t").await.unwrap();
385 let result = df.collect().await.unwrap();
386 assert_eq!(result.iter().map(RecordBatch::num_rows).sum::<usize>(), 3);
387 assert_eq!(result[0].schema().fields().len(), 2);
388 assert_eq!(result[0].schema().field(0).name(), "id");
389 assert_eq!(result[0].schema().field(1).name(), "price");
390 }
391
392 #[tokio::test]
393 async fn test_multi_cycle() {
394 let provider = Arc::new(LiveSourceProvider::new(test_schema()));
395 let handle = provider.handle();
396 let ctx = test_ctx();
397 ctx.register_table("t", provider).unwrap();
398
399 handle.swap(vec![make_batch(&[1], &["A"], &[10.0])]);
400 assert_eq!(count_rows(&ctx, "SELECT * FROM t").await, 1);
401
402 handle.swap(vec![make_batch(&[2, 3], &["B", "C"], &[20.0, 30.0])]);
403 assert_eq!(count_rows(&ctx, "SELECT * FROM t").await, 2);
404
405 handle.clear();
406 assert_eq!(count_rows(&ctx, "SELECT * FROM t").await, 0);
407 }
408
409 #[tokio::test]
410 async fn test_cached_plan_sees_fresh_data() {
411 use datafusion::physical_plan::ExecutionPlanProperties as _;
412
413 let provider = Arc::new(LiveSourceProvider::new(test_schema()));
414 let handle = provider.handle();
415 let ctx = test_ctx();
416 ctx.register_table("t", provider).unwrap();
417
418 handle.swap(vec![make_batch(&[1], &["A"], &[10.0])]);
419 let logical = ctx
420 .state()
421 .create_logical_plan("SELECT * FROM t")
422 .await
423 .unwrap();
424 let physical = ctx.state().create_physical_plan(&logical).await.unwrap();
425 assert_eq!(physical.output_partitioning().partition_count(), 1);
426
427 let task_ctx = ctx.task_ctx();
428 let r1 = datafusion::physical_plan::collect(physical.clone(), task_ctx.clone())
429 .await
430 .unwrap();
431 assert_eq!(r1.iter().map(RecordBatch::num_rows).sum::<usize>(), 1);
432
433 handle.swap(vec![make_batch(
434 &[2, 3, 4],
435 &["B", "C", "D"],
436 &[20.0, 30.0, 40.0],
437 )]);
438 let r2 = datafusion::physical_plan::collect(physical.clone(), task_ctx.clone())
439 .await
440 .unwrap();
441 assert_eq!(r2.iter().map(RecordBatch::num_rows).sum::<usize>(), 3);
442
443 handle.clear();
444 let r3 = datafusion::physical_plan::collect(physical, task_ctx)
445 .await
446 .unwrap();
447 assert_eq!(r3.iter().map(RecordBatch::num_rows).sum::<usize>(), 0);
448 }
449}