1use sqlparser::ast::{Expr, OrderByKind, Query, SelectItem, SetExpr, Statement};
7
8#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct OrderAnalysis {
11 pub order_columns: Vec<OrderColumn>,
13 pub limit: Option<usize>,
15 pub is_windowed: bool,
17 pub pattern: OrderPattern,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct OrderColumn {
24 pub column: String,
26 pub descending: bool,
28 pub nulls_first: bool,
30}
31
32#[derive(Debug, Clone, PartialEq, Eq)]
34pub enum OrderPattern {
35 None,
37 SourceSatisfied,
39 TopK {
41 k: usize,
43 },
44 WindowLocal,
46 PerGroupTopK {
48 k: usize,
50 partition_columns: Vec<String>,
52 rank_type: RankType,
54 },
55 Unbounded,
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum RankType {
62 RowNumber,
64 Rank,
66 DenseRank,
68}
69
70impl OrderAnalysis {
71 #[must_use]
73 pub fn is_streaming_safe(&self) -> bool {
74 !matches!(self.pattern, OrderPattern::Unbounded)
75 }
76}
77
78#[must_use]
91pub fn analyze_order_by(stmt: &Statement) -> OrderAnalysis {
92 let Statement::Query(query) = stmt else {
93 return OrderAnalysis {
94 order_columns: vec![],
95 limit: None,
96 is_windowed: false,
97 pattern: OrderPattern::None,
98 };
99 };
100
101 let limit = extract_limit(query);
102 let is_windowed = check_is_windowed(query);
103
104 if let Some((k, partition_columns, rank_type)) = detect_row_number_pattern(query) {
108 let order_columns = extract_order_columns(query);
109 return OrderAnalysis {
110 order_columns,
111 limit,
112 is_windowed,
113 pattern: OrderPattern::PerGroupTopK {
114 k,
115 partition_columns,
116 rank_type,
117 },
118 };
119 }
120
121 let order_columns = extract_order_columns(query);
122 if order_columns.is_empty() {
123 return OrderAnalysis {
124 order_columns: vec![],
125 limit: None,
126 is_windowed: false,
127 pattern: OrderPattern::None,
128 };
129 }
130
131 let pattern = if is_windowed {
132 OrderPattern::WindowLocal
133 } else if let Some(k) = limit {
134 OrderPattern::TopK { k }
135 } else {
136 OrderPattern::Unbounded
137 };
138
139 OrderAnalysis {
140 order_columns,
141 limit,
142 is_windowed,
143 pattern,
144 }
145}
146
147#[must_use]
152pub fn is_order_satisfied(
153 required: &[OrderColumn],
154 source: &[crate::datafusion::SortColumn],
155) -> bool {
156 if required.is_empty() {
157 return true;
158 }
159 if source.len() < required.len() {
160 return false;
161 }
162 required.iter().zip(source.iter()).all(|(req, src)| {
163 req.column == src.name
164 && req.descending == src.descending
165 && req.nulls_first == src.nulls_first
166 })
167}
168
169fn extract_order_columns(query: &Query) -> Vec<OrderColumn> {
171 let Some(order_by) = &query.order_by else {
172 return vec![];
173 };
174
175 let OrderByKind::Expressions(exprs) = &order_by.kind else {
176 return vec![]; };
178
179 exprs
180 .iter()
181 .filter_map(|ob_expr| {
182 let column = extract_column_name(&ob_expr.expr)?;
183 let descending = !ob_expr.options.asc.unwrap_or(true);
184 let nulls_first = ob_expr.options.nulls_first.unwrap_or(false);
185 Some(OrderColumn {
186 column,
187 descending,
188 nulls_first,
189 })
190 })
191 .collect()
192}
193
194fn extract_limit(query: &Query) -> Option<usize> {
196 use sqlparser::ast::LimitClause;
197
198 let limit_clause = query.limit_clause.as_ref()?;
199 match limit_clause {
200 LimitClause::LimitOffset { limit, .. } => {
201 let expr = limit.as_ref()?;
202 expr_to_usize(expr)
203 }
204 LimitClause::OffsetCommaLimit { limit, .. } => expr_to_usize(limit),
205 }
206}
207
208fn check_is_windowed(query: &Query) -> bool {
210 if let SetExpr::Select(select) = query.body.as_ref() {
211 use sqlparser::ast::GroupByExpr;
212 match &select.group_by {
213 GroupByExpr::Expressions(exprs, _modifiers) => {
214 exprs.iter().any(is_window_function_call)
215 }
216 GroupByExpr::All(_) => false,
217 }
218 } else {
219 false
220 }
221}
222
223fn detect_row_number_pattern(query: &Query) -> Option<(usize, Vec<String>, RankType)> {
229 if let SetExpr::Select(select) = query.body.as_ref() {
231 for item in &select.projection {
232 if let SelectItem::UnnamedExpr(expr) | SelectItem::ExprWithAlias { expr, .. } = item {
233 if let Some((partition_cols, _order_cols, rank_type)) =
234 extract_row_number_info(expr)
235 {
236 if let Some(k) = extract_limit(query) {
238 return Some((k, partition_cols, rank_type));
239 }
240 }
241 }
242 }
243
244 for from in &select.from {
246 if let sqlparser::ast::TableFactor::Derived { subquery, .. } = &from.relation {
247 if let SetExpr::Select(inner_select) = subquery.body.as_ref() {
248 for item in &inner_select.projection {
249 if let SelectItem::ExprWithAlias { expr, alias } = item {
250 if let Some((partition_cols, _order_cols, rank_type)) =
251 extract_row_number_info(expr)
252 {
253 if let Some(k) =
256 extract_rn_filter_limit(select.selection.as_ref(), &alias.value)
257 {
258 return Some((k, partition_cols, rank_type));
259 }
260 }
261 }
262 }
263 }
264 }
265 }
266 }
267 None
268}
269
270fn extract_row_number_info(expr: &Expr) -> Option<(Vec<String>, Vec<String>, RankType)> {
274 if let Expr::Function(func) = expr {
275 let name = func.name.to_string().to_uppercase();
276 let rank_type = match name.as_str() {
277 "ROW_NUMBER" => RankType::RowNumber,
278 "RANK" => RankType::Rank,
279 "DENSE_RANK" => RankType::DenseRank,
280 _ => return None,
281 };
282 if let Some(ref window_spec) = func.over {
283 match window_spec {
284 sqlparser::ast::WindowType::WindowSpec(spec) => {
285 let partition_cols: Vec<String> = spec
286 .partition_by
287 .iter()
288 .filter_map(extract_column_name)
289 .collect();
290 let order_cols: Vec<String> = spec
291 .order_by
292 .iter()
293 .filter_map(|ob| extract_column_name(&ob.expr))
294 .collect();
295 return Some((partition_cols, order_cols, rank_type));
296 }
297 sqlparser::ast::WindowType::NamedWindow(_) => {}
298 }
299 }
300 }
301 None
302}
303
304fn extract_rn_filter_limit(selection: Option<&Expr>, alias: &str) -> Option<usize> {
306 let where_expr = selection?;
307 if let Expr::BinaryOp { left, op, right } = where_expr {
308 use sqlparser::ast::BinaryOperator;
309 match op {
310 BinaryOperator::LtEq if extract_column_name(left)? == alias => {
311 return expr_to_usize(right);
313 }
314 BinaryOperator::Lt if extract_column_name(left)? == alias => {
315 return expr_to_usize(right).map(|n| n.saturating_sub(1));
317 }
318 _ => {}
319 }
320 }
321 None
322}
323
324fn is_window_function_call(expr: &Expr) -> bool {
326 if let Expr::Function(func) = expr {
327 let name = func.name.to_string().to_uppercase();
328 matches!(name.as_str(), "TUMBLE" | "HOP" | "SESSION")
329 } else {
330 false
331 }
332}
333
334fn extract_column_name(expr: &Expr) -> Option<String> {
336 match expr {
337 Expr::Identifier(ident) => Some(ident.value.clone()),
338 Expr::CompoundIdentifier(parts) => {
339 parts.last().map(|p| p.value.clone())
341 }
342 _ => None,
343 }
344}
345
346fn expr_to_usize(expr: &Expr) -> Option<usize> {
348 match expr {
349 Expr::Value(value_with_span) => match &value_with_span.value {
350 sqlparser::ast::Value::Number(n, _) => n.parse::<usize>().ok(),
351 _ => None,
352 },
353 _ => None,
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use sqlparser::dialect::GenericDialect;
361 use sqlparser::parser::Parser;
362
363 fn parse_stmt(sql: &str) -> Statement {
364 let dialect = GenericDialect {};
365 let mut stmts = Parser::parse_sql(&dialect, sql).unwrap();
366 stmts.remove(0)
367 }
368
369 #[test]
370 fn test_analyze_simple_order_by() {
371 let stmt = parse_stmt("SELECT id, value FROM events ORDER BY id");
372 let analysis = analyze_order_by(&stmt);
373 assert_eq!(analysis.order_columns.len(), 1);
374 assert_eq!(analysis.order_columns[0].column, "id");
375 assert!(!analysis.order_columns[0].descending);
376 assert_eq!(analysis.pattern, OrderPattern::Unbounded);
377 }
378
379 #[test]
380 fn test_analyze_order_by_desc() {
381 let stmt = parse_stmt("SELECT * FROM events ORDER BY price DESC");
382 let analysis = analyze_order_by(&stmt);
383 assert_eq!(analysis.order_columns.len(), 1);
384 assert!(analysis.order_columns[0].descending);
385 }
386
387 #[test]
388 fn test_analyze_order_by_nulls_first() {
389 let stmt = parse_stmt("SELECT * FROM events ORDER BY value ASC NULLS FIRST");
390 let analysis = analyze_order_by(&stmt);
391 assert_eq!(analysis.order_columns.len(), 1);
392 assert!(!analysis.order_columns[0].descending);
393 assert!(analysis.order_columns[0].nulls_first);
394 }
395
396 #[test]
397 fn test_analyze_order_by_multiple_columns() {
398 let stmt = parse_stmt("SELECT * FROM events ORDER BY category ASC, price DESC NULLS LAST");
399 let analysis = analyze_order_by(&stmt);
400 assert_eq!(analysis.order_columns.len(), 2);
401 assert_eq!(analysis.order_columns[0].column, "category");
402 assert!(!analysis.order_columns[0].descending);
403 assert_eq!(analysis.order_columns[1].column, "price");
404 assert!(analysis.order_columns[1].descending);
405 }
406
407 #[test]
408 fn test_analyze_order_by_with_limit() {
409 let stmt = parse_stmt("SELECT * FROM events ORDER BY price DESC LIMIT 10");
410 let analysis = analyze_order_by(&stmt);
411 assert_eq!(analysis.limit, Some(10));
412 assert_eq!(analysis.pattern, OrderPattern::TopK { k: 10 });
413 }
414
415 #[test]
416 fn test_analyze_order_by_without_limit() {
417 let stmt = parse_stmt("SELECT * FROM events ORDER BY id");
418 let analysis = analyze_order_by(&stmt);
419 assert!(analysis.limit.is_none());
420 assert_eq!(analysis.pattern, OrderPattern::Unbounded);
421 assert!(!analysis.is_streaming_safe());
422 }
423
424 #[test]
425 fn test_analyze_no_order_by() {
426 let stmt = parse_stmt("SELECT * FROM events");
427 let analysis = analyze_order_by(&stmt);
428 assert_eq!(analysis.pattern, OrderPattern::None);
429 assert!(analysis.order_columns.is_empty());
430 assert!(analysis.is_streaming_safe());
431 }
432
433 #[test]
434 fn test_analyze_select_star() {
435 let stmt = parse_stmt("SELECT * FROM events WHERE id > 5");
436 let analysis = analyze_order_by(&stmt);
437 assert_eq!(analysis.pattern, OrderPattern::None);
438 }
439
440 #[test]
441 fn test_detect_row_number_pattern() {
442 let sql = "SELECT * FROM (
443 SELECT *, ROW_NUMBER() OVER (PARTITION BY category ORDER BY price DESC) AS rn
444 FROM trades
445 ) sub WHERE rn <= 5";
446 let stmt = parse_stmt(sql);
447 let analysis = analyze_order_by(&stmt);
448
449 assert_eq!(
451 analysis.pattern,
452 OrderPattern::PerGroupTopK {
453 k: 5,
454 partition_columns: vec!["category".to_string()],
455 rank_type: RankType::RowNumber,
456 }
457 );
458 assert!(analysis.is_streaming_safe());
459 }
460
461 #[test]
462 fn test_detect_row_number_with_partition() {
463 let sql = "SELECT * FROM (
464 SELECT *, ROW_NUMBER() OVER (PARTITION BY category ORDER BY price DESC) AS rn
465 FROM trades
466 ) sub WHERE rn <= 3 ORDER BY category LIMIT 100";
467 let stmt = parse_stmt(sql);
468 let analysis = analyze_order_by(&stmt);
469
470 assert_eq!(
472 analysis.pattern,
473 OrderPattern::PerGroupTopK {
474 k: 3,
475 partition_columns: vec!["category".to_string()],
476 rank_type: RankType::RowNumber,
477 }
478 );
479 assert!(analysis.is_streaming_safe());
480 }
481
482 #[test]
483 fn test_detect_row_number_without_filter() {
484 let sql = "SELECT *, ROW_NUMBER() OVER (ORDER BY price DESC) AS rn FROM trades";
485 let stmt = parse_stmt(sql);
486 let analysis = analyze_order_by(&stmt);
487 assert_eq!(analysis.pattern, OrderPattern::None);
489 }
490
491 #[test]
494 fn test_row_number_subquery_no_outer_order() {
495 let sql = "SELECT * FROM (
496 SELECT *, ROW_NUMBER() OVER (PARTITION BY symbol ORDER BY ts DESC) AS rn
497 FROM trades
498 ) sub WHERE rn <= 10";
499 let stmt = parse_stmt(sql);
500 let analysis = analyze_order_by(&stmt);
501 assert_eq!(
502 analysis.pattern,
503 OrderPattern::PerGroupTopK {
504 k: 10,
505 partition_columns: vec!["symbol".to_string()],
506 rank_type: RankType::RowNumber,
507 }
508 );
509 assert!(analysis.is_streaming_safe());
510 }
511
512 #[test]
513 fn test_row_number_direct_with_limit() {
514 let sql = "SELECT *, ROW_NUMBER() OVER (PARTITION BY cat ORDER BY val DESC) AS rn
515 FROM events LIMIT 5";
516 let stmt = parse_stmt(sql);
517 let analysis = analyze_order_by(&stmt);
518 assert_eq!(
519 analysis.pattern,
520 OrderPattern::PerGroupTopK {
521 k: 5,
522 partition_columns: vec!["cat".to_string()],
523 rank_type: RankType::RowNumber,
524 }
525 );
526 }
527
528 #[test]
529 fn test_detect_rank_pattern() {
530 let sql = "SELECT * FROM (
531 SELECT *, RANK() OVER (PARTITION BY category ORDER BY price DESC) AS rn
532 FROM trades
533 ) sub WHERE rn <= 3";
534 let stmt = parse_stmt(sql);
535 let analysis = analyze_order_by(&stmt);
536 assert_eq!(
537 analysis.pattern,
538 OrderPattern::PerGroupTopK {
539 k: 3,
540 partition_columns: vec!["category".to_string()],
541 rank_type: RankType::Rank,
542 }
543 );
544 assert!(analysis.is_streaming_safe());
545 }
546
547 #[test]
548 fn test_detect_dense_rank_pattern() {
549 let sql = "SELECT * FROM (
550 SELECT *, DENSE_RANK() OVER (PARTITION BY region ORDER BY revenue DESC) AS rn
551 FROM sales
552 ) sub WHERE rn <= 5";
553 let stmt = parse_stmt(sql);
554 let analysis = analyze_order_by(&stmt);
555 assert_eq!(
556 analysis.pattern,
557 OrderPattern::PerGroupTopK {
558 k: 5,
559 partition_columns: vec!["region".to_string()],
560 rank_type: RankType::DenseRank,
561 }
562 );
563 }
564
565 #[test]
566 fn test_rank_multiple_partition_columns() {
567 let sql = "SELECT * FROM (
568 SELECT *, RANK() OVER (PARTITION BY region, category ORDER BY sales DESC) AS rn
569 FROM revenue
570 ) sub WHERE rn <= 3";
571 let stmt = parse_stmt(sql);
572 let analysis = analyze_order_by(&stmt);
573 match &analysis.pattern {
574 OrderPattern::PerGroupTopK {
575 k,
576 partition_columns,
577 rank_type,
578 } => {
579 assert_eq!(*k, 3);
580 assert_eq!(
581 partition_columns,
582 &["region".to_string(), "category".to_string()]
583 );
584 assert_eq!(*rank_type, RankType::Rank);
585 }
586 _ => panic!("Expected PerGroupTopK, got {:?}", analysis.pattern),
587 }
588 }
589
590 #[test]
591 fn test_rank_extracts_order_columns() {
592 let sql = "SELECT *, RANK() OVER (PARTITION BY cat ORDER BY price DESC, ts ASC) AS rn
593 FROM trades LIMIT 10";
594 let stmt = parse_stmt(sql);
595 let analysis = analyze_order_by(&stmt);
596 assert!(matches!(
597 analysis.pattern,
598 OrderPattern::PerGroupTopK {
599 rank_type: RankType::Rank,
600 ..
601 }
602 ));
603 }
604
605 #[test]
606 fn test_rank_pattern_is_streaming_safe() {
607 let sql = "SELECT * FROM (
608 SELECT *, DENSE_RANK() OVER (PARTITION BY cat ORDER BY val) AS rn
609 FROM events
610 ) sub WHERE rn <= 5";
611 let stmt = parse_stmt(sql);
612 let analysis = analyze_order_by(&stmt);
613 assert!(analysis.is_streaming_safe());
614 }
615
616 #[test]
617 fn test_no_ranking_function_none() {
618 let sql = "SELECT id, name FROM events WHERE id > 5";
619 let stmt = parse_stmt(sql);
620 let analysis = analyze_order_by(&stmt);
621 assert_eq!(analysis.pattern, OrderPattern::None);
622 }
623
624 #[test]
625 fn test_order_satisfied_exact_match() {
626 use crate::datafusion::SortColumn;
627 let required = vec![OrderColumn {
628 column: "event_time".to_string(),
629 descending: false,
630 nulls_first: false,
631 }];
632 let source = vec![SortColumn::ascending("event_time")];
633 assert!(is_order_satisfied(&required, &source));
634 }
635
636 #[test]
637 fn test_order_satisfied_prefix_match() {
638 use crate::datafusion::SortColumn;
639 let required = vec![OrderColumn {
640 column: "event_time".to_string(),
641 descending: false,
642 nulls_first: false,
643 }];
644 let source = vec![
645 SortColumn::ascending("event_time"),
646 SortColumn::ascending("id"),
647 ];
648 assert!(is_order_satisfied(&required, &source));
649 }
650
651 #[test]
652 fn test_order_not_satisfied_different_direction() {
653 use crate::datafusion::SortColumn;
654 let required = vec![OrderColumn {
655 column: "event_time".to_string(),
656 descending: true,
657 nulls_first: false,
658 }];
659 let source = vec![SortColumn::ascending("event_time")];
660 assert!(!is_order_satisfied(&required, &source));
661 }
662
663 #[test]
664 fn test_order_not_satisfied_different_columns() {
665 use crate::datafusion::SortColumn;
666 let required = vec![OrderColumn {
667 column: "id".to_string(),
668 descending: false,
669 nulls_first: false,
670 }];
671 let source = vec![SortColumn::ascending("event_time")];
672 assert!(!is_order_satisfied(&required, &source));
673 }
674
675 #[test]
676 fn test_topk_pattern_streaming_safe() {
677 let stmt = parse_stmt("SELECT * FROM trades ORDER BY price DESC LIMIT 5");
678 let analysis = analyze_order_by(&stmt);
679 assert!(analysis.is_streaming_safe());
680 assert_eq!(analysis.pattern, OrderPattern::TopK { k: 5 });
681 }
682
683 #[test]
684 fn test_unbounded_pattern_not_streaming_safe() {
685 let stmt = parse_stmt("SELECT * FROM trades ORDER BY price DESC");
686 let analysis = analyze_order_by(&stmt);
687 assert!(!analysis.is_streaming_safe());
688 assert_eq!(analysis.pattern, OrderPattern::Unbounded);
689 }
690
691 #[test]
692 fn test_no_order_by_streaming_safe() {
693 let stmt = parse_stmt("SELECT * FROM trades");
694 let analysis = analyze_order_by(&stmt);
695 assert!(analysis.is_streaming_safe());
696 }
697
698 #[test]
699 fn test_windowed_order_by() {
700 let stmt = parse_stmt(
701 "SELECT COUNT(*) FROM events GROUP BY TUMBLE(event_time, INTERVAL '5' MINUTE) ORDER BY event_time",
702 );
703 let analysis = analyze_order_by(&stmt);
704 assert_eq!(analysis.pattern, OrderPattern::WindowLocal);
705 assert!(analysis.is_windowed);
706 assert!(analysis.is_streaming_safe());
707 }
708}