Skip to main content

laminar_sql/planner/
channel_derivation.rs

1//! Channel type derivation from query plan analysis.
2
3#[allow(clippy::disallowed_types)] // cold path: query planning
4use std::collections::HashMap;
5
6/// Channel type derived from query analysis.
7///
8/// This enum represents the automatically-derived channel configuration
9/// for a source based on how many downstream consumers it has.
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum DerivedChannelType {
12    /// Single consumer - use SPSC channel.
13    ///
14    /// Optimal for sources with exactly one downstream MV.
15    /// No cloning overhead, lock-free single producer/consumer.
16    Spsc,
17
18    /// Multiple consumers - use Broadcast channel.
19    ///
20    /// Used when a source feeds multiple downstream MVs.
21    /// Values are cloned to each consumer.
22    Broadcast {
23        /// Number of downstream consumers.
24        consumer_count: usize,
25    },
26}
27
28impl DerivedChannelType {
29    /// Returns true if this is a broadcast channel.
30    #[must_use]
31    pub fn is_broadcast(&self) -> bool {
32        matches!(self, DerivedChannelType::Broadcast { .. })
33    }
34
35    /// Returns the consumer count.
36    #[must_use]
37    pub fn consumer_count(&self) -> usize {
38        match self {
39            DerivedChannelType::Spsc => 1,
40            DerivedChannelType::Broadcast { consumer_count } => *consumer_count,
41        }
42    }
43}
44
45/// Source definition for channel derivation.
46///
47/// Represents a registered streaming source that can be consumed by MVs.
48#[derive(Debug, Clone)]
49pub struct SourceDefinition {
50    /// Source name (e.g., "trades", "orders").
51    pub name: String,
52    /// Optional watermark column for event time processing.
53    pub watermark_column: Option<String>,
54}
55
56impl SourceDefinition {
57    /// Creates a new source definition.
58    #[must_use]
59    pub fn new(name: impl Into<String>) -> Self {
60        Self {
61            name: name.into(),
62            watermark_column: None,
63        }
64    }
65
66    /// Creates a source definition with a watermark column.
67    #[must_use]
68    pub fn with_watermark(name: impl Into<String>, watermark_column: impl Into<String>) -> Self {
69        Self {
70            name: name.into(),
71            watermark_column: Some(watermark_column.into()),
72        }
73    }
74}
75
76/// Materialized view definition for channel derivation.
77///
78/// Represents a continuous query that consumes from one or more sources.
79#[derive(Debug, Clone)]
80pub struct MvDefinition {
81    /// MV name (e.g., "vwap", "max_price").
82    pub name: String,
83    /// Sources this MV reads from.
84    pub source_refs: Vec<String>,
85}
86
87impl MvDefinition {
88    /// Creates a new MV definition.
89    #[must_use]
90    pub fn new(name: impl Into<String>, source_refs: Vec<String>) -> Self {
91        Self {
92            name: name.into(),
93            source_refs,
94        }
95    }
96
97    /// Creates an MV definition that reads from a single source.
98    #[must_use]
99    pub fn from_source(name: impl Into<String>, source: impl Into<String>) -> Self {
100        Self {
101            name: name.into(),
102            source_refs: vec![source.into()],
103        }
104    }
105}
106
107/// Maps each source to SPSC (single consumer) or Broadcast (multi-consumer).
108#[must_use]
109pub fn derive_channel_types(
110    sources: &[SourceDefinition],
111    mvs: &[MvDefinition],
112) -> HashMap<String, DerivedChannelType> {
113    let consumer_counts = count_consumers_per_source(mvs);
114
115    sources
116        .iter()
117        .map(|source| {
118            let count = consumer_counts.get(&source.name).copied().unwrap_or(0);
119            let channel_type = if count <= 1 {
120                DerivedChannelType::Spsc
121            } else {
122                DerivedChannelType::Broadcast {
123                    consumer_count: count,
124                }
125            };
126            (source.name.clone(), channel_type)
127        })
128        .collect()
129}
130
131/// Counts how many MVs read from each source.
132fn count_consumers_per_source(mvs: &[MvDefinition]) -> HashMap<String, usize> {
133    let mut counts: HashMap<String, usize> = HashMap::with_capacity(mvs.len());
134
135    for mv in mvs {
136        for source_ref in &mv.source_refs {
137            *counts.entry(source_ref.clone()).or_insert(0) += 1;
138        }
139    }
140
141    counts
142}
143
144/// Analyzes a single MV to extract its source references.
145///
146/// This is a helper for parsing SQL queries to find referenced sources.
147/// In practice, this would integrate with the SQL parser to extract
148/// table references from FROM clauses.
149///
150/// # Arguments
151///
152/// * `mv_name` - The MV name
153/// * `source_tables` - Tables referenced in the query
154///
155/// # Returns
156///
157/// An `MvDefinition` with the extracted source references.
158#[must_use]
159pub fn analyze_mv_sources(mv_name: &str, source_tables: &[&str]) -> MvDefinition {
160    MvDefinition::new(
161        mv_name.to_string(),
162        source_tables.iter().map(|s| (*s).to_string()).collect(),
163    )
164}
165
166/// Channel derivation result with additional metadata.
167#[derive(Debug, Clone)]
168pub struct ChannelDerivationResult {
169    /// Derived channel types per source.
170    pub channel_types: HashMap<String, DerivedChannelType>,
171    /// Sources with no consumers (orphaned).
172    pub orphaned_sources: Vec<String>,
173    /// Total broadcast channels needed.
174    pub broadcast_count: usize,
175    /// Total SPSC channels needed.
176    pub spsc_count: usize,
177}
178
179/// Derives channel types with additional analysis metadata.
180///
181/// Returns a result that includes orphaned sources (sources with no consumers)
182/// and counts of each channel type.
183#[must_use]
184pub fn derive_channel_types_detailed(
185    sources: &[SourceDefinition],
186    mvs: &[MvDefinition],
187) -> ChannelDerivationResult {
188    let channel_types = derive_channel_types(sources, mvs);
189
190    let orphaned_sources: Vec<String> = channel_types
191        .iter()
192        .filter(|(_, ct)| matches!(ct, DerivedChannelType::Spsc))
193        .filter(|(name, _)| {
194            // Check if this source actually has any consumers
195            !mvs.iter().any(|mv| mv.source_refs.contains(*name))
196        })
197        .map(|(name, _)| name.clone())
198        .collect();
199
200    let broadcast_count = channel_types
201        .values()
202        .filter(|ct| ct.is_broadcast())
203        .count();
204
205    let spsc_count = channel_types.len() - broadcast_count;
206
207    ChannelDerivationResult {
208        channel_types,
209        orphaned_sources,
210        broadcast_count,
211        spsc_count,
212    }
213}
214
215// ===========================================================================
216// Tests
217// ===========================================================================
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn test_derive_single_consumer_spsc() {
225        let sources = vec![SourceDefinition::new("trades")];
226        let mvs = vec![MvDefinition::from_source("vwap", "trades")];
227
228        let channel_types = derive_channel_types(&sources, &mvs);
229
230        assert_eq!(channel_types.get("trades"), Some(&DerivedChannelType::Spsc));
231    }
232
233    #[test]
234    fn test_derive_multiple_consumers_broadcast() {
235        let sources = vec![SourceDefinition::new("trades")];
236        let mvs = vec![
237            MvDefinition::from_source("vwap", "trades"),
238            MvDefinition::from_source("max_price", "trades"),
239        ];
240
241        let channel_types = derive_channel_types(&sources, &mvs);
242
243        assert_eq!(
244            channel_types.get("trades"),
245            Some(&DerivedChannelType::Broadcast { consumer_count: 2 })
246        );
247    }
248
249    #[test]
250    fn test_derive_mixed_sources() {
251        let sources = vec![
252            SourceDefinition::new("trades"),
253            SourceDefinition::new("orders"),
254        ];
255        let mvs = vec![
256            MvDefinition::from_source("vwap", "trades"),
257            MvDefinition::from_source("max_price", "trades"),
258            MvDefinition::from_source("order_count", "orders"),
259        ];
260
261        let channel_types = derive_channel_types(&sources, &mvs);
262
263        // trades: 2 consumers → Broadcast
264        assert_eq!(
265            channel_types.get("trades"),
266            Some(&DerivedChannelType::Broadcast { consumer_count: 2 })
267        );
268
269        // orders: 1 consumer → SPSC
270        assert_eq!(channel_types.get("orders"), Some(&DerivedChannelType::Spsc));
271    }
272
273    #[test]
274    fn test_derive_no_consumers() {
275        let sources = vec![SourceDefinition::new("orphan")];
276        let mvs: Vec<MvDefinition> = vec![];
277
278        let channel_types = derive_channel_types(&sources, &mvs);
279
280        // No consumers → SPSC (default)
281        assert_eq!(channel_types.get("orphan"), Some(&DerivedChannelType::Spsc));
282    }
283
284    #[test]
285    fn test_derive_mv_with_multiple_sources() {
286        let sources = vec![
287            SourceDefinition::new("orders"),
288            SourceDefinition::new("payments"),
289        ];
290        let mvs = vec![MvDefinition::new(
291            "order_payments",
292            vec!["orders".to_string(), "payments".to_string()],
293        )];
294
295        let channel_types = derive_channel_types(&sources, &mvs);
296
297        // Both sources have 1 consumer → SPSC
298        assert_eq!(channel_types.get("orders"), Some(&DerivedChannelType::Spsc));
299        assert_eq!(
300            channel_types.get("payments"),
301            Some(&DerivedChannelType::Spsc)
302        );
303    }
304
305    #[test]
306    fn test_derived_channel_type_methods() {
307        let spsc = DerivedChannelType::Spsc;
308        assert!(!spsc.is_broadcast());
309        assert_eq!(spsc.consumer_count(), 1);
310
311        let broadcast = DerivedChannelType::Broadcast { consumer_count: 3 };
312        assert!(broadcast.is_broadcast());
313        assert_eq!(broadcast.consumer_count(), 3);
314    }
315
316    #[test]
317    fn test_source_definition() {
318        let source = SourceDefinition::new("trades");
319        assert_eq!(source.name, "trades");
320        assert!(source.watermark_column.is_none());
321
322        let source_wm = SourceDefinition::with_watermark("trades", "event_time");
323        assert_eq!(source_wm.name, "trades");
324        assert_eq!(source_wm.watermark_column, Some("event_time".to_string()));
325    }
326
327    #[test]
328    fn test_mv_definition() {
329        let mv = MvDefinition::from_source("vwap", "trades");
330        assert_eq!(mv.name, "vwap");
331        assert_eq!(mv.source_refs, vec!["trades"]);
332
333        let mv_multi = MvDefinition::new(
334            "join_result",
335            vec!["orders".to_string(), "payments".to_string()],
336        );
337        assert_eq!(mv_multi.name, "join_result");
338        assert_eq!(mv_multi.source_refs.len(), 2);
339    }
340
341    #[test]
342    fn test_analyze_mv_sources() {
343        let mv = analyze_mv_sources("my_mv", &["table1", "table2"]);
344        assert_eq!(mv.name, "my_mv");
345        assert_eq!(mv.source_refs, vec!["table1", "table2"]);
346    }
347
348    #[test]
349    fn test_detailed_derivation() {
350        let sources = vec![
351            SourceDefinition::new("trades"),
352            SourceDefinition::new("orders"),
353            SourceDefinition::new("orphan"),
354        ];
355        let mvs = vec![
356            MvDefinition::from_source("vwap", "trades"),
357            MvDefinition::from_source("max_price", "trades"),
358            MvDefinition::from_source("order_count", "orders"),
359        ];
360
361        let result = derive_channel_types_detailed(&sources, &mvs);
362
363        assert_eq!(result.broadcast_count, 1); // trades
364        assert_eq!(result.spsc_count, 2); // orders, orphan
365        assert!(result.orphaned_sources.contains(&"orphan".to_string()));
366    }
367
368    #[test]
369    fn test_three_consumers() {
370        let sources = vec![SourceDefinition::new("events")];
371        let mvs = vec![
372            MvDefinition::from_source("mv1", "events"),
373            MvDefinition::from_source("mv2", "events"),
374            MvDefinition::from_source("mv3", "events"),
375        ];
376
377        let channel_types = derive_channel_types(&sources, &mvs);
378
379        assert_eq!(
380            channel_types.get("events"),
381            Some(&DerivedChannelType::Broadcast { consumer_count: 3 })
382        );
383    }
384}