Skip to main content

laminar_sql/datafusion/
mod.rs

1//! DataFusion integration for SQL processing.
2
3mod bridge;
4mod channel_source;
5/// Cross-instance hash repartition for distributed GROUP BY. Gated on
6/// `cluster-unstable` because it pulls in the shuffle transport.
7#[cfg(feature = "cluster-unstable")]
8pub mod cluster_repartition;
9/// Lambda higher-order functions for arrays and maps (F-SCHEMA-015 Tier 3)
10pub mod complex_type_lambda;
11/// Array, Struct, and Map scalar UDFs (F-SCHEMA-015)
12pub mod complex_type_udf;
13mod exec;
14/// End-to-end streaming SQL execution
15pub mod execute;
16/// Format bridge UDFs for inline format conversion
17pub mod format_bridge_udf;
18/// LaminarDB streaming JSON extension UDFs (F-SCHEMA-013)
19pub mod json_extensions;
20/// SQL/JSON path query compiler and scalar UDFs
21pub mod json_path;
22/// JSON table-valued functions (array/object expansion)
23pub mod json_tvf;
24/// JSONB binary format types for JSON UDF evaluation
25pub mod json_types;
26/// PostgreSQL-compatible JSON aggregate UDAFs
27pub mod json_udaf;
28/// PostgreSQL-compatible JSON scalar UDFs
29pub mod json_udf;
30/// Live source provider for streaming execution with plan caching
31pub mod live_source;
32/// Lookup join plan node for DataFusion.
33pub mod lookup_join;
34/// Physical execution plan and extension planner for lookup joins.
35pub mod lookup_join_exec;
36/// Processing-time UDF for `PROCTIME()` support
37pub mod proctime_udf;
38mod source;
39mod table_provider;
40/// Dynamic watermark filter for scan-level late-data pruning
41/// Watermark UDF for current watermark access
42pub mod watermark_udf;
43/// Window function UDFs (TUMBLE, HOP, SESSION, CUMULATE)
44pub mod window_udf;
45
46pub use bridge::{BridgeSendError, BridgeSender, BridgeStream, BridgeTrySendError, StreamBridge};
47pub use channel_source::ChannelStreamSource;
48pub use complex_type_lambda::{
49    register_lambda_functions, ArrayFilter, ArrayReduce, ArrayTransform, MapFilter,
50    MapTransformValues,
51};
52pub use complex_type_udf::{
53    register_complex_type_functions, MapContainsKey, MapFromArrays, MapKeys, MapValues, StructDrop,
54    StructExtract, StructMerge, StructRename, StructSet,
55};
56pub use exec::StreamingScanExec;
57pub use execute::{execute_streaming_sql, DdlResult, QueryResult, StreamingSqlResult};
58pub use format_bridge_udf::{FromJsonUdf, ParseEpochUdf, ParseTimestampUdf, ToJsonUdf};
59pub use json_extensions::{
60    register_json_extensions, JsonInferSchema, JsonToColumns, JsonbDeepMerge, JsonbExcept,
61    JsonbFlatten, JsonbMerge, JsonbPick, JsonbRenameKeys, JsonbStripNulls, JsonbUnflatten,
62};
63pub use json_path::{CompiledJsonPath, JsonPathStep, JsonbPathExistsUdf, JsonbPathMatchUdf};
64pub use json_tvf::{
65    register_json_table_functions, JsonbArrayElementsTextTvf, JsonbArrayElementsTvf,
66    JsonbEachTextTvf, JsonbEachTvf, JsonbObjectKeysTvf,
67};
68pub use json_udaf::{JsonAgg, JsonObjectAgg};
69pub use json_udf::{
70    JsonBuildArray, JsonBuildObject, JsonTypeof, JsonbContainedBy, JsonbContains, JsonbExists,
71    JsonbExistsAll, JsonbExistsAny, JsonbGet, JsonbGetIdx, JsonbGetPath, JsonbGetPathText,
72    JsonbGetText, JsonbGetTextIdx, ToJsonb,
73};
74pub use live_source::{LiveSourceHandle, LiveSourceProvider};
75pub use lookup_join_exec::{
76    LookupJoinExec, LookupJoinExtensionPlanner, LookupSnapshot, LookupTableRegistry,
77    PartialLookupJoinExec, PartialLookupState, RegisteredLookup, VersionedLookupJoinExec,
78    VersionedLookupState,
79};
80pub use proctime_udf::ProcTimeUdf;
81pub use source::{SortColumn, StreamSource, StreamSourceRef};
82pub use table_provider::StreamingTableProvider;
83pub use watermark_udf::WatermarkUdf;
84pub use window_udf::{CumulateWindowStart, HopWindowStart, SessionWindowStart, TumbleWindowStart};
85
86use std::sync::atomic::AtomicI64;
87use std::sync::Arc;
88
89use datafusion::execution::SessionStateBuilder;
90use datafusion::prelude::*;
91use datafusion_expr::ScalarUDF;
92
93use crate::planner::streaming_optimizer::{StreamingPhysicalValidator, StreamingValidatorMode};
94
95/// Returns a base `SessionConfig` with identifier normalization disabled.
96///
97/// DataFusion's default behaviour lowercases all unquoted SQL identifiers
98/// (per the SQL standard). LaminarDB disables this so that mixed-case
99/// column names from external sources (Kafka, CDC, WebSocket) can be
100/// referenced without double-quoting.
101#[must_use]
102pub fn base_session_config() -> SessionConfig {
103    let mut config = SessionConfig::new();
104    config.options_mut().sql_parser.enable_ident_normalization = false;
105    // Single partition for streaming micro-batch execution. Multi-partition
106    // plans contain stateful operators (RepartitionExec) that cannot be
107    // reused across cycles, causing panics on cached physical plans.
108    config = config.with_target_partitions(1);
109    config
110}
111
112/// Creates a `DataFusion` session context with identifier normalization
113/// disabled.
114///
115/// Suitable for ad-hoc / non-streaming queries (filters, lookups).
116/// For streaming workloads prefer [`create_streaming_context`].
117#[must_use]
118pub fn create_session_context() -> SessionContext {
119    SessionContext::new_with_config(base_session_config())
120}
121
122/// Creates a `DataFusion` session context configured for streaming queries.
123///
124/// The context is configured with:
125/// - Batch size of 8192 (balanced for streaming throughput)
126/// - Single partition (streaming sources are typically not partitioned)
127/// - Identifier normalization disabled (mixed-case columns work unquoted)
128/// - All streaming UDFs registered (TUMBLE, HOP, SESSION, WATERMARK)
129/// - `StreamingPhysicalValidator` in `Reject` mode (blocks unsafe plans)
130///
131/// The watermark UDF is initialized with no watermark set (returns NULL).
132/// Use [`register_streaming_functions_with_watermark`] to provide a live
133/// watermark source.
134///
135/// # Example
136///
137/// ```rust,ignore
138/// let ctx = create_streaming_context();
139/// ctx.register_table("events", provider)?;
140/// let df = ctx.sql("SELECT * FROM events").await?;
141/// ```
142#[must_use]
143pub fn create_streaming_context() -> SessionContext {
144    create_streaming_context_with_validator(StreamingValidatorMode::Reject)
145}
146
147/// Creates a streaming context with a configurable validator mode.
148///
149/// Same as [`create_streaming_context`] but allows choosing how the
150/// [`StreamingPhysicalValidator`] handles plan violations.
151///
152/// Use [`StreamingValidatorMode::Off`] to get the previous behaviour
153/// (no plan-time validation).
154#[must_use]
155pub fn create_streaming_context_with_validator(mode: StreamingValidatorMode) -> SessionContext {
156    let config = base_session_config().with_batch_size(8192);
157
158    let ctx = if matches!(mode, StreamingValidatorMode::Off) {
159        SessionContext::new_with_config(config)
160    } else {
161        // Build a default state to get the standard optimizer rules, then
162        // prepend our streaming validator so it fires before DataFusion's
163        // built-in SanityCheckPlan (which produces generic error messages).
164        let default_state = SessionStateBuilder::new()
165            .with_config(config.clone())
166            .with_default_features()
167            .build();
168        let mut rules: Vec<
169            Arc<dyn datafusion::physical_optimizer::PhysicalOptimizerRule + Send + Sync>,
170        > = vec![Arc::new(StreamingPhysicalValidator::new(mode))];
171        rules.extend(default_state.physical_optimizers().iter().cloned());
172
173        let state = SessionStateBuilder::new()
174            .with_config(config)
175            .with_default_features()
176            .with_physical_optimizer_rules(rules)
177            .build();
178        SessionContext::new_with_state(state)
179    };
180
181    register_streaming_functions(&ctx);
182    ctx
183}
184
185/// Registers `LaminarDB` streaming UDFs with a session context.
186///
187/// Registers the following scalar functions:
188/// - `tumble(timestamp, interval)` — tumbling window start
189/// - `hop(timestamp, slide, size)` — hopping window start
190/// - `session(timestamp, gap)` — session window pass-through
191/// - `watermark()` — current watermark (returns NULL, no live source)
192///
193/// Use [`register_streaming_functions_with_watermark`] to provide a
194/// live watermark source from Ring 0.
195pub fn register_streaming_functions(ctx: &SessionContext) {
196    ctx.register_udf(ScalarUDF::new_from_impl(TumbleWindowStart::new()));
197    ctx.register_udf(ScalarUDF::new_from_impl(HopWindowStart::new()));
198    ctx.register_udf(ScalarUDF::new_from_impl(SessionWindowStart::new()));
199    ctx.register_udf(ScalarUDF::new_from_impl(CumulateWindowStart::new()));
200    ctx.register_udf(ScalarUDF::new_from_impl(WatermarkUdf::unset()));
201    ctx.register_udf(ScalarUDF::new_from_impl(ProcTimeUdf::new()));
202    register_json_functions(ctx);
203    register_json_extensions(ctx);
204    register_complex_type_functions(ctx);
205    register_lambda_functions(ctx);
206}
207
208/// Registers streaming UDFs with a live watermark source.
209///
210/// Same as [`register_streaming_functions`] but connects the `watermark()`
211/// UDF to a shared atomic value that Ring 0 updates in real time.
212///
213/// # Arguments
214///
215/// * `ctx` - `DataFusion` session context
216/// * `watermark_ms` - Shared atomic holding the current watermark in
217///   milliseconds since epoch. Values < 0 mean "no watermark" (returns NULL).
218pub fn register_streaming_functions_with_watermark(
219    ctx: &SessionContext,
220    watermark_ms: Arc<AtomicI64>,
221) {
222    ctx.register_udf(ScalarUDF::new_from_impl(TumbleWindowStart::new()));
223    ctx.register_udf(ScalarUDF::new_from_impl(HopWindowStart::new()));
224    ctx.register_udf(ScalarUDF::new_from_impl(SessionWindowStart::new()));
225    ctx.register_udf(ScalarUDF::new_from_impl(CumulateWindowStart::new()));
226    ctx.register_udf(ScalarUDF::new_from_impl(WatermarkUdf::new(watermark_ms)));
227    ctx.register_udf(ScalarUDF::new_from_impl(ProcTimeUdf::new()));
228    register_json_functions(ctx);
229    register_json_extensions(ctx);
230    register_complex_type_functions(ctx);
231    register_lambda_functions(ctx);
232}
233
234/// Registers all PostgreSQL-compatible JSON UDFs and UDAFs
235/// with the given `SessionContext`.
236pub fn register_json_functions(ctx: &SessionContext) {
237    // Extraction operators
238    ctx.register_udf(ScalarUDF::new_from_impl(JsonbGet::new()));
239    ctx.register_udf(ScalarUDF::new_from_impl(JsonbGetIdx::new()));
240    ctx.register_udf(ScalarUDF::new_from_impl(JsonbGetText::new()));
241    ctx.register_udf(ScalarUDF::new_from_impl(JsonbGetTextIdx::new()));
242    ctx.register_udf(ScalarUDF::new_from_impl(JsonbGetPath::new()));
243    ctx.register_udf(ScalarUDF::new_from_impl(JsonbGetPathText::new()));
244
245    // Existence operators
246    ctx.register_udf(ScalarUDF::new_from_impl(JsonbExists::new()));
247    ctx.register_udf(ScalarUDF::new_from_impl(JsonbExistsAny::new()));
248    ctx.register_udf(ScalarUDF::new_from_impl(JsonbExistsAll::new()));
249
250    // Containment operators
251    ctx.register_udf(ScalarUDF::new_from_impl(JsonbContains::new()));
252    ctx.register_udf(ScalarUDF::new_from_impl(JsonbContainedBy::new()));
253
254    // Interrogation / construction
255    ctx.register_udf(ScalarUDF::new_from_impl(JsonTypeof::new()));
256    ctx.register_udf(ScalarUDF::new_from_impl(JsonBuildObject::new()));
257    ctx.register_udf(ScalarUDF::new_from_impl(JsonBuildArray::new()));
258    ctx.register_udf(ScalarUDF::new_from_impl(ToJsonb::new()));
259
260    // Aggregates
261    ctx.register_udaf(datafusion_expr::AggregateUDF::new_from_impl(JsonAgg::new()));
262    ctx.register_udaf(datafusion_expr::AggregateUDF::new_from_impl(
263        JsonObjectAgg::new(),
264    ));
265
266    // Format bridge functions
267    ctx.register_udf(ScalarUDF::new_from_impl(ParseEpochUdf::new()));
268    ctx.register_udf(ScalarUDF::new_from_impl(ParseTimestampUdf::new()));
269    ctx.register_udf(ScalarUDF::new_from_impl(ToJsonUdf::new()));
270    ctx.register_udf(ScalarUDF::new_from_impl(FromJsonUdf::new()));
271
272    // JSON path query functions (scalar)
273    ctx.register_udf(ScalarUDF::new_from_impl(JsonbPathExistsUdf::new()));
274    ctx.register_udf(ScalarUDF::new_from_impl(JsonbPathMatchUdf::new()));
275
276    // JSON table-valued functions
277    register_json_table_functions(ctx);
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use arrow_array::{Float64Array, Int64Array, RecordBatch};
284    use arrow_schema::{DataType, Field, Schema};
285    use datafusion::execution::FunctionRegistry;
286    use futures::StreamExt;
287    use std::sync::Arc;
288
289    fn test_schema() -> Arc<Schema> {
290        Arc::new(Schema::new(vec![
291            Field::new("id", DataType::Int64, false),
292            Field::new("value", DataType::Float64, true),
293        ]))
294    }
295
296    /// Take the sender from a `ChannelStreamSource`, panicking if already taken.
297    fn take_test_sender(source: &ChannelStreamSource) -> super::bridge::BridgeSender {
298        source.take_sender().expect("sender already taken")
299    }
300
301    fn test_batch(schema: &Arc<Schema>, ids: Vec<i64>, values: Vec<f64>) -> RecordBatch {
302        RecordBatch::try_new(
303            Arc::clone(schema),
304            vec![
305                Arc::new(Int64Array::from(ids)),
306                Arc::new(Float64Array::from(values)),
307            ],
308        )
309        .unwrap()
310    }
311
312    #[test]
313    fn test_create_streaming_context() {
314        let ctx = create_streaming_context();
315        let state = ctx.state();
316        let config = state.config();
317
318        assert_eq!(config.batch_size(), 8192);
319        assert_eq!(config.target_partitions(), 1);
320    }
321
322    #[tokio::test]
323    async fn test_full_query_pipeline() {
324        let ctx = create_streaming_context();
325        let schema = test_schema();
326
327        // Create source and take the sender (important for channel closure)
328        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
329        let sender = take_test_sender(&source);
330        let provider = StreamingTableProvider::new("events", source);
331        ctx.register_table("events", Arc::new(provider)).unwrap();
332
333        // Send test data
334        sender
335            .send(test_batch(&schema, vec![1, 2, 3], vec![10.0, 20.0, 30.0]))
336            .await
337            .unwrap();
338        sender
339            .send(test_batch(&schema, vec![4, 5], vec![40.0, 50.0]))
340            .await
341            .unwrap();
342        drop(sender); // Close the channel
343
344        // Execute query
345        let df = ctx.sql("SELECT * FROM events").await.unwrap();
346        let batches = df.collect().await.unwrap();
347
348        // Verify results
349        let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
350        assert_eq!(total_rows, 5);
351    }
352
353    #[tokio::test]
354    async fn test_query_with_projection() {
355        let ctx = create_streaming_context();
356        let schema = test_schema();
357
358        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
359        let sender = take_test_sender(&source);
360        let provider = StreamingTableProvider::new("events", source);
361        ctx.register_table("events", Arc::new(provider)).unwrap();
362
363        sender
364            .send(test_batch(&schema, vec![1, 2], vec![100.0, 200.0]))
365            .await
366            .unwrap();
367        drop(sender);
368
369        // Query only the id column
370        let df = ctx.sql("SELECT id FROM events").await.unwrap();
371        let batches = df.collect().await.unwrap();
372
373        assert_eq!(batches.len(), 1);
374        assert_eq!(batches[0].num_columns(), 1);
375        assert_eq!(batches[0].schema().field(0).name(), "id");
376    }
377
378    #[tokio::test]
379    async fn test_query_with_filter() {
380        let ctx = create_streaming_context();
381        let schema = test_schema();
382
383        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
384        let sender = take_test_sender(&source);
385        let provider = StreamingTableProvider::new("events", source);
386        ctx.register_table("events", Arc::new(provider)).unwrap();
387
388        sender
389            .send(test_batch(
390                &schema,
391                vec![1, 2, 3, 4, 5],
392                vec![10.0, 20.0, 30.0, 40.0, 50.0],
393            ))
394            .await
395            .unwrap();
396        drop(sender);
397
398        // Filter for value > 25
399        let df = ctx
400            .sql("SELECT * FROM events WHERE value > 25")
401            .await
402            .unwrap();
403        let batches = df.collect().await.unwrap();
404
405        let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
406        assert_eq!(total_rows, 3); // 30, 40, 50
407    }
408
409    #[tokio::test]
410    async fn test_unbounded_aggregation_rejected() {
411        // Aggregations on unbounded streams should be rejected by `DataFusion`.
412        // Streaming aggregations require windows, which are implemented.
413        let ctx = create_streaming_context();
414        let schema = test_schema();
415
416        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
417        let sender = take_test_sender(&source);
418        let provider = StreamingTableProvider::new("events", source);
419        ctx.register_table("events", Arc::new(provider)).unwrap();
420
421        sender
422            .send(test_batch(&schema, vec![1, 2, 3], vec![10.0, 20.0, 30.0]))
423            .await
424            .unwrap();
425        drop(sender);
426
427        // Aggregate query on unbounded stream should fail at execution
428        let df = ctx.sql("SELECT COUNT(*) as cnt FROM events").await.unwrap();
429
430        // Execution should fail because we can't aggregate an infinite stream
431        let result = df.collect().await;
432        assert!(
433            result.is_err(),
434            "Aggregation on unbounded stream should fail"
435        );
436    }
437
438    #[tokio::test]
439    async fn test_query_with_order_by() {
440        let ctx = create_streaming_context();
441        let schema = test_schema();
442
443        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
444        let sender = take_test_sender(&source);
445        let provider = StreamingTableProvider::new("events", source);
446        ctx.register_table("events", Arc::new(provider)).unwrap();
447
448        sender
449            .send(test_batch(&schema, vec![3, 1, 2], vec![30.0, 10.0, 20.0]))
450            .await
451            .unwrap();
452        drop(sender);
453
454        // Query with ORDER BY (`DataFusion` handles this with Sort operator)
455        let df = ctx.sql("SELECT id, value FROM events").await.unwrap();
456        let batches = df.collect().await.unwrap();
457
458        // Verify we got results (ordering may vary due to streaming nature)
459        let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
460        assert_eq!(total_rows, 3);
461    }
462
463    #[tokio::test]
464    async fn test_bridge_throughput() {
465        // Benchmark-style test for bridge performance
466        let schema = test_schema();
467        let bridge = StreamBridge::new(Arc::clone(&schema), 10000);
468        let sender = bridge.sender();
469        let mut stream = bridge.into_stream();
470
471        let batch_count = 1000;
472        let batch = test_batch(&schema, vec![1, 2, 3, 4, 5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
473
474        // Spawn sender task
475        let send_task = tokio::spawn(async move {
476            for _ in 0..batch_count {
477                sender.send(batch.clone()).await.unwrap();
478            }
479        });
480
481        // Receive all batches
482        let mut received = 0;
483        while let Some(result) = stream.next().await {
484            result.unwrap();
485            received += 1;
486            if received == batch_count {
487                break;
488            }
489        }
490
491        send_task.await.unwrap();
492        assert_eq!(received, batch_count);
493    }
494
495    // ── Integration Tests ──────────────────────────────────────────
496
497    #[test]
498    fn test_streaming_functions_registered() {
499        let ctx = create_streaming_context();
500        // Verify all 4 UDFs are registered
501        assert!(ctx.udf("tumble").is_ok(), "tumble UDF not registered");
502        assert!(ctx.udf("hop").is_ok(), "hop UDF not registered");
503        assert!(ctx.udf("session").is_ok(), "session UDF not registered");
504        assert!(ctx.udf("watermark").is_ok(), "watermark UDF not registered");
505    }
506
507    #[test]
508    fn test_streaming_functions_with_watermark() {
509        use std::sync::atomic::AtomicI64;
510
511        let ctx = create_session_context();
512        let wm = Arc::new(AtomicI64::new(42_000));
513        register_streaming_functions_with_watermark(&ctx, wm);
514
515        assert!(ctx.udf("tumble").is_ok());
516        assert!(ctx.udf("watermark").is_ok());
517    }
518
519    #[tokio::test]
520    async fn test_tumble_udf_via_datafusion() {
521        use arrow_array::TimestampMillisecondArray;
522        use arrow_schema::TimeUnit;
523
524        let ctx = create_streaming_context();
525
526        // Create schema with timestamp and value columns
527        let schema = Arc::new(Schema::new(vec![
528            Field::new(
529                "event_time",
530                DataType::Timestamp(TimeUnit::Millisecond, None),
531                false,
532            ),
533            Field::new("value", DataType::Float64, false),
534        ]));
535
536        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
537        let sender = take_test_sender(&source);
538        let provider = StreamingTableProvider::new("events", source);
539        ctx.register_table("events", Arc::new(provider)).unwrap();
540
541        // Send events across two 5-minute windows:
542        // Window [0, 300_000): timestamps 60_000, 120_000
543        // Window [300_000, 600_000): timestamps 360_000
544        let batch = RecordBatch::try_new(
545            Arc::clone(&schema),
546            vec![
547                Arc::new(TimestampMillisecondArray::from(vec![
548                    60_000i64, 120_000, 360_000,
549                ])),
550                Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])),
551            ],
552        )
553        .unwrap();
554        sender.send(batch).await.unwrap();
555        drop(sender);
556
557        // Verify the tumble UDF computes correct window starts via DataFusion
558        // (GROUP BY aggregation and ORDER BY on unbounded streams are handled by Ring 0)
559        let df = ctx
560            .sql(
561                "SELECT tumble(event_time, INTERVAL '5' MINUTE) as window_start, \
562                 value \
563                 FROM events",
564            )
565            .await
566            .unwrap();
567
568        let batches = df.collect().await.unwrap();
569        let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
570        assert_eq!(total_rows, 3);
571
572        // Verify the window_start values (single batch, order preserved)
573        let ws_col = batches[0]
574            .column(0)
575            .as_any()
576            .downcast_ref::<TimestampMillisecondArray>()
577            .expect("window_start should be TimestampMillisecond");
578        // 60_000 and 120_000 → window [0, 300_000), start = 0
579        assert_eq!(ws_col.value(0), 0);
580        assert_eq!(ws_col.value(1), 0);
581        // 360_000 → window [300_000, 600_000), start = 300_000
582        assert_eq!(ws_col.value(2), 300_000);
583    }
584
585    #[tokio::test]
586    async fn test_logical_plan_from_windowed_query() {
587        use arrow_schema::TimeUnit;
588
589        let ctx = create_streaming_context();
590
591        let schema = Arc::new(Schema::new(vec![
592            Field::new(
593                "event_time",
594                DataType::Timestamp(TimeUnit::Millisecond, None),
595                false,
596            ),
597            Field::new("value", DataType::Float64, false),
598        ]));
599
600        let source = Arc::new(ChannelStreamSource::new(schema));
601        let _sender = source.take_sender();
602        let provider = StreamingTableProvider::new("events", source);
603        ctx.register_table("events", Arc::new(provider)).unwrap();
604
605        // Create a LogicalPlan for a windowed query
606        let df = ctx
607            .sql(
608                "SELECT tumble(event_time, INTERVAL '5' MINUTE) as w, \
609                 COUNT(*) as cnt \
610                 FROM events \
611                 GROUP BY tumble(event_time, INTERVAL '5' MINUTE)",
612            )
613            .await;
614
615        // Should succeed in creating the logical plan (UDFs are registered)
616        assert!(df.is_ok(), "Failed to create logical plan: {df:?}");
617    }
618
619    #[tokio::test]
620    async fn test_end_to_end_execute_streaming_sql() {
621        use crate::planner::StreamingPlanner;
622
623        let ctx = create_streaming_context();
624
625        let schema = Arc::new(Schema::new(vec![
626            Field::new("id", DataType::Int64, false),
627            Field::new("name", DataType::Utf8, true),
628        ]));
629
630        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
631        let sender = take_test_sender(&source);
632        let provider = StreamingTableProvider::new("items", source);
633        ctx.register_table("items", Arc::new(provider)).unwrap();
634
635        let batch = RecordBatch::try_new(
636            Arc::clone(&schema),
637            vec![
638                Arc::new(Int64Array::from(vec![1, 2, 3])),
639                Arc::new(arrow_array::StringArray::from(vec!["a", "b", "c"])),
640            ],
641        )
642        .unwrap();
643        sender.send(batch).await.unwrap();
644        drop(sender);
645
646        let mut planner = StreamingPlanner::new();
647        let result = execute_streaming_sql("SELECT id FROM items WHERE id > 1", &ctx, &mut planner)
648            .await
649            .unwrap();
650
651        match result {
652            StreamingSqlResult::Query(qr) => {
653                let mut stream = qr.stream;
654                let mut total = 0;
655                while let Some(batch) = stream.next().await {
656                    total += batch.unwrap().num_rows();
657                }
658                assert_eq!(total, 2); // id=2, id=3
659            }
660            StreamingSqlResult::Ddl(_) => panic!("Expected Query result"),
661        }
662    }
663
664    #[tokio::test]
665    async fn test_watermark_function_in_filter() {
666        use arrow_array::TimestampMillisecondArray;
667        use arrow_schema::TimeUnit;
668        use std::sync::atomic::AtomicI64;
669
670        // Create context with a specific watermark value
671        let config = base_session_config()
672            .with_batch_size(8192)
673            .with_target_partitions(1);
674        let ctx = SessionContext::new_with_config(config);
675        let wm = Arc::new(AtomicI64::new(200_000)); // watermark at 200s
676        register_streaming_functions_with_watermark(&ctx, wm);
677
678        let schema = Arc::new(Schema::new(vec![
679            Field::new(
680                "event_time",
681                DataType::Timestamp(TimeUnit::Millisecond, None),
682                false,
683            ),
684            Field::new("value", DataType::Float64, false),
685        ]));
686
687        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
688        let sender = take_test_sender(&source);
689        let provider = StreamingTableProvider::new("events", source);
690        ctx.register_table("events", Arc::new(provider)).unwrap();
691
692        // Events: 100s, 200s, 300s - watermark is at 200s
693        let batch = RecordBatch::try_new(
694            Arc::clone(&schema),
695            vec![
696                Arc::new(TimestampMillisecondArray::from(vec![
697                    100_000i64, 200_000, 300_000,
698                ])),
699                Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0])),
700            ],
701        )
702        .unwrap();
703        sender.send(batch).await.unwrap();
704        drop(sender);
705
706        // Filter events after watermark
707        let df = ctx
708            .sql("SELECT value FROM events WHERE event_time > watermark()")
709            .await
710            .unwrap();
711        let batches = df.collect().await.unwrap();
712        let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
713        // Only event at 300s is after watermark (200s)
714        assert_eq!(total_rows, 1);
715    }
716
717    #[tokio::test]
718    async fn test_date_trunc_available() {
719        let ctx = create_streaming_context();
720        let df = ctx
721            .sql("SELECT date_trunc('hour', TIMESTAMP '2026-01-15 14:30:00')")
722            .await
723            .unwrap();
724        let batches = df.collect().await.unwrap();
725        assert_eq!(batches.len(), 1);
726        assert_eq!(batches[0].num_rows(), 1);
727    }
728
729    #[tokio::test]
730    async fn test_date_bin_available() {
731        let ctx = create_streaming_context();
732        let df = ctx
733            .sql(
734                "SELECT date_bin(\
735                 INTERVAL '15 minutes', \
736                 TIMESTAMP '2026-01-15 14:32:00', \
737                 TIMESTAMP '2026-01-01 00:00:00')",
738            )
739            .await
740            .unwrap();
741        let batches = df.collect().await.unwrap();
742        assert_eq!(batches.len(), 1);
743        assert_eq!(batches[0].num_rows(), 1);
744    }
745
746    #[tokio::test]
747    async fn test_unnest_literal_array() {
748        let ctx = create_streaming_context();
749        let df = ctx
750            .sql("SELECT unnest(make_array(1, 2, 3)) AS val")
751            .await
752            .unwrap();
753        let batches = df.collect().await.unwrap();
754        let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
755        assert_eq!(total_rows, 3);
756    }
757
758    #[tokio::test]
759    async fn test_unnest_from_table_with_array_col() {
760        let ctx = create_streaming_context();
761        // Register a table with an array column
762        ctx.sql(
763            "CREATE TABLE arr_table (id INT, tags INT[]) \
764             AS VALUES (1, make_array(10, 20)), (2, make_array(30))",
765        )
766        .await
767        .unwrap();
768        let df = ctx
769            .sql("SELECT id, unnest(tags) AS tag FROM arr_table")
770            .await
771            .unwrap();
772        let batches = df.collect().await.unwrap();
773        let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
774        // Row 1: [10,20] → 2 rows, Row 2: [30] → 1 row = 3 total
775        assert_eq!(total_rows, 3);
776    }
777}