1#[allow(clippy::disallowed_types)] use std::collections::HashMap;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum DerivedChannelType {
12 Spsc,
17
18 Broadcast {
23 consumer_count: usize,
25 },
26}
27
28impl DerivedChannelType {
29 #[must_use]
31 pub fn is_broadcast(&self) -> bool {
32 matches!(self, DerivedChannelType::Broadcast { .. })
33 }
34
35 #[must_use]
37 pub fn consumer_count(&self) -> usize {
38 match self {
39 DerivedChannelType::Spsc => 1,
40 DerivedChannelType::Broadcast { consumer_count } => *consumer_count,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
49pub struct SourceDefinition {
50 pub name: String,
52 pub watermark_column: Option<String>,
54}
55
56impl SourceDefinition {
57 #[must_use]
59 pub fn new(name: impl Into<String>) -> Self {
60 Self {
61 name: name.into(),
62 watermark_column: None,
63 }
64 }
65
66 #[must_use]
68 pub fn with_watermark(name: impl Into<String>, watermark_column: impl Into<String>) -> Self {
69 Self {
70 name: name.into(),
71 watermark_column: Some(watermark_column.into()),
72 }
73 }
74}
75
76#[derive(Debug, Clone)]
80pub struct MvDefinition {
81 pub name: String,
83 pub source_refs: Vec<String>,
85}
86
87impl MvDefinition {
88 #[must_use]
90 pub fn new(name: impl Into<String>, source_refs: Vec<String>) -> Self {
91 Self {
92 name: name.into(),
93 source_refs,
94 }
95 }
96
97 #[must_use]
99 pub fn from_source(name: impl Into<String>, source: impl Into<String>) -> Self {
100 Self {
101 name: name.into(),
102 source_refs: vec![source.into()],
103 }
104 }
105}
106
107#[must_use]
109pub fn derive_channel_types(
110 sources: &[SourceDefinition],
111 mvs: &[MvDefinition],
112) -> HashMap<String, DerivedChannelType> {
113 let consumer_counts = count_consumers_per_source(mvs);
114
115 sources
116 .iter()
117 .map(|source| {
118 let count = consumer_counts.get(&source.name).copied().unwrap_or(0);
119 let channel_type = if count <= 1 {
120 DerivedChannelType::Spsc
121 } else {
122 DerivedChannelType::Broadcast {
123 consumer_count: count,
124 }
125 };
126 (source.name.clone(), channel_type)
127 })
128 .collect()
129}
130
131fn count_consumers_per_source(mvs: &[MvDefinition]) -> HashMap<String, usize> {
133 let mut counts: HashMap<String, usize> = HashMap::with_capacity(mvs.len());
134
135 for mv in mvs {
136 for source_ref in &mv.source_refs {
137 *counts.entry(source_ref.clone()).or_insert(0) += 1;
138 }
139 }
140
141 counts
142}
143
144#[must_use]
159pub fn analyze_mv_sources(mv_name: &str, source_tables: &[&str]) -> MvDefinition {
160 MvDefinition::new(
161 mv_name.to_string(),
162 source_tables.iter().map(|s| (*s).to_string()).collect(),
163 )
164}
165
166#[derive(Debug, Clone)]
168pub struct ChannelDerivationResult {
169 pub channel_types: HashMap<String, DerivedChannelType>,
171 pub orphaned_sources: Vec<String>,
173 pub broadcast_count: usize,
175 pub spsc_count: usize,
177}
178
179#[must_use]
184pub fn derive_channel_types_detailed(
185 sources: &[SourceDefinition],
186 mvs: &[MvDefinition],
187) -> ChannelDerivationResult {
188 let channel_types = derive_channel_types(sources, mvs);
189
190 let orphaned_sources: Vec<String> = channel_types
191 .iter()
192 .filter(|(_, ct)| matches!(ct, DerivedChannelType::Spsc))
193 .filter(|(name, _)| {
194 !mvs.iter().any(|mv| mv.source_refs.contains(*name))
196 })
197 .map(|(name, _)| name.clone())
198 .collect();
199
200 let broadcast_count = channel_types
201 .values()
202 .filter(|ct| ct.is_broadcast())
203 .count();
204
205 let spsc_count = channel_types.len() - broadcast_count;
206
207 ChannelDerivationResult {
208 channel_types,
209 orphaned_sources,
210 broadcast_count,
211 spsc_count,
212 }
213}
214
215#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
224 fn test_derive_single_consumer_spsc() {
225 let sources = vec![SourceDefinition::new("trades")];
226 let mvs = vec![MvDefinition::from_source("vwap", "trades")];
227
228 let channel_types = derive_channel_types(&sources, &mvs);
229
230 assert_eq!(channel_types.get("trades"), Some(&DerivedChannelType::Spsc));
231 }
232
233 #[test]
234 fn test_derive_multiple_consumers_broadcast() {
235 let sources = vec![SourceDefinition::new("trades")];
236 let mvs = vec![
237 MvDefinition::from_source("vwap", "trades"),
238 MvDefinition::from_source("max_price", "trades"),
239 ];
240
241 let channel_types = derive_channel_types(&sources, &mvs);
242
243 assert_eq!(
244 channel_types.get("trades"),
245 Some(&DerivedChannelType::Broadcast { consumer_count: 2 })
246 );
247 }
248
249 #[test]
250 fn test_derive_mixed_sources() {
251 let sources = vec![
252 SourceDefinition::new("trades"),
253 SourceDefinition::new("orders"),
254 ];
255 let mvs = vec![
256 MvDefinition::from_source("vwap", "trades"),
257 MvDefinition::from_source("max_price", "trades"),
258 MvDefinition::from_source("order_count", "orders"),
259 ];
260
261 let channel_types = derive_channel_types(&sources, &mvs);
262
263 assert_eq!(
265 channel_types.get("trades"),
266 Some(&DerivedChannelType::Broadcast { consumer_count: 2 })
267 );
268
269 assert_eq!(channel_types.get("orders"), Some(&DerivedChannelType::Spsc));
271 }
272
273 #[test]
274 fn test_derive_no_consumers() {
275 let sources = vec![SourceDefinition::new("orphan")];
276 let mvs: Vec<MvDefinition> = vec![];
277
278 let channel_types = derive_channel_types(&sources, &mvs);
279
280 assert_eq!(channel_types.get("orphan"), Some(&DerivedChannelType::Spsc));
282 }
283
284 #[test]
285 fn test_derive_mv_with_multiple_sources() {
286 let sources = vec![
287 SourceDefinition::new("orders"),
288 SourceDefinition::new("payments"),
289 ];
290 let mvs = vec![MvDefinition::new(
291 "order_payments",
292 vec!["orders".to_string(), "payments".to_string()],
293 )];
294
295 let channel_types = derive_channel_types(&sources, &mvs);
296
297 assert_eq!(channel_types.get("orders"), Some(&DerivedChannelType::Spsc));
299 assert_eq!(
300 channel_types.get("payments"),
301 Some(&DerivedChannelType::Spsc)
302 );
303 }
304
305 #[test]
306 fn test_derived_channel_type_methods() {
307 let spsc = DerivedChannelType::Spsc;
308 assert!(!spsc.is_broadcast());
309 assert_eq!(spsc.consumer_count(), 1);
310
311 let broadcast = DerivedChannelType::Broadcast { consumer_count: 3 };
312 assert!(broadcast.is_broadcast());
313 assert_eq!(broadcast.consumer_count(), 3);
314 }
315
316 #[test]
317 fn test_source_definition() {
318 let source = SourceDefinition::new("trades");
319 assert_eq!(source.name, "trades");
320 assert!(source.watermark_column.is_none());
321
322 let source_wm = SourceDefinition::with_watermark("trades", "event_time");
323 assert_eq!(source_wm.name, "trades");
324 assert_eq!(source_wm.watermark_column, Some("event_time".to_string()));
325 }
326
327 #[test]
328 fn test_mv_definition() {
329 let mv = MvDefinition::from_source("vwap", "trades");
330 assert_eq!(mv.name, "vwap");
331 assert_eq!(mv.source_refs, vec!["trades"]);
332
333 let mv_multi = MvDefinition::new(
334 "join_result",
335 vec!["orders".to_string(), "payments".to_string()],
336 );
337 assert_eq!(mv_multi.name, "join_result");
338 assert_eq!(mv_multi.source_refs.len(), 2);
339 }
340
341 #[test]
342 fn test_analyze_mv_sources() {
343 let mv = analyze_mv_sources("my_mv", &["table1", "table2"]);
344 assert_eq!(mv.name, "my_mv");
345 assert_eq!(mv.source_refs, vec!["table1", "table2"]);
346 }
347
348 #[test]
349 fn test_detailed_derivation() {
350 let sources = vec![
351 SourceDefinition::new("trades"),
352 SourceDefinition::new("orders"),
353 SourceDefinition::new("orphan"),
354 ];
355 let mvs = vec![
356 MvDefinition::from_source("vwap", "trades"),
357 MvDefinition::from_source("max_price", "trades"),
358 MvDefinition::from_source("order_count", "orders"),
359 ];
360
361 let result = derive_channel_types_detailed(&sources, &mvs);
362
363 assert_eq!(result.broadcast_count, 1); assert_eq!(result.spsc_count, 2); assert!(result.orphaned_sources.contains(&"orphan".to_string()));
366 }
367
368 #[test]
369 fn test_three_consumers() {
370 let sources = vec![SourceDefinition::new("events")];
371 let mvs = vec![
372 MvDefinition::from_source("mv1", "events"),
373 MvDefinition::from_source("mv2", "events"),
374 MvDefinition::from_source("mv3", "events"),
375 ];
376
377 let channel_types = derive_channel_types(&sources, &mvs);
378
379 assert_eq!(
380 channel_types.get("events"),
381 Some(&DerivedChannelType::Broadcast { consumer_count: 3 })
382 );
383 }
384}