Skip to main content

laminar_sql/datafusion/
channel_source.rs

1//! Channel-based streaming source implementation
2//!
3//! This module provides `ChannelStreamSource`, the primary integration point
4//! between `LaminarDB`'s Reactor and `DataFusion`'s query engine.
5
6use std::fmt::{Debug, Formatter};
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use arrow_array::RecordBatch;
12use arrow_schema::SchemaRef;
13use async_trait::async_trait;
14use datafusion::physical_plan::RecordBatchStream;
15use datafusion_common::DataFusionError;
16use datafusion_expr::Expr;
17use futures::Stream;
18use parking_lot::Mutex;
19
20use super::bridge::{BridgeSender, StreamBridge};
21use super::source::{SortColumn, StreamSource};
22
23/// Default channel capacity for the stream source.
24const DEFAULT_CHANNEL_CAPACITY: usize = 1024;
25
26/// Bridges LaminarDB's push-based Reactor and DataFusion's pull-based
27/// query execution. The sender is `take`-once (not cloned) so dropping
28/// it closes the channel and lets the query terminate.
29pub struct ChannelStreamSource {
30    /// Schema of the data
31    schema: SchemaRef,
32    /// The bridge connecting sender and receivers
33    bridge: Mutex<Option<StreamBridge>>,
34    /// Sender for pushing data - must be taken, not cloned
35    sender: Mutex<Option<BridgeSender>>,
36    /// Channel capacity
37    capacity: usize,
38    /// Declared output ordering (for ORDER BY elision)
39    ordering: Option<Vec<SortColumn>>,
40}
41
42impl ChannelStreamSource {
43    /// Creates a new channel stream source with default capacity.
44    #[must_use]
45    pub fn new(schema: SchemaRef) -> Self {
46        Self::with_capacity(schema, DEFAULT_CHANNEL_CAPACITY)
47    }
48
49    /// Creates a new channel stream source with the given channel capacity.
50    #[must_use]
51    pub fn with_capacity(schema: SchemaRef, capacity: usize) -> Self {
52        let bridge = StreamBridge::new(Arc::clone(&schema), capacity);
53        let sender = bridge.sender();
54        Self {
55            schema,
56            bridge: Mutex::new(Some(bridge)),
57            sender: Mutex::new(Some(sender)),
58            capacity,
59            ordering: None,
60        }
61    }
62
63    /// Declares that this source produces data in the given sort order.
64    /// When set, `DataFusion` can elide `SortExec` for ORDER BY queries
65    /// that match the declared ordering.
66    #[must_use]
67    pub fn with_ordering(mut self, ordering: Vec<SortColumn>) -> Self {
68        self.ordering = Some(ordering);
69        self
70    }
71
72    /// Takes the sender for pushing batches into this source.
73    ///
74    /// This method can only be called once. The sender is moved out of
75    /// the source to ensure the caller has full ownership and can close
76    /// the channel by dropping the sender.
77    ///
78    /// The returned sender can be cloned to allow multiple producers.
79    ///
80    /// Returns `None` if the sender was already taken.
81    #[must_use]
82    pub fn take_sender(&self) -> Option<BridgeSender> {
83        self.sender.lock().take()
84    }
85
86    /// Returns a clone of the sender if it hasn't been taken yet.
87    ///
88    /// **Warning**: Using this method can lead to channel leak issues if
89    /// the original sender is never dropped. Prefer `take_sender()` for
90    /// proper channel lifecycle management.
91    #[must_use]
92    pub fn sender(&self) -> Option<BridgeSender> {
93        self.sender.lock().as_ref().map(BridgeSender::clone)
94    }
95
96    /// Resets the source with a new bridge and sender.
97    ///
98    /// This is useful when you need to reuse the source after the previous
99    /// stream has been consumed. Any data sent before the reset but not
100    /// yet consumed will be lost.
101    ///
102    /// Returns the new sender.
103    pub fn reset(&self) -> BridgeSender {
104        let bridge = StreamBridge::new(Arc::clone(&self.schema), self.capacity);
105        let sender = bridge.sender();
106        *self.bridge.lock() = Some(bridge);
107        *self.sender.lock() = Some(sender.clone());
108        sender
109    }
110}
111
112impl Debug for ChannelStreamSource {
113    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
114        f.debug_struct("ChannelStreamSource")
115            .field("schema", &self.schema)
116            .field("capacity", &self.capacity)
117            .finish_non_exhaustive()
118    }
119}
120
121#[async_trait]
122impl StreamSource for ChannelStreamSource {
123    fn schema(&self) -> SchemaRef {
124        Arc::clone(&self.schema)
125    }
126
127    fn output_ordering(&self) -> Option<Vec<SortColumn>> {
128        self.ordering.clone()
129    }
130
131    fn stream(
132        &self,
133        projection: Option<Vec<usize>>,
134        _filters: Vec<Expr>,
135    ) -> Result<datafusion::physical_plan::SendableRecordBatchStream, DataFusionError> {
136        let mut bridge_guard = self.bridge.lock();
137        let bridge = bridge_guard.take().ok_or_else(|| {
138            DataFusionError::Execution(
139                "Stream already taken; call reset() to create a new bridge".to_string(),
140            )
141        })?;
142
143        let inner_stream = bridge.into_stream();
144
145        // Apply projection if specified
146        let stream: datafusion::physical_plan::SendableRecordBatchStream =
147            if let Some(indices) = projection {
148                let projected_schema = {
149                    let fields: Vec<_> = indices
150                        .iter()
151                        .map(|&i| self.schema.field(i).clone())
152                        .collect();
153                    Arc::new(arrow_schema::Schema::new(fields))
154                };
155                Box::pin(ProjectingStream::new(
156                    inner_stream,
157                    projected_schema,
158                    indices,
159                ))
160            } else {
161                Box::pin(inner_stream)
162            };
163
164        Ok(stream)
165    }
166}
167
168/// A stream that applies column projection to record batches.
169struct ProjectingStream<S> {
170    inner: S,
171    schema: SchemaRef,
172    indices: Vec<usize>,
173}
174
175impl<S> ProjectingStream<S> {
176    fn new(inner: S, schema: SchemaRef, indices: Vec<usize>) -> Self {
177        Self {
178            inner,
179            schema,
180            indices,
181        }
182    }
183}
184
185impl<S> Debug for ProjectingStream<S> {
186    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
187        f.debug_struct("ProjectingStream")
188            .field("schema", &self.schema)
189            .field("indices", &self.indices)
190            .finish_non_exhaustive()
191    }
192}
193
194impl<S> Stream for ProjectingStream<S>
195where
196    S: Stream<Item = Result<RecordBatch, DataFusionError>> + Unpin,
197{
198    type Item = Result<RecordBatch, DataFusionError>;
199
200    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
201        match Pin::new(&mut self.inner).poll_next(cx) {
202            Poll::Ready(Some(Ok(batch))) => {
203                // Project columns using built-in projection (avoids intermediate Vec alloc)
204                let projected = batch.project(&self.indices).map_err(|e| {
205                    DataFusionError::ArrowError(Box::new(e), Some("projection failed".to_string()))
206                });
207                Poll::Ready(Some(projected))
208            }
209            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
210            Poll::Ready(None) => Poll::Ready(None),
211            Poll::Pending => Poll::Pending,
212        }
213    }
214}
215
216impl<S> RecordBatchStream for ProjectingStream<S>
217where
218    S: Stream<Item = Result<RecordBatch, DataFusionError>> + Unpin,
219{
220    fn schema(&self) -> SchemaRef {
221        Arc::clone(&self.schema)
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228    use arrow_array::Int64Array;
229    use arrow_schema::{DataType, Field, Schema};
230    use futures::StreamExt;
231
232    fn test_schema() -> SchemaRef {
233        Arc::new(Schema::new(vec![
234            Field::new("id", DataType::Int64, false),
235            Field::new("value", DataType::Int64, false),
236        ]))
237    }
238
239    fn test_batch(schema: &SchemaRef, ids: Vec<i64>, values: Vec<i64>) -> RecordBatch {
240        RecordBatch::try_new(
241            Arc::clone(schema),
242            vec![
243                Arc::new(Int64Array::from(ids)),
244                Arc::new(Int64Array::from(values)),
245            ],
246        )
247        .unwrap()
248    }
249
250    #[test]
251    fn test_channel_source_schema() {
252        let schema = test_schema();
253        let source = ChannelStreamSource::new(Arc::clone(&schema));
254
255        assert_eq!(source.schema(), schema);
256    }
257
258    #[tokio::test]
259    async fn test_channel_source_stream() {
260        let schema = test_schema();
261        let source = ChannelStreamSource::new(Arc::clone(&schema));
262        let sender = source.take_sender().unwrap();
263
264        let mut stream = source.stream(None, vec![]).unwrap();
265
266        // Send data
267        sender
268            .send(test_batch(&schema, vec![1, 2], vec![10, 20]))
269            .await
270            .unwrap();
271        drop(sender);
272
273        // Receive data
274        let batch = stream.next().await.unwrap().unwrap();
275        assert_eq!(batch.num_rows(), 2);
276        assert_eq!(batch.num_columns(), 2);
277    }
278
279    #[tokio::test]
280    async fn test_channel_source_projection() {
281        let schema = test_schema();
282        let source = ChannelStreamSource::new(Arc::clone(&schema));
283        let sender = source.take_sender().unwrap();
284
285        // Project only the "value" column (index 1)
286        let mut stream = source.stream(Some(vec![1]), vec![]).unwrap();
287
288        sender
289            .send(test_batch(&schema, vec![1, 2], vec![100, 200]))
290            .await
291            .unwrap();
292        drop(sender);
293
294        let batch = stream.next().await.unwrap().unwrap();
295        assert_eq!(batch.num_columns(), 1);
296        assert_eq!(batch.schema().field(0).name(), "value");
297
298        let values = batch
299            .column(0)
300            .as_any()
301            .downcast_ref::<Int64Array>()
302            .unwrap();
303        assert_eq!(values.value(0), 100);
304        assert_eq!(values.value(1), 200);
305    }
306
307    #[tokio::test]
308    async fn test_channel_source_stream_already_taken() {
309        let schema = test_schema();
310        let source = ChannelStreamSource::new(Arc::clone(&schema));
311
312        // First stream takes ownership
313        let _stream = source.stream(None, vec![]).unwrap();
314
315        // Second stream should fail
316        let result = source.stream(None, vec![]);
317        assert!(result.is_err());
318    }
319
320    #[tokio::test]
321    async fn test_channel_source_multiple_batches() {
322        let schema = test_schema();
323        let source = ChannelStreamSource::new(Arc::clone(&schema));
324        let sender = source.take_sender().unwrap();
325        let mut stream = source.stream(None, vec![]).unwrap();
326
327        // Send multiple batches
328        for i in 0..5i64 {
329            sender
330                .send(test_batch(&schema, vec![i], vec![i * 10]))
331                .await
332                .unwrap();
333        }
334        drop(sender);
335
336        // Receive all batches
337        let mut count = 0;
338        while let Some(result) = stream.next().await {
339            result.unwrap();
340            count += 1;
341        }
342        assert_eq!(count, 5);
343    }
344
345    #[tokio::test]
346    async fn test_channel_source_take_sender_once() {
347        let schema = test_schema();
348        let source = ChannelStreamSource::new(Arc::clone(&schema));
349
350        // First take succeeds
351        let sender = source.take_sender();
352        assert!(sender.is_some());
353
354        // Second take returns None
355        let sender2 = source.take_sender();
356        assert!(sender2.is_none());
357    }
358
359    #[tokio::test]
360    async fn test_channel_source_reset() {
361        let schema = test_schema();
362        let source = ChannelStreamSource::new(Arc::clone(&schema));
363
364        // Take sender and stream
365        let _sender = source.take_sender().unwrap();
366        let _stream = source.stream(None, vec![]).unwrap();
367
368        // Reset creates new bridge and sender
369        let new_sender = source.reset();
370        let mut new_stream = source.stream(None, vec![]).unwrap();
371
372        // Can use the new sender and stream
373        new_sender
374            .send(test_batch(&schema, vec![1], vec![10]))
375            .await
376            .unwrap();
377        drop(new_sender);
378
379        let batch = new_stream.next().await.unwrap().unwrap();
380        assert_eq!(batch.num_rows(), 1);
381    }
382
383    #[test]
384    fn test_channel_source_debug() {
385        let schema = test_schema();
386        let source = ChannelStreamSource::new(Arc::clone(&schema));
387
388        let debug_str = format!("{source:?}");
389        assert!(debug_str.contains("ChannelStreamSource"));
390        assert!(debug_str.contains("capacity"));
391    }
392
393    #[test]
394    fn test_channel_source_default_no_ordering() {
395        let schema = test_schema();
396        let source = ChannelStreamSource::new(Arc::clone(&schema));
397
398        assert!(source.output_ordering().is_none());
399    }
400
401    #[test]
402    fn test_channel_source_with_ordering() {
403        let schema = test_schema();
404        let source = ChannelStreamSource::new(Arc::clone(&schema))
405            .with_ordering(vec![SortColumn::ascending("id")]);
406
407        let ordering = source.output_ordering();
408        assert!(ordering.is_some());
409        let cols = ordering.unwrap();
410        assert_eq!(cols.len(), 1);
411        assert_eq!(cols[0].name, "id");
412        assert!(!cols[0].descending);
413    }
414}