1use std::fmt::{Debug, Formatter};
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use arrow_array::RecordBatch;
12use arrow_schema::SchemaRef;
13use async_trait::async_trait;
14use datafusion::physical_plan::RecordBatchStream;
15use datafusion_common::DataFusionError;
16use datafusion_expr::Expr;
17use futures::Stream;
18use parking_lot::Mutex;
19
20use super::bridge::{BridgeSender, StreamBridge};
21use super::source::{SortColumn, StreamSource};
22
23const DEFAULT_CHANNEL_CAPACITY: usize = 1024;
25
26pub struct ChannelStreamSource {
30 schema: SchemaRef,
32 bridge: Mutex<Option<StreamBridge>>,
34 sender: Mutex<Option<BridgeSender>>,
36 capacity: usize,
38 ordering: Option<Vec<SortColumn>>,
40}
41
42impl ChannelStreamSource {
43 #[must_use]
45 pub fn new(schema: SchemaRef) -> Self {
46 Self::with_capacity(schema, DEFAULT_CHANNEL_CAPACITY)
47 }
48
49 #[must_use]
51 pub fn with_capacity(schema: SchemaRef, capacity: usize) -> Self {
52 let bridge = StreamBridge::new(Arc::clone(&schema), capacity);
53 let sender = bridge.sender();
54 Self {
55 schema,
56 bridge: Mutex::new(Some(bridge)),
57 sender: Mutex::new(Some(sender)),
58 capacity,
59 ordering: None,
60 }
61 }
62
63 #[must_use]
67 pub fn with_ordering(mut self, ordering: Vec<SortColumn>) -> Self {
68 self.ordering = Some(ordering);
69 self
70 }
71
72 #[must_use]
82 pub fn take_sender(&self) -> Option<BridgeSender> {
83 self.sender.lock().take()
84 }
85
86 #[must_use]
92 pub fn sender(&self) -> Option<BridgeSender> {
93 self.sender.lock().as_ref().map(BridgeSender::clone)
94 }
95
96 pub fn reset(&self) -> BridgeSender {
104 let bridge = StreamBridge::new(Arc::clone(&self.schema), self.capacity);
105 let sender = bridge.sender();
106 *self.bridge.lock() = Some(bridge);
107 *self.sender.lock() = Some(sender.clone());
108 sender
109 }
110}
111
112impl Debug for ChannelStreamSource {
113 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
114 f.debug_struct("ChannelStreamSource")
115 .field("schema", &self.schema)
116 .field("capacity", &self.capacity)
117 .finish_non_exhaustive()
118 }
119}
120
121#[async_trait]
122impl StreamSource for ChannelStreamSource {
123 fn schema(&self) -> SchemaRef {
124 Arc::clone(&self.schema)
125 }
126
127 fn output_ordering(&self) -> Option<Vec<SortColumn>> {
128 self.ordering.clone()
129 }
130
131 fn stream(
132 &self,
133 projection: Option<Vec<usize>>,
134 _filters: Vec<Expr>,
135 ) -> Result<datafusion::physical_plan::SendableRecordBatchStream, DataFusionError> {
136 let mut bridge_guard = self.bridge.lock();
137 let bridge = bridge_guard.take().ok_or_else(|| {
138 DataFusionError::Execution(
139 "Stream already taken; call reset() to create a new bridge".to_string(),
140 )
141 })?;
142
143 let inner_stream = bridge.into_stream();
144
145 let stream: datafusion::physical_plan::SendableRecordBatchStream =
147 if let Some(indices) = projection {
148 let projected_schema = {
149 let fields: Vec<_> = indices
150 .iter()
151 .map(|&i| self.schema.field(i).clone())
152 .collect();
153 Arc::new(arrow_schema::Schema::new(fields))
154 };
155 Box::pin(ProjectingStream::new(
156 inner_stream,
157 projected_schema,
158 indices,
159 ))
160 } else {
161 Box::pin(inner_stream)
162 };
163
164 Ok(stream)
165 }
166}
167
168struct ProjectingStream<S> {
170 inner: S,
171 schema: SchemaRef,
172 indices: Vec<usize>,
173}
174
175impl<S> ProjectingStream<S> {
176 fn new(inner: S, schema: SchemaRef, indices: Vec<usize>) -> Self {
177 Self {
178 inner,
179 schema,
180 indices,
181 }
182 }
183}
184
185impl<S> Debug for ProjectingStream<S> {
186 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
187 f.debug_struct("ProjectingStream")
188 .field("schema", &self.schema)
189 .field("indices", &self.indices)
190 .finish_non_exhaustive()
191 }
192}
193
194impl<S> Stream for ProjectingStream<S>
195where
196 S: Stream<Item = Result<RecordBatch, DataFusionError>> + Unpin,
197{
198 type Item = Result<RecordBatch, DataFusionError>;
199
200 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
201 match Pin::new(&mut self.inner).poll_next(cx) {
202 Poll::Ready(Some(Ok(batch))) => {
203 let projected = batch.project(&self.indices).map_err(|e| {
205 DataFusionError::ArrowError(Box::new(e), Some("projection failed".to_string()))
206 });
207 Poll::Ready(Some(projected))
208 }
209 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
210 Poll::Ready(None) => Poll::Ready(None),
211 Poll::Pending => Poll::Pending,
212 }
213 }
214}
215
216impl<S> RecordBatchStream for ProjectingStream<S>
217where
218 S: Stream<Item = Result<RecordBatch, DataFusionError>> + Unpin,
219{
220 fn schema(&self) -> SchemaRef {
221 Arc::clone(&self.schema)
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use arrow_array::Int64Array;
229 use arrow_schema::{DataType, Field, Schema};
230 use futures::StreamExt;
231
232 fn test_schema() -> SchemaRef {
233 Arc::new(Schema::new(vec![
234 Field::new("id", DataType::Int64, false),
235 Field::new("value", DataType::Int64, false),
236 ]))
237 }
238
239 fn test_batch(schema: &SchemaRef, ids: Vec<i64>, values: Vec<i64>) -> RecordBatch {
240 RecordBatch::try_new(
241 Arc::clone(schema),
242 vec![
243 Arc::new(Int64Array::from(ids)),
244 Arc::new(Int64Array::from(values)),
245 ],
246 )
247 .unwrap()
248 }
249
250 #[test]
251 fn test_channel_source_schema() {
252 let schema = test_schema();
253 let source = ChannelStreamSource::new(Arc::clone(&schema));
254
255 assert_eq!(source.schema(), schema);
256 }
257
258 #[tokio::test]
259 async fn test_channel_source_stream() {
260 let schema = test_schema();
261 let source = ChannelStreamSource::new(Arc::clone(&schema));
262 let sender = source.take_sender().unwrap();
263
264 let mut stream = source.stream(None, vec![]).unwrap();
265
266 sender
268 .send(test_batch(&schema, vec![1, 2], vec![10, 20]))
269 .await
270 .unwrap();
271 drop(sender);
272
273 let batch = stream.next().await.unwrap().unwrap();
275 assert_eq!(batch.num_rows(), 2);
276 assert_eq!(batch.num_columns(), 2);
277 }
278
279 #[tokio::test]
280 async fn test_channel_source_projection() {
281 let schema = test_schema();
282 let source = ChannelStreamSource::new(Arc::clone(&schema));
283 let sender = source.take_sender().unwrap();
284
285 let mut stream = source.stream(Some(vec![1]), vec![]).unwrap();
287
288 sender
289 .send(test_batch(&schema, vec![1, 2], vec![100, 200]))
290 .await
291 .unwrap();
292 drop(sender);
293
294 let batch = stream.next().await.unwrap().unwrap();
295 assert_eq!(batch.num_columns(), 1);
296 assert_eq!(batch.schema().field(0).name(), "value");
297
298 let values = batch
299 .column(0)
300 .as_any()
301 .downcast_ref::<Int64Array>()
302 .unwrap();
303 assert_eq!(values.value(0), 100);
304 assert_eq!(values.value(1), 200);
305 }
306
307 #[tokio::test]
308 async fn test_channel_source_stream_already_taken() {
309 let schema = test_schema();
310 let source = ChannelStreamSource::new(Arc::clone(&schema));
311
312 let _stream = source.stream(None, vec![]).unwrap();
314
315 let result = source.stream(None, vec![]);
317 assert!(result.is_err());
318 }
319
320 #[tokio::test]
321 async fn test_channel_source_multiple_batches() {
322 let schema = test_schema();
323 let source = ChannelStreamSource::new(Arc::clone(&schema));
324 let sender = source.take_sender().unwrap();
325 let mut stream = source.stream(None, vec![]).unwrap();
326
327 for i in 0..5i64 {
329 sender
330 .send(test_batch(&schema, vec![i], vec![i * 10]))
331 .await
332 .unwrap();
333 }
334 drop(sender);
335
336 let mut count = 0;
338 while let Some(result) = stream.next().await {
339 result.unwrap();
340 count += 1;
341 }
342 assert_eq!(count, 5);
343 }
344
345 #[tokio::test]
346 async fn test_channel_source_take_sender_once() {
347 let schema = test_schema();
348 let source = ChannelStreamSource::new(Arc::clone(&schema));
349
350 let sender = source.take_sender();
352 assert!(sender.is_some());
353
354 let sender2 = source.take_sender();
356 assert!(sender2.is_none());
357 }
358
359 #[tokio::test]
360 async fn test_channel_source_reset() {
361 let schema = test_schema();
362 let source = ChannelStreamSource::new(Arc::clone(&schema));
363
364 let _sender = source.take_sender().unwrap();
366 let _stream = source.stream(None, vec![]).unwrap();
367
368 let new_sender = source.reset();
370 let mut new_stream = source.stream(None, vec![]).unwrap();
371
372 new_sender
374 .send(test_batch(&schema, vec![1], vec![10]))
375 .await
376 .unwrap();
377 drop(new_sender);
378
379 let batch = new_stream.next().await.unwrap().unwrap();
380 assert_eq!(batch.num_rows(), 1);
381 }
382
383 #[test]
384 fn test_channel_source_debug() {
385 let schema = test_schema();
386 let source = ChannelStreamSource::new(Arc::clone(&schema));
387
388 let debug_str = format!("{source:?}");
389 assert!(debug_str.contains("ChannelStreamSource"));
390 assert!(debug_str.contains("capacity"));
391 }
392
393 #[test]
394 fn test_channel_source_default_no_ordering() {
395 let schema = test_schema();
396 let source = ChannelStreamSource::new(Arc::clone(&schema));
397
398 assert!(source.output_ordering().is_none());
399 }
400
401 #[test]
402 fn test_channel_source_with_ordering() {
403 let schema = test_schema();
404 let source = ChannelStreamSource::new(Arc::clone(&schema))
405 .with_ordering(vec![SortColumn::ascending("id")]);
406
407 let ordering = source.output_ordering();
408 assert!(ordering.is_some());
409 let cols = ordering.unwrap();
410 assert_eq!(cols.len(), 1);
411 assert_eq!(cols[0].name, "id");
412 assert!(!cols[0].descending);
413 }
414}