1use std::fmt::{Debug, Formatter};
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use arrow_array::RecordBatch;
12use arrow_schema::SchemaRef;
13use async_trait::async_trait;
14use datafusion::physical_plan::RecordBatchStream;
15use datafusion_common::DataFusionError;
16use datafusion_expr::Expr;
17use futures::Stream;
18use parking_lot::Mutex;
19
20use super::bridge::{BridgeSender, StreamBridge};
21use super::source::{SortColumn, StreamSource};
22
23const DEFAULT_CHANNEL_CAPACITY: usize = 1024;
25
26pub struct ChannelStreamSource {
62 schema: SchemaRef,
64 bridge: Mutex<Option<StreamBridge>>,
66 sender: Mutex<Option<BridgeSender>>,
68 capacity: usize,
70 ordering: Option<Vec<SortColumn>>,
72}
73
74impl ChannelStreamSource {
75 #[must_use]
81 pub fn new(schema: SchemaRef) -> Self {
82 Self::with_capacity(schema, DEFAULT_CHANNEL_CAPACITY)
83 }
84
85 #[must_use]
92 pub fn with_capacity(schema: SchemaRef, capacity: usize) -> Self {
93 let bridge = StreamBridge::new(Arc::clone(&schema), capacity);
94 let sender = bridge.sender();
95 Self {
96 schema,
97 bridge: Mutex::new(Some(bridge)),
98 sender: Mutex::new(Some(sender)),
99 capacity,
100 ordering: None,
101 }
102 }
103
104 #[must_use]
113 pub fn with_ordering(mut self, ordering: Vec<SortColumn>) -> Self {
114 self.ordering = Some(ordering);
115 self
116 }
117
118 #[must_use]
128 pub fn take_sender(&self) -> Option<BridgeSender> {
129 self.sender.lock().take()
130 }
131
132 #[must_use]
138 pub fn sender(&self) -> Option<BridgeSender> {
139 self.sender.lock().as_ref().map(BridgeSender::clone)
140 }
141
142 pub fn reset(&self) -> BridgeSender {
150 let bridge = StreamBridge::new(Arc::clone(&self.schema), self.capacity);
151 let sender = bridge.sender();
152 *self.bridge.lock() = Some(bridge);
153 *self.sender.lock() = Some(sender.clone());
154 sender
155 }
156}
157
158impl Debug for ChannelStreamSource {
159 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
160 f.debug_struct("ChannelStreamSource")
161 .field("schema", &self.schema)
162 .field("capacity", &self.capacity)
163 .finish_non_exhaustive()
164 }
165}
166
167#[async_trait]
168impl StreamSource for ChannelStreamSource {
169 fn schema(&self) -> SchemaRef {
170 Arc::clone(&self.schema)
171 }
172
173 fn output_ordering(&self) -> Option<Vec<SortColumn>> {
174 self.ordering.clone()
175 }
176
177 fn stream(
178 &self,
179 projection: Option<Vec<usize>>,
180 _filters: Vec<Expr>,
181 ) -> Result<datafusion::physical_plan::SendableRecordBatchStream, DataFusionError> {
182 let mut bridge_guard = self.bridge.lock();
183 let bridge = bridge_guard.take().ok_or_else(|| {
184 DataFusionError::Execution(
185 "Stream already taken; call reset() to create a new bridge".to_string(),
186 )
187 })?;
188
189 let inner_stream = bridge.into_stream();
190
191 let stream: datafusion::physical_plan::SendableRecordBatchStream =
193 if let Some(indices) = projection {
194 let projected_schema = {
195 let fields: Vec<_> = indices
196 .iter()
197 .map(|&i| self.schema.field(i).clone())
198 .collect();
199 Arc::new(arrow_schema::Schema::new(fields))
200 };
201 Box::pin(ProjectingStream::new(
202 inner_stream,
203 projected_schema,
204 indices,
205 ))
206 } else {
207 Box::pin(inner_stream)
208 };
209
210 Ok(stream)
211 }
212}
213
214struct ProjectingStream<S> {
216 inner: S,
217 schema: SchemaRef,
218 indices: Vec<usize>,
219}
220
221impl<S> ProjectingStream<S> {
222 fn new(inner: S, schema: SchemaRef, indices: Vec<usize>) -> Self {
223 Self {
224 inner,
225 schema,
226 indices,
227 }
228 }
229}
230
231impl<S> Debug for ProjectingStream<S> {
232 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
233 f.debug_struct("ProjectingStream")
234 .field("schema", &self.schema)
235 .field("indices", &self.indices)
236 .finish_non_exhaustive()
237 }
238}
239
240impl<S> Stream for ProjectingStream<S>
241where
242 S: Stream<Item = Result<RecordBatch, DataFusionError>> + Unpin,
243{
244 type Item = Result<RecordBatch, DataFusionError>;
245
246 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
247 match Pin::new(&mut self.inner).poll_next(cx) {
248 Poll::Ready(Some(Ok(batch))) => {
249 let projected = batch.project(&self.indices).map_err(|e| {
251 DataFusionError::ArrowError(Box::new(e), Some("projection failed".to_string()))
252 });
253 Poll::Ready(Some(projected))
254 }
255 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
256 Poll::Ready(None) => Poll::Ready(None),
257 Poll::Pending => Poll::Pending,
258 }
259 }
260}
261
262impl<S> RecordBatchStream for ProjectingStream<S>
263where
264 S: Stream<Item = Result<RecordBatch, DataFusionError>> + Unpin,
265{
266 fn schema(&self) -> SchemaRef {
267 Arc::clone(&self.schema)
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274 use arrow_array::Int64Array;
275 use arrow_schema::{DataType, Field, Schema};
276 use futures::StreamExt;
277
278 fn test_schema() -> SchemaRef {
279 Arc::new(Schema::new(vec![
280 Field::new("id", DataType::Int64, false),
281 Field::new("value", DataType::Int64, false),
282 ]))
283 }
284
285 fn test_batch(schema: &SchemaRef, ids: Vec<i64>, values: Vec<i64>) -> RecordBatch {
286 RecordBatch::try_new(
287 Arc::clone(schema),
288 vec![
289 Arc::new(Int64Array::from(ids)),
290 Arc::new(Int64Array::from(values)),
291 ],
292 )
293 .unwrap()
294 }
295
296 #[test]
297 fn test_channel_source_schema() {
298 let schema = test_schema();
299 let source = ChannelStreamSource::new(Arc::clone(&schema));
300
301 assert_eq!(source.schema(), schema);
302 }
303
304 #[tokio::test]
305 async fn test_channel_source_stream() {
306 let schema = test_schema();
307 let source = ChannelStreamSource::new(Arc::clone(&schema));
308 let sender = source.take_sender().unwrap();
309
310 let mut stream = source.stream(None, vec![]).unwrap();
311
312 sender
314 .send(test_batch(&schema, vec![1, 2], vec![10, 20]))
315 .await
316 .unwrap();
317 drop(sender);
318
319 let batch = stream.next().await.unwrap().unwrap();
321 assert_eq!(batch.num_rows(), 2);
322 assert_eq!(batch.num_columns(), 2);
323 }
324
325 #[tokio::test]
326 async fn test_channel_source_projection() {
327 let schema = test_schema();
328 let source = ChannelStreamSource::new(Arc::clone(&schema));
329 let sender = source.take_sender().unwrap();
330
331 let mut stream = source.stream(Some(vec![1]), vec![]).unwrap();
333
334 sender
335 .send(test_batch(&schema, vec![1, 2], vec![100, 200]))
336 .await
337 .unwrap();
338 drop(sender);
339
340 let batch = stream.next().await.unwrap().unwrap();
341 assert_eq!(batch.num_columns(), 1);
342 assert_eq!(batch.schema().field(0).name(), "value");
343
344 let values = batch
345 .column(0)
346 .as_any()
347 .downcast_ref::<Int64Array>()
348 .unwrap();
349 assert_eq!(values.value(0), 100);
350 assert_eq!(values.value(1), 200);
351 }
352
353 #[tokio::test]
354 async fn test_channel_source_stream_already_taken() {
355 let schema = test_schema();
356 let source = ChannelStreamSource::new(Arc::clone(&schema));
357
358 let _stream = source.stream(None, vec![]).unwrap();
360
361 let result = source.stream(None, vec![]);
363 assert!(result.is_err());
364 }
365
366 #[tokio::test]
367 async fn test_channel_source_multiple_batches() {
368 let schema = test_schema();
369 let source = ChannelStreamSource::new(Arc::clone(&schema));
370 let sender = source.take_sender().unwrap();
371 let mut stream = source.stream(None, vec![]).unwrap();
372
373 for i in 0..5i64 {
375 sender
376 .send(test_batch(&schema, vec![i], vec![i * 10]))
377 .await
378 .unwrap();
379 }
380 drop(sender);
381
382 let mut count = 0;
384 while let Some(result) = stream.next().await {
385 result.unwrap();
386 count += 1;
387 }
388 assert_eq!(count, 5);
389 }
390
391 #[tokio::test]
392 async fn test_channel_source_take_sender_once() {
393 let schema = test_schema();
394 let source = ChannelStreamSource::new(Arc::clone(&schema));
395
396 let sender = source.take_sender();
398 assert!(sender.is_some());
399
400 let sender2 = source.take_sender();
402 assert!(sender2.is_none());
403 }
404
405 #[tokio::test]
406 async fn test_channel_source_reset() {
407 let schema = test_schema();
408 let source = ChannelStreamSource::new(Arc::clone(&schema));
409
410 let _sender = source.take_sender().unwrap();
412 let _stream = source.stream(None, vec![]).unwrap();
413
414 let new_sender = source.reset();
416 let mut new_stream = source.stream(None, vec![]).unwrap();
417
418 new_sender
420 .send(test_batch(&schema, vec![1], vec![10]))
421 .await
422 .unwrap();
423 drop(new_sender);
424
425 let batch = new_stream.next().await.unwrap().unwrap();
426 assert_eq!(batch.num_rows(), 1);
427 }
428
429 #[test]
430 fn test_channel_source_debug() {
431 let schema = test_schema();
432 let source = ChannelStreamSource::new(Arc::clone(&schema));
433
434 let debug_str = format!("{source:?}");
435 assert!(debug_str.contains("ChannelStreamSource"));
436 assert!(debug_str.contains("capacity"));
437 }
438
439 #[test]
440 fn test_channel_source_default_no_ordering() {
441 let schema = test_schema();
442 let source = ChannelStreamSource::new(Arc::clone(&schema));
443
444 assert!(source.output_ordering().is_none());
445 }
446
447 #[test]
448 fn test_channel_source_with_ordering() {
449 let schema = test_schema();
450 let source = ChannelStreamSource::new(Arc::clone(&schema))
451 .with_ordering(vec![SortColumn::ascending("id")]);
452
453 let ordering = source.output_ordering();
454 assert!(ordering.is_some());
455 let cols = ordering.unwrap();
456 assert_eq!(cols.len(), 1);
457 assert_eq!(cols[0].name, "id");
458 assert!(!cols[0].descending);
459 }
460}