laminar_sql/datafusion/
table_provider.rs1use std::any::Any;
8use std::sync::Arc;
9
10use arrow_schema::SchemaRef;
11use async_trait::async_trait;
12use datafusion::catalog::Session;
13use datafusion::datasource::TableProvider;
14use datafusion::physical_plan::ExecutionPlan;
15use datafusion_common::DataFusionError;
16use datafusion_expr::{Expr, TableProviderFilterPushDown, TableType};
17
18use super::exec::StreamingScanExec;
19use super::source::StreamSourceRef;
20
21#[derive(Debug)]
40pub struct StreamingTableProvider {
41 name: String,
43 source: StreamSourceRef,
45}
46
47impl StreamingTableProvider {
48 #[must_use]
50 pub fn new(name: impl Into<String>, source: StreamSourceRef) -> Self {
51 Self {
52 name: name.into(),
53 source,
54 }
55 }
56
57 #[must_use]
59 pub fn name(&self) -> &str {
60 &self.name
61 }
62
63 #[must_use]
65 pub fn source(&self) -> &StreamSourceRef {
66 &self.source
67 }
68}
69
70#[async_trait]
71impl TableProvider for StreamingTableProvider {
72 fn as_any(&self) -> &dyn Any {
73 self
74 }
75
76 fn schema(&self) -> SchemaRef {
77 self.source.schema()
78 }
79
80 fn table_type(&self) -> TableType {
81 TableType::Base
83 }
84
85 fn supports_filters_pushdown(
86 &self,
87 filters: &[&Expr],
88 ) -> Result<Vec<TableProviderFilterPushDown>, DataFusionError> {
89 let expr_refs: Vec<Expr> = filters.iter().map(|e| (*e).clone()).collect();
91 let supported = self.source.supports_filters(&expr_refs);
92
93 Ok(supported
94 .into_iter()
95 .map(|s| {
96 if s {
97 TableProviderFilterPushDown::Exact
98 } else {
99 TableProviderFilterPushDown::Unsupported
100 }
101 })
102 .collect())
103 }
104
105 async fn scan(
106 &self,
107 _state: &dyn Session,
108 projection: Option<&Vec<usize>>,
109 filters: &[Expr],
110 _limit: Option<usize>,
111 ) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
112 Ok(Arc::new(StreamingScanExec::new(
117 Arc::clone(&self.source),
118 projection.cloned(),
119 filters.to_vec(),
120 )))
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use crate::datafusion::source::StreamSource;
128 use arrow_schema::{DataType, Field, Schema};
129 use datafusion::execution::SendableRecordBatchStream;
130
131 #[derive(Debug)]
132 struct MockSource {
133 schema: SchemaRef,
134 supports_eq_filter: bool,
135 }
136
137 #[async_trait]
138 impl StreamSource for MockSource {
139 fn schema(&self) -> SchemaRef {
140 Arc::clone(&self.schema)
141 }
142
143 fn stream(
144 &self,
145 _projection: Option<Vec<usize>>,
146 _filters: Vec<Expr>,
147 ) -> Result<SendableRecordBatchStream, DataFusionError> {
148 Err(DataFusionError::NotImplemented("mock".to_string()))
149 }
150
151 fn supports_filters(&self, filters: &[Expr]) -> Vec<bool> {
152 filters
153 .iter()
154 .map(|f| {
155 if self.supports_eq_filter {
156 matches!(f, Expr::BinaryExpr(e) if e.op == datafusion_expr::Operator::Eq)
158 } else {
159 false
160 }
161 })
162 .collect()
163 }
164 }
165
166 fn test_schema() -> SchemaRef {
167 Arc::new(Schema::new(vec![
168 Field::new("id", DataType::Int64, false),
169 Field::new("name", DataType::Utf8, true),
170 ]))
171 }
172
173 #[test]
174 fn test_table_provider_schema() {
175 let schema = test_schema();
176 let source: StreamSourceRef = Arc::new(MockSource {
177 schema: Arc::clone(&schema),
178 supports_eq_filter: false,
179 });
180 let provider = StreamingTableProvider::new("test_table", source);
181
182 assert_eq!(provider.schema(), schema);
183 assert_eq!(provider.name(), "test_table");
184 }
185
186 #[test]
187 fn test_table_provider_type() {
188 let schema = test_schema();
189 let source: StreamSourceRef = Arc::new(MockSource {
190 schema,
191 supports_eq_filter: false,
192 });
193 let provider = StreamingTableProvider::new("test", source);
194
195 assert_eq!(provider.table_type(), TableType::Base);
196 }
197
198 #[test]
199 fn test_filter_pushdown_unsupported() {
200 let schema = test_schema();
201 let source: StreamSourceRef = Arc::new(MockSource {
202 schema,
203 supports_eq_filter: false,
204 });
205 let provider = StreamingTableProvider::new("test", source);
206
207 let filter = Expr::Literal(datafusion_common::ScalarValue::Int64(Some(1)), None);
208 let result = provider.supports_filters_pushdown(&[&filter]).unwrap();
209
210 assert_eq!(result.len(), 1);
211 assert!(matches!(
212 result[0],
213 TableProviderFilterPushDown::Unsupported
214 ));
215 }
216
217 #[test]
218 fn test_filter_pushdown_supported() {
219 let schema = test_schema();
220 let source: StreamSourceRef = Arc::new(MockSource {
221 schema,
222 supports_eq_filter: true,
223 });
224 let provider = StreamingTableProvider::new("test", source);
225
226 let filter = Expr::BinaryExpr(datafusion_expr::BinaryExpr {
228 left: Box::new(Expr::Column(datafusion_common::Column::new_unqualified(
229 "id",
230 ))),
231 op: datafusion_expr::Operator::Eq,
232 right: Box::new(Expr::Literal(
233 datafusion_common::ScalarValue::Int64(Some(1)),
234 None,
235 )),
236 });
237 let result = provider.supports_filters_pushdown(&[&filter]).unwrap();
238
239 assert_eq!(result.len(), 1);
240 assert!(matches!(result[0], TableProviderFilterPushDown::Exact));
241 }
242
243 #[tokio::test]
244 async fn test_scan_creates_exec() {
245 use crate::datafusion::create_session_context;
246
247 let schema = test_schema();
248 let source: StreamSourceRef = Arc::new(MockSource {
249 schema: Arc::clone(&schema),
250 supports_eq_filter: false,
251 });
252 let provider = StreamingTableProvider::new("test", source);
253
254 let ctx = create_session_context();
255 let session_state = ctx.state();
256
257 let exec = provider
258 .scan(&session_state, None, &[], None)
259 .await
260 .unwrap();
261
262 assert!(exec.as_any().is::<StreamingScanExec>());
264 assert_eq!(exec.schema(), schema);
265 }
266
267 #[tokio::test]
268 async fn test_scan_with_projection() {
269 use crate::datafusion::create_session_context;
270
271 let schema = test_schema();
272 let source: StreamSourceRef = Arc::new(MockSource {
273 schema,
274 supports_eq_filter: false,
275 });
276 let provider = StreamingTableProvider::new("test", source);
277
278 let ctx = create_session_context();
279 let session_state = ctx.state();
280
281 let projection = vec![0]; let exec = provider
283 .scan(&session_state, Some(&projection), &[], None)
284 .await
285 .unwrap();
286
287 let output_schema = exec.schema();
288 assert_eq!(output_schema.fields().len(), 1);
289 assert_eq!(output_schema.field(0).name(), "id");
290 }
291}