Skip to main content

laminar_db/
catalog.rs

1//! Source and sink catalog for tracking registered streaming objects.
2#![allow(clippy::disallowed_types)] // cold path
3
4use std::collections::HashMap;
5use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
6use std::sync::Arc;
7use std::time::Duration;
8
9use arrow::array::RecordBatch;
10use arrow::datatypes::SchemaRef;
11use parking_lot::RwLock;
12use tokio::sync::Notify;
13
14use laminar_core::streaming::{self, BackpressureStrategy, SourceConfig, WaitStrategy};
15
16/// Record type for Arrow-based streaming subscriptions.
17#[derive(Clone, Debug)]
18pub struct ArrowRecord {
19    pub(crate) batch: RecordBatch,
20}
21
22impl laminar_core::streaming::Record for ArrowRecord {
23    fn schema() -> SchemaRef {
24        // This is a placeholder; the actual schema is on the SourceEntry.
25        // ArrowRecord is only used as a type parameter; push_arrow bypasses this.
26        Arc::new(arrow::datatypes::Schema::empty())
27    }
28
29    fn to_record_batch(&self) -> RecordBatch {
30        self.batch.clone()
31    }
32}
33
34/// Bounded ring buffer for snapshot batches.
35///
36/// Concurrent `push()` calls each get a unique slot via atomic `fetch_add`.
37/// Per-slot mutex protects the actual read/write.
38struct SnapshotRing {
39    slots: Box<[parking_lot::Mutex<Option<RecordBatch>>]>,
40    tail: AtomicUsize,
41    capacity: usize,
42}
43
44impl SnapshotRing {
45    fn new(capacity: usize) -> Self {
46        let cap = capacity.max(1);
47        let slots: Vec<_> = (0..cap).map(|_| parking_lot::Mutex::new(None)).collect();
48        Self {
49            slots: slots.into_boxed_slice(),
50            tail: AtomicUsize::new(0),
51            capacity: cap,
52        }
53    }
54
55    fn push(&self, batch: RecordBatch) {
56        // fetch_add is atomic — concurrent pushers each get a unique slot.
57        let idx = self.tail.fetch_add(1, Ordering::Relaxed) % self.capacity;
58        *self.slots[idx].lock() = Some(batch);
59    }
60
61    fn snapshot(&self) -> Vec<RecordBatch> {
62        let tail = self.tail.load(Ordering::Acquire);
63        let count = tail.min(self.capacity);
64        // Read the most recent `count` slots, oldest first.
65        let start = if tail <= self.capacity {
66            0
67        } else {
68            tail % self.capacity
69        };
70        let mut result = Vec::with_capacity(count);
71        for i in 0..count {
72            let idx = (start + i) % self.capacity;
73            if let Some(batch) = self.slots[idx].lock().as_ref() {
74                result.push(batch.clone());
75            }
76        }
77        result
78    }
79}
80
81/// A registered source in the catalog.
82pub struct SourceEntry {
83    /// Source name.
84    pub name: String,
85    /// Arrow schema.
86    pub schema: SchemaRef,
87    /// Watermark column name, if configured.
88    pub watermark_column: Option<String>,
89    /// Maximum out-of-orderness for watermark generation.
90    pub max_out_of_orderness: Option<Duration>,
91    /// Whether this source uses `PROCTIME()` watermarks.
92    pub is_processing_time: std::sync::atomic::AtomicBool,
93    pub(crate) source: streaming::Source<ArrowRecord>,
94    pub(crate) sink: streaming::Sink<ArrowRecord>,
95    buffer: SnapshotRing,
96    /// Wakeup handle for `db.insert()` event-driven notification.
97    data_notify: Arc<Notify>,
98}
99
100impl SourceEntry {
101    /// Push a batch to both the channel and the snapshot ring.
102    pub(crate) fn push_and_buffer(
103        &self,
104        batch: RecordBatch,
105    ) -> Result<(), laminar_core::streaming::StreamingError> {
106        self.source.push_arrow(batch.clone())?;
107        self.buffer.push(batch);
108        self.data_notify.notify_one();
109        Ok(())
110    }
111
112    pub(crate) fn snapshot(&self) -> Vec<RecordBatch> {
113        self.buffer.snapshot()
114    }
115
116    pub(crate) fn data_notify(&self) -> Arc<Notify> {
117        Arc::clone(&self.data_notify)
118    }
119}
120
121pub(crate) struct SinkEntry {
122    pub(crate) input: String,
123}
124
125pub(crate) struct QueryEntry {
126    pub(crate) id: u64,
127    pub(crate) sql: String,
128    pub(crate) active: bool,
129}
130
131pub(crate) struct StreamEntry {
132    pub(crate) name: String,
133    pub(crate) source: streaming::Source<ArrowRecord>,
134    pub(crate) sink: streaming::Sink<ArrowRecord>,
135}
136
137/// Central registry of sources, sinks, streams, and queries.
138pub struct SourceCatalog {
139    sources: RwLock<HashMap<String, Arc<SourceEntry>>>,
140    sinks: RwLock<HashMap<String, SinkEntry>>,
141    streams: RwLock<HashMap<String, Arc<StreamEntry>>>,
142    queries: RwLock<HashMap<u64, QueryEntry>>,
143    next_query_id: AtomicU64,
144    default_buffer_size: usize,
145    default_backpressure: BackpressureStrategy,
146}
147
148impl SourceCatalog {
149    /// Create a catalog with the given defaults for new sources.
150    #[must_use]
151    pub fn new(buffer_size: usize, backpressure: BackpressureStrategy) -> Self {
152        Self {
153            sources: RwLock::new(HashMap::new()),
154            sinks: RwLock::new(HashMap::new()),
155            streams: RwLock::new(HashMap::new()),
156            queries: RwLock::new(HashMap::new()),
157            next_query_id: AtomicU64::new(1),
158            default_buffer_size: buffer_size,
159            default_backpressure: backpressure,
160        }
161    }
162
163    #[allow(clippy::too_many_arguments)]
164    pub(crate) fn register_source(
165        &self,
166        name: &str,
167        schema: SchemaRef,
168        watermark_column: Option<String>,
169        max_out_of_orderness: Option<Duration>,
170        buffer_size: Option<usize>,
171        backpressure: Option<BackpressureStrategy>,
172    ) -> Result<Arc<SourceEntry>, crate::DbError> {
173        let mut sources = self.sources.write();
174        if sources.contains_key(name) {
175            return Err(crate::DbError::SourceAlreadyExists(name.to_string()));
176        }
177
178        let buf_size = buffer_size.unwrap_or(self.default_buffer_size);
179        let bp = backpressure.unwrap_or(self.default_backpressure);
180
181        // Channel buffer is at least 1024 to avoid blocking on small snapshot rings.
182        let channel_buf = buf_size.max(1024);
183        let config = SourceConfig {
184            channel: streaming::ChannelConfig {
185                buffer_size: channel_buf,
186                backpressure: bp,
187                wait_strategy: WaitStrategy::SpinYield,
188                track_stats: false,
189            },
190            name: Some(name.to_string()),
191        };
192
193        let (source, sink) = streaming::create_with_config::<ArrowRecord>(config);
194
195        let entry = Arc::new(SourceEntry {
196            name: name.to_string(),
197            schema,
198            watermark_column,
199            max_out_of_orderness,
200            is_processing_time: std::sync::atomic::AtomicBool::new(false),
201            source,
202            sink,
203            buffer: SnapshotRing::new(buf_size),
204            data_notify: Arc::new(Notify::new()),
205        });
206
207        sources.insert(name.to_string(), Arc::clone(&entry));
208        Ok(entry)
209    }
210
211    pub(crate) fn register_source_or_replace(
212        &self,
213        name: &str,
214        schema: SchemaRef,
215        watermark_column: Option<String>,
216        max_out_of_orderness: Option<Duration>,
217        buffer_size: Option<usize>,
218        backpressure: Option<BackpressureStrategy>,
219    ) -> Arc<SourceEntry> {
220        // Remove existing if present
221        self.sources.write().remove(name);
222        // Safe to unwrap since we just removed any conflict
223        self.register_source(
224            name,
225            schema,
226            watermark_column,
227            max_out_of_orderness,
228            buffer_size,
229            backpressure,
230        )
231        .unwrap()
232    }
233
234    /// Look up a registered source by name.
235    pub fn get_source(&self, name: &str) -> Option<Arc<SourceEntry>> {
236        self.sources.read().get(name).cloned()
237    }
238
239    /// Returns `true` if the source existed.
240    pub fn drop_source(&self, name: &str) -> bool {
241        self.sources.write().remove(name).is_some()
242    }
243
244    pub(crate) fn register_sink(&self, name: &str, input: &str) -> Result<(), crate::DbError> {
245        let mut sinks = self.sinks.write();
246        if sinks.contains_key(name) {
247            return Err(crate::DbError::SinkAlreadyExists(name.to_string()));
248        }
249        sinks.insert(
250            name.to_string(),
251            SinkEntry {
252                input: input.to_string(),
253            },
254        );
255        Ok(())
256    }
257
258    /// Returns `true` if the sink existed.
259    pub fn drop_sink(&self, name: &str) -> bool {
260        self.sinks.write().remove(name).is_some()
261    }
262
263    pub(crate) fn register_stream(&self, name: &str) -> Result<(), crate::DbError> {
264        let mut streams = self.streams.write();
265        if streams.contains_key(name) {
266            return Err(crate::DbError::StreamAlreadyExists(name.to_string()));
267        }
268
269        let config = SourceConfig {
270            channel: streaming::ChannelConfig {
271                buffer_size: self.default_buffer_size,
272                backpressure: self.default_backpressure,
273                wait_strategy: WaitStrategy::SpinYield,
274                track_stats: false,
275            },
276            name: Some(name.to_string()),
277        };
278
279        let (source, sink) = streaming::create_with_config::<ArrowRecord>(config);
280
281        streams.insert(
282            name.to_string(),
283            Arc::new(StreamEntry {
284                name: name.to_string(),
285                source,
286                sink,
287            }),
288        );
289        Ok(())
290    }
291
292    pub(crate) fn get_stream_subscription(
293        &self,
294        name: &str,
295    ) -> Option<streaming::Subscription<ArrowRecord>> {
296        self.streams
297            .read()
298            .get(name)
299            .map(|entry| entry.sink.subscribe())
300    }
301
302    pub(crate) fn get_stream_entry(&self, name: &str) -> Option<Arc<StreamEntry>> {
303        self.streams.read().get(name).cloned()
304    }
305
306    pub(crate) fn get_stream_source(&self, name: &str) -> Option<streaming::Source<ArrowRecord>> {
307        self.streams
308            .read()
309            .get(name)
310            .map(|entry| entry.source.clone())
311    }
312
313    /// Returns `true` if the stream existed.
314    pub fn drop_stream(&self, name: &str) -> bool {
315        self.streams.write().remove(name).is_some()
316    }
317
318    /// All registered stream names.
319    pub fn list_streams(&self) -> Vec<String> {
320        self.streams.read().keys().cloned().collect()
321    }
322
323    /// All registered source names.
324    pub fn list_sources(&self) -> Vec<String> {
325        self.sources.read().keys().cloned().collect()
326    }
327
328    /// All registered sink names.
329    pub fn list_sinks(&self) -> Vec<String> {
330        self.sinks.read().keys().cloned().collect()
331    }
332
333    /// Input source/table name for a sink, if registered.
334    pub fn get_sink_input(&self, name: &str) -> Option<String> {
335        self.sinks.read().get(name).map(|e| e.input.clone())
336    }
337
338    pub(crate) fn register_query(&self, sql: &str) -> u64 {
339        let id = self.next_query_id.fetch_add(1, Ordering::Relaxed);
340        let mut queries = self.queries.write();
341        queries.insert(
342            id,
343            QueryEntry {
344                id,
345                sql: sql.to_string(),
346                active: true,
347            },
348        );
349        id
350    }
351
352    pub(crate) fn deactivate_query(&self, id: u64) -> bool {
353        // Cap retained deactivated queries so finished SELECTs can't accumulate
354        // unboundedly; over the cap, the oldest (lowest id) is dropped.
355        const MAX_INACTIVE_QUERIES: usize = 100;
356        let mut queries = self.queries.write();
357        if let Some(entry) = queries.get_mut(&id) {
358            let was_active = entry.active;
359            entry.active = false;
360            if was_active {
361                let inactive_count = queries.values().filter(|q| !q.active).count();
362                if inactive_count > MAX_INACTIVE_QUERIES {
363                    let oldest_inactive_id =
364                        queries.values().filter(|q| !q.active).map(|q| q.id).min();
365                    if let Some(oldest_id) = oldest_inactive_id {
366                        queries.remove(&oldest_id);
367                    }
368                }
369            }
370            true
371        } else {
372            false
373        }
374    }
375
376    pub(crate) fn list_queries(&self) -> Vec<(u64, String, bool)> {
377        self.queries
378            .read()
379            .values()
380            .map(|q| (q.id, q.sql.clone(), q.active))
381            .collect()
382    }
383
384    /// Schema for DESCRIBE queries.
385    pub fn describe_source(&self, name: &str) -> Option<SchemaRef> {
386        self.sources.read().get(name).map(|e| e.schema.clone())
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393    use arrow::datatypes::{DataType, Field, Schema};
394
395    fn test_schema() -> SchemaRef {
396        Arc::new(Schema::new(vec![
397            Field::new("id", DataType::Int64, false),
398            Field::new("value", DataType::Float64, false),
399        ]))
400    }
401
402    #[tokio::test]
403    async fn test_register_source() {
404        let catalog = SourceCatalog::new(1024, BackpressureStrategy::Block);
405        let result = catalog.register_source("test", test_schema(), None, None, None, None);
406        assert!(result.is_ok());
407        assert!(catalog.get_source("test").is_some());
408    }
409
410    #[tokio::test]
411    async fn test_register_duplicate_source() {
412        let catalog = SourceCatalog::new(1024, BackpressureStrategy::Block);
413        catalog
414            .register_source("test", test_schema(), None, None, None, None)
415            .unwrap();
416        let result = catalog.register_source("test", test_schema(), None, None, None, None);
417        assert!(matches!(
418            result,
419            Err(crate::DbError::SourceAlreadyExists(_))
420        ));
421    }
422
423    #[tokio::test]
424    async fn test_drop_source() {
425        let catalog = SourceCatalog::new(1024, BackpressureStrategy::Block);
426        catalog
427            .register_source("test", test_schema(), None, None, None, None)
428            .unwrap();
429        assert!(catalog.drop_source("test"));
430        assert!(catalog.get_source("test").is_none());
431    }
432
433    #[tokio::test]
434    async fn test_list_sources() {
435        let catalog = SourceCatalog::new(1024, BackpressureStrategy::Block);
436        catalog
437            .register_source("a", test_schema(), None, None, None, None)
438            .unwrap();
439        catalog
440            .register_source("b", test_schema(), None, None, None, None)
441            .unwrap();
442        let mut names = catalog.list_sources();
443        names.sort();
444        assert_eq!(names, vec!["a", "b"]);
445    }
446
447    #[tokio::test]
448    async fn test_register_sink() {
449        let catalog = SourceCatalog::new(1024, BackpressureStrategy::Block);
450        assert!(catalog.register_sink("output", "events").is_ok());
451        assert_eq!(catalog.list_sinks(), vec!["output"]);
452    }
453
454    #[tokio::test]
455    async fn test_register_query() {
456        let catalog = SourceCatalog::new(1024, BackpressureStrategy::Block);
457        let id = catalog.register_query("SELECT * FROM events");
458        assert_eq!(id, 1);
459        let queries = catalog.list_queries();
460        assert_eq!(queries.len(), 1);
461        assert!(queries[0].2); // active
462    }
463
464    #[tokio::test]
465    async fn test_deactivate_query() {
466        let catalog = SourceCatalog::new(1024, BackpressureStrategy::Block);
467        let id = catalog.register_query("SELECT * FROM events");
468        catalog.deactivate_query(id);
469        let queries = catalog.list_queries();
470        assert!(!queries[0].2); // inactive
471    }
472
473    #[tokio::test]
474    async fn test_deactivate_query_limit() {
475        let catalog = SourceCatalog::new(1024, BackpressureStrategy::Block);
476        let mut ids = Vec::new();
477        for i in 0..105 {
478            let sql = format!("SELECT * FROM events_{i}");
479            ids.push(catalog.register_query(&sql));
480        }
481        for id in &ids {
482            catalog.deactivate_query(*id);
483        }
484        let queries = catalog.list_queries();
485        assert_eq!(queries.len(), 100);
486        let remaining_ids: std::collections::HashSet<u64> = queries.iter().map(|q| q.0).collect();
487        for id in 1..=5 {
488            assert!(
489                !remaining_ids.contains(&id),
490                "Query {id} should have been evicted"
491            );
492        }
493        for id in 6..=105 {
494            assert!(
495                remaining_ids.contains(&id),
496                "Query {id} should be remaining"
497            );
498        }
499    }
500
501    #[tokio::test]
502    async fn test_describe_source() {
503        let catalog = SourceCatalog::new(1024, BackpressureStrategy::Block);
504        let schema = test_schema();
505        catalog
506            .register_source("test", schema.clone(), None, None, None, None)
507            .unwrap();
508        let result = catalog.describe_source("test");
509        assert!(result.is_some());
510        assert_eq!(result.unwrap().fields().len(), 2);
511    }
512
513    #[tokio::test]
514    async fn test_or_replace() {
515        let catalog = SourceCatalog::new(1024, BackpressureStrategy::Block);
516        catalog
517            .register_source("test", test_schema(), None, None, None, None)
518            .unwrap();
519        let entry = catalog.register_source_or_replace(
520            "test",
521            test_schema(),
522            Some("ts".into()),
523            None,
524            None,
525            None,
526        );
527        assert_eq!(entry.watermark_column, Some("ts".to_string()));
528    }
529
530    #[tokio::test]
531    async fn test_push_and_buffer_snapshot() {
532        let catalog = SourceCatalog::new(1024, BackpressureStrategy::Block);
533        let schema = test_schema();
534        let entry = catalog
535            .register_source("test", schema.clone(), None, None, None, None)
536            .unwrap();
537
538        let batch = RecordBatch::try_new(
539            schema,
540            vec![
541                Arc::new(arrow::array::Int64Array::from(vec![1])),
542                Arc::new(arrow::array::Float64Array::from(vec![1.5])),
543            ],
544        )
545        .unwrap();
546
547        entry.push_and_buffer(batch).unwrap();
548        let snap = entry.snapshot();
549        assert_eq!(snap.len(), 1);
550        assert_eq!(snap[0].num_rows(), 1);
551    }
552
553    #[tokio::test]
554    async fn test_buffer_capacity_drops_oldest() {
555        // SnapshotRing capacity=2; channel gets a larger buffer so pushes don't block.
556        let catalog = SourceCatalog::new(2, BackpressureStrategy::DropOldest);
557        let schema = test_schema();
558        let entry = catalog
559            .register_source("test", schema.clone(), None, None, None, None)
560            .unwrap();
561
562        let values: [(i64, f64); 3] = [(0, 1.0), (1, 2.0), (2, 3.0)];
563        for (id, val) in values {
564            let batch = RecordBatch::try_new(
565                schema.clone(),
566                vec![
567                    Arc::new(arrow::array::Int64Array::from(vec![id])),
568                    Arc::new(arrow::array::Float64Array::from(vec![val])),
569                ],
570            )
571            .unwrap();
572            entry.push_and_buffer(batch).unwrap();
573        }
574
575        let snap = entry.snapshot();
576        // SnapshotRing capacity=2, so only the last 2 batches remain
577        assert_eq!(snap.len(), 2);
578        let col = snap[0]
579            .column(0)
580            .as_any()
581            .downcast_ref::<arrow::array::Int64Array>()
582            .unwrap();
583        assert_eq!(col.value(0), 1); // batch 0 was dropped
584    }
585}