Skip to main content

laminar_sql/datafusion/
table_provider.rs

1//! Streaming table provider for `DataFusion` integration
2//!
3//! This module provides `StreamingTableProvider` which implements `DataFusion`'s
4//! `TableProvider` trait, allowing streaming sources to be registered as
5//! tables in a `SessionContext` and queried with SQL.
6
7use std::any::Any;
8use std::sync::Arc;
9
10use arrow_schema::SchemaRef;
11use async_trait::async_trait;
12use datafusion::catalog::Session;
13use datafusion::datasource::TableProvider;
14use datafusion::physical_plan::ExecutionPlan;
15use datafusion_common::DataFusionError;
16use datafusion_expr::{Expr, TableProviderFilterPushDown, TableType};
17
18use super::exec::StreamingScanExec;
19use super::source::StreamSourceRef;
20
21/// A `DataFusion` table provider backed by a streaming source.
22///
23/// This allows streaming sources to be registered as tables in `DataFusion`'s
24/// `SessionContext` and queried using SQL. The provider handles:
25///
26/// - Schema exposure to `DataFusion`'s catalog
27/// - Projection pushdown to the source
28/// - Filter pushdown when supported by the source
29///
30/// # Usage
31///
32/// ```rust,ignore
33/// let source = Arc::new(ChannelStreamSource::new(schema));
34/// let provider = StreamingTableProvider::new("events", source);
35/// ctx.register_table("events", Arc::new(provider))?;
36///
37/// let df = ctx.sql("SELECT * FROM events WHERE id > 100").await?;
38/// ```
39#[derive(Debug)]
40pub struct StreamingTableProvider {
41    /// Table name
42    name: String,
43    /// The underlying streaming source
44    source: StreamSourceRef,
45}
46
47impl StreamingTableProvider {
48    /// Creates a new streaming table provider.
49    #[must_use]
50    pub fn new(name: impl Into<String>, source: StreamSourceRef) -> Self {
51        Self {
52            name: name.into(),
53            source,
54        }
55    }
56
57    /// Returns the table name.
58    #[must_use]
59    pub fn name(&self) -> &str {
60        &self.name
61    }
62
63    /// Returns the underlying streaming source.
64    #[must_use]
65    pub fn source(&self) -> &StreamSourceRef {
66        &self.source
67    }
68}
69
70#[async_trait]
71impl TableProvider for StreamingTableProvider {
72    fn as_any(&self) -> &dyn Any {
73        self
74    }
75
76    fn schema(&self) -> SchemaRef {
77        self.source.schema()
78    }
79
80    fn table_type(&self) -> TableType {
81        // Streaming tables behave like base tables but are read-only
82        TableType::Base
83    }
84
85    fn supports_filters_pushdown(
86        &self,
87        filters: &[&Expr],
88    ) -> Result<Vec<TableProviderFilterPushDown>, DataFusionError> {
89        // Ask the source which filters it can handle
90        let expr_refs: Vec<Expr> = filters.iter().map(|e| (*e).clone()).collect();
91        let supported = self.source.supports_filters(&expr_refs);
92
93        Ok(supported
94            .into_iter()
95            .map(|s| {
96                if s {
97                    TableProviderFilterPushDown::Exact
98                } else {
99                    TableProviderFilterPushDown::Unsupported
100                }
101            })
102            .collect())
103    }
104
105    async fn scan(
106        &self,
107        _state: &dyn Session,
108        projection: Option<&Vec<usize>>,
109        filters: &[Expr],
110        _limit: Option<usize>,
111    ) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
112        // DataFusion only passes filters here that `supports_filters_pushdown`
113        // already claimed as `Exact`/`Inexact` — no need to re-ask the source
114        // which ones it can handle. Forwarding all of them was the old path's
115        // cost (two Vec<Expr> clones per scan for nothing).
116        Ok(Arc::new(StreamingScanExec::new(
117            Arc::clone(&self.source),
118            projection.cloned(),
119            filters.to_vec(),
120        )))
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use crate::datafusion::source::StreamSource;
128    use arrow_schema::{DataType, Field, Schema};
129    use datafusion::execution::SendableRecordBatchStream;
130
131    #[derive(Debug)]
132    struct MockSource {
133        schema: SchemaRef,
134        supports_eq_filter: bool,
135    }
136
137    #[async_trait]
138    impl StreamSource for MockSource {
139        fn schema(&self) -> SchemaRef {
140            Arc::clone(&self.schema)
141        }
142
143        fn stream(
144            &self,
145            _projection: Option<Vec<usize>>,
146            _filters: Vec<Expr>,
147        ) -> Result<SendableRecordBatchStream, DataFusionError> {
148            Err(DataFusionError::NotImplemented("mock".to_string()))
149        }
150
151        fn supports_filters(&self, filters: &[Expr]) -> Vec<bool> {
152            filters
153                .iter()
154                .map(|f| {
155                    if self.supports_eq_filter {
156                        // Only support equality filters for testing
157                        matches!(f, Expr::BinaryExpr(e) if e.op == datafusion_expr::Operator::Eq)
158                    } else {
159                        false
160                    }
161                })
162                .collect()
163        }
164    }
165
166    fn test_schema() -> SchemaRef {
167        Arc::new(Schema::new(vec![
168            Field::new("id", DataType::Int64, false),
169            Field::new("name", DataType::Utf8, true),
170        ]))
171    }
172
173    #[test]
174    fn test_table_provider_schema() {
175        let schema = test_schema();
176        let source: StreamSourceRef = Arc::new(MockSource {
177            schema: Arc::clone(&schema),
178            supports_eq_filter: false,
179        });
180        let provider = StreamingTableProvider::new("test_table", source);
181
182        assert_eq!(provider.schema(), schema);
183        assert_eq!(provider.name(), "test_table");
184    }
185
186    #[test]
187    fn test_table_provider_type() {
188        let schema = test_schema();
189        let source: StreamSourceRef = Arc::new(MockSource {
190            schema,
191            supports_eq_filter: false,
192        });
193        let provider = StreamingTableProvider::new("test", source);
194
195        assert_eq!(provider.table_type(), TableType::Base);
196    }
197
198    #[test]
199    fn test_filter_pushdown_unsupported() {
200        let schema = test_schema();
201        let source: StreamSourceRef = Arc::new(MockSource {
202            schema,
203            supports_eq_filter: false,
204        });
205        let provider = StreamingTableProvider::new("test", source);
206
207        let filter = Expr::Literal(datafusion_common::ScalarValue::Int64(Some(1)), None);
208        let result = provider.supports_filters_pushdown(&[&filter]).unwrap();
209
210        assert_eq!(result.len(), 1);
211        assert!(matches!(
212            result[0],
213            TableProviderFilterPushDown::Unsupported
214        ));
215    }
216
217    #[test]
218    fn test_filter_pushdown_supported() {
219        let schema = test_schema();
220        let source: StreamSourceRef = Arc::new(MockSource {
221            schema,
222            supports_eq_filter: true,
223        });
224        let provider = StreamingTableProvider::new("test", source);
225
226        // Create an equality filter: id = 1
227        let filter = Expr::BinaryExpr(datafusion_expr::BinaryExpr {
228            left: Box::new(Expr::Column(datafusion_common::Column::new_unqualified(
229                "id",
230            ))),
231            op: datafusion_expr::Operator::Eq,
232            right: Box::new(Expr::Literal(
233                datafusion_common::ScalarValue::Int64(Some(1)),
234                None,
235            )),
236        });
237        let result = provider.supports_filters_pushdown(&[&filter]).unwrap();
238
239        assert_eq!(result.len(), 1);
240        assert!(matches!(result[0], TableProviderFilterPushDown::Exact));
241    }
242
243    #[tokio::test]
244    async fn test_scan_creates_exec() {
245        use crate::datafusion::create_session_context;
246
247        let schema = test_schema();
248        let source: StreamSourceRef = Arc::new(MockSource {
249            schema: Arc::clone(&schema),
250            supports_eq_filter: false,
251        });
252        let provider = StreamingTableProvider::new("test", source);
253
254        let ctx = create_session_context();
255        let session_state = ctx.state();
256
257        let exec = provider
258            .scan(&session_state, None, &[], None)
259            .await
260            .unwrap();
261
262        // Verify it's a StreamingScanExec
263        assert!(exec.as_any().is::<StreamingScanExec>());
264        assert_eq!(exec.schema(), schema);
265    }
266
267    #[tokio::test]
268    async fn test_scan_with_projection() {
269        use crate::datafusion::create_session_context;
270
271        let schema = test_schema();
272        let source: StreamSourceRef = Arc::new(MockSource {
273            schema,
274            supports_eq_filter: false,
275        });
276        let provider = StreamingTableProvider::new("test", source);
277
278        let ctx = create_session_context();
279        let session_state = ctx.state();
280
281        let projection = vec![0]; // Only id column
282        let exec = provider
283            .scan(&session_state, Some(&projection), &[], None)
284            .await
285            .unwrap();
286
287        let output_schema = exec.schema();
288        assert_eq!(output_schema.fields().len(), 1);
289        assert_eq!(output_schema.field(0).name(), "id");
290    }
291}