Skip to main content

laminar_db/
catalog.rs

1//! Source and sink catalog for tracking registered streaming objects.
2#![allow(clippy::disallowed_types)] // cold path
3
4use std::collections::HashMap;
5use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
6use std::sync::Arc;
7use std::time::Duration;
8
9use arrow::array::RecordBatch;
10use arrow::datatypes::SchemaRef;
11use parking_lot::RwLock;
12use tokio::sync::Notify;
13
14use laminar_core::streaming::{self, BackpressureStrategy, SourceConfig, WaitStrategy};
15
16/// Record type for Arrow-based streaming subscriptions.
17#[derive(Clone, Debug)]
18pub struct ArrowRecord {
19    pub(crate) batch: RecordBatch,
20}
21
22impl laminar_core::streaming::Record for ArrowRecord {
23    fn schema() -> SchemaRef {
24        // This is a placeholder; the actual schema is on the SourceEntry.
25        // ArrowRecord is only used as a type parameter; push_arrow bypasses this.
26        Arc::new(arrow::datatypes::Schema::empty())
27    }
28
29    fn to_record_batch(&self) -> RecordBatch {
30        self.batch.clone()
31    }
32}
33
34/// Bounded ring buffer for snapshot batches.
35///
36/// Concurrent `push()` calls each get a unique slot via atomic `fetch_add`.
37/// Per-slot mutex protects the actual read/write.
38struct SnapshotRing {
39    slots: Box<[parking_lot::Mutex<Option<RecordBatch>>]>,
40    tail: AtomicUsize,
41    capacity: usize,
42}
43
44impl SnapshotRing {
45    fn new(capacity: usize) -> Self {
46        let cap = capacity.max(1);
47        let slots: Vec<_> = (0..cap).map(|_| parking_lot::Mutex::new(None)).collect();
48        Self {
49            slots: slots.into_boxed_slice(),
50            tail: AtomicUsize::new(0),
51            capacity: cap,
52        }
53    }
54
55    fn push(&self, batch: RecordBatch) {
56        // fetch_add is atomic — concurrent pushers each get a unique slot.
57        let idx = self.tail.fetch_add(1, Ordering::Relaxed) % self.capacity;
58        *self.slots[idx].lock() = Some(batch);
59    }
60
61    fn snapshot(&self) -> Vec<RecordBatch> {
62        let tail = self.tail.load(Ordering::Acquire);
63        let count = tail.min(self.capacity);
64        // Read the most recent `count` slots, oldest first.
65        let start = if tail <= self.capacity {
66            0
67        } else {
68            tail % self.capacity
69        };
70        let mut result = Vec::with_capacity(count);
71        for i in 0..count {
72            let idx = (start + i) % self.capacity;
73            if let Some(batch) = self.slots[idx].lock().as_ref() {
74                result.push(batch.clone());
75            }
76        }
77        result
78    }
79}
80
81/// A registered source in the catalog.
82pub struct SourceEntry {
83    /// Source name.
84    pub name: String,
85    /// Arrow schema.
86    pub schema: SchemaRef,
87    /// Watermark column name, if configured.
88    pub watermark_column: Option<String>,
89    /// Maximum out-of-orderness for watermark generation.
90    pub max_out_of_orderness: Option<Duration>,
91    /// Whether this source uses `PROCTIME()` watermarks.
92    pub is_processing_time: std::sync::atomic::AtomicBool,
93    pub(crate) source: streaming::Source<ArrowRecord>,
94    pub(crate) sink: streaming::Sink<ArrowRecord>,
95    buffer: SnapshotRing,
96    /// Wakeup handle for `db.insert()` event-driven notification.
97    data_notify: Arc<Notify>,
98}
99
100impl SourceEntry {
101    /// Push a batch to both the channel and the snapshot ring.
102    pub(crate) fn push_and_buffer(
103        &self,
104        batch: RecordBatch,
105    ) -> Result<(), laminar_core::streaming::StreamingError> {
106        self.source.push_arrow(batch.clone())?;
107        self.buffer.push(batch);
108        self.data_notify.notify_one();
109        Ok(())
110    }
111
112    pub(crate) fn snapshot(&self) -> Vec<RecordBatch> {
113        self.buffer.snapshot()
114    }
115
116    pub(crate) fn data_notify(&self) -> Arc<Notify> {
117        Arc::clone(&self.data_notify)
118    }
119}
120
121pub(crate) struct SinkEntry {
122    pub(crate) input: String,
123}
124
125pub(crate) struct QueryEntry {
126    pub(crate) id: u64,
127    pub(crate) sql: String,
128    pub(crate) active: bool,
129}
130
131pub(crate) struct StreamEntry {
132    pub(crate) name: String,
133    pub(crate) source: streaming::Source<ArrowRecord>,
134    pub(crate) sink: streaming::Sink<ArrowRecord>,
135}
136
137/// Central registry of sources, sinks, streams, and queries.
138pub struct SourceCatalog {
139    sources: RwLock<HashMap<String, Arc<SourceEntry>>>,
140    sinks: RwLock<HashMap<String, SinkEntry>>,
141    streams: RwLock<HashMap<String, Arc<StreamEntry>>>,
142    queries: RwLock<HashMap<u64, QueryEntry>>,
143    next_query_id: AtomicU64,
144    default_buffer_size: usize,
145    default_backpressure: BackpressureStrategy,
146}
147
148impl SourceCatalog {
149    /// Create a catalog with the given defaults for new sources.
150    #[must_use]
151    pub fn new(buffer_size: usize, backpressure: BackpressureStrategy) -> Self {
152        Self {
153            sources: RwLock::new(HashMap::new()),
154            sinks: RwLock::new(HashMap::new()),
155            streams: RwLock::new(HashMap::new()),
156            queries: RwLock::new(HashMap::new()),
157            next_query_id: AtomicU64::new(1),
158            default_buffer_size: buffer_size,
159            default_backpressure: backpressure,
160        }
161    }
162
163    #[allow(clippy::too_many_arguments)]
164    pub(crate) fn register_source(
165        &self,
166        name: &str,
167        schema: SchemaRef,
168        watermark_column: Option<String>,
169        max_out_of_orderness: Option<Duration>,
170        buffer_size: Option<usize>,
171        backpressure: Option<BackpressureStrategy>,
172    ) -> Result<Arc<SourceEntry>, crate::DbError> {
173        let mut sources = self.sources.write();
174        if sources.contains_key(name) {
175            return Err(crate::DbError::SourceAlreadyExists(name.to_string()));
176        }
177
178        let buf_size = buffer_size.unwrap_or(self.default_buffer_size);
179        let bp = backpressure.unwrap_or(self.default_backpressure);
180
181        // Channel buffer is at least 1024 to avoid blocking on small snapshot rings.
182        let channel_buf = buf_size.max(1024);
183        let config = SourceConfig {
184            channel: streaming::ChannelConfig {
185                buffer_size: channel_buf,
186                backpressure: bp,
187                wait_strategy: WaitStrategy::SpinYield,
188                track_stats: false,
189            },
190            name: Some(name.to_string()),
191        };
192
193        let (source, sink) = streaming::create_with_config::<ArrowRecord>(config);
194
195        let entry = Arc::new(SourceEntry {
196            name: name.to_string(),
197            schema,
198            watermark_column,
199            max_out_of_orderness,
200            is_processing_time: std::sync::atomic::AtomicBool::new(false),
201            source,
202            sink,
203            buffer: SnapshotRing::new(buf_size),
204            data_notify: Arc::new(Notify::new()),
205        });
206
207        sources.insert(name.to_string(), Arc::clone(&entry));
208        Ok(entry)
209    }
210
211    pub(crate) fn register_source_or_replace(
212        &self,
213        name: &str,
214        schema: SchemaRef,
215        watermark_column: Option<String>,
216        max_out_of_orderness: Option<Duration>,
217        buffer_size: Option<usize>,
218        backpressure: Option<BackpressureStrategy>,
219    ) -> Arc<SourceEntry> {
220        // Remove existing if present
221        self.sources.write().remove(name);
222        // Safe to unwrap since we just removed any conflict
223        self.register_source(
224            name,
225            schema,
226            watermark_column,
227            max_out_of_orderness,
228            buffer_size,
229            backpressure,
230        )
231        .unwrap()
232    }
233
234    /// Look up a registered source by name.
235    pub fn get_source(&self, name: &str) -> Option<Arc<SourceEntry>> {
236        self.sources.read().get(name).cloned()
237    }
238
239    /// Returns `true` if the source existed.
240    pub fn drop_source(&self, name: &str) -> bool {
241        self.sources.write().remove(name).is_some()
242    }
243
244    pub(crate) fn register_sink(&self, name: &str, input: &str) -> Result<(), crate::DbError> {
245        let mut sinks = self.sinks.write();
246        if sinks.contains_key(name) {
247            return Err(crate::DbError::SinkAlreadyExists(name.to_string()));
248        }
249        sinks.insert(
250            name.to_string(),
251            SinkEntry {
252                input: input.to_string(),
253            },
254        );
255        Ok(())
256    }
257
258    /// Returns `true` if the sink existed.
259    pub fn drop_sink(&self, name: &str) -> bool {
260        self.sinks.write().remove(name).is_some()
261    }
262
263    pub(crate) fn register_stream(&self, name: &str) -> Result<(), crate::DbError> {
264        let mut streams = self.streams.write();
265        if streams.contains_key(name) {
266            return Err(crate::DbError::StreamAlreadyExists(name.to_string()));
267        }
268
269        let config = SourceConfig {
270            channel: streaming::ChannelConfig {
271                buffer_size: self.default_buffer_size,
272                backpressure: self.default_backpressure,
273                wait_strategy: WaitStrategy::SpinYield,
274                track_stats: false,
275            },
276            name: Some(name.to_string()),
277        };
278
279        let (source, sink) = streaming::create_with_config::<ArrowRecord>(config);
280
281        streams.insert(
282            name.to_string(),
283            Arc::new(StreamEntry {
284                name: name.to_string(),
285                source,
286                sink,
287            }),
288        );
289        Ok(())
290    }
291
292    pub(crate) fn get_stream_subscription(
293        &self,
294        name: &str,
295    ) -> Option<streaming::Subscription<ArrowRecord>> {
296        self.streams
297            .read()
298            .get(name)
299            .map(|entry| entry.sink.subscribe())
300    }
301
302    pub(crate) fn get_stream_entry(&self, name: &str) -> Option<Arc<StreamEntry>> {
303        self.streams.read().get(name).cloned()
304    }
305
306    pub(crate) fn get_stream_source(&self, name: &str) -> Option<streaming::Source<ArrowRecord>> {
307        self.streams
308            .read()
309            .get(name)
310            .map(|entry| entry.source.clone())
311    }
312
313    /// Returns `true` if the stream existed.
314    pub fn drop_stream(&self, name: &str) -> bool {
315        self.streams.write().remove(name).is_some()
316    }
317
318    /// All registered stream names.
319    pub fn list_streams(&self) -> Vec<String> {
320        self.streams.read().keys().cloned().collect()
321    }
322
323    /// All registered source names.
324    pub fn list_sources(&self) -> Vec<String> {
325        self.sources.read().keys().cloned().collect()
326    }
327
328    /// All registered sink names.
329    pub fn list_sinks(&self) -> Vec<String> {
330        self.sinks.read().keys().cloned().collect()
331    }
332
333    /// Input source/table name for a sink, if registered.
334    pub fn get_sink_input(&self, name: &str) -> Option<String> {
335        self.sinks.read().get(name).map(|e| e.input.clone())
336    }
337
338    pub(crate) fn register_query(&self, sql: &str) -> u64 {
339        let id = self.next_query_id.fetch_add(1, Ordering::Relaxed);
340        let mut queries = self.queries.write();
341        queries.insert(
342            id,
343            QueryEntry {
344                id,
345                sql: sql.to_string(),
346                active: true,
347            },
348        );
349        id
350    }
351
352    pub(crate) fn deactivate_query(&self, id: u64) -> bool {
353        if let Some(entry) = self.queries.write().get_mut(&id) {
354            entry.active = false;
355            true
356        } else {
357            false
358        }
359    }
360
361    pub(crate) fn list_queries(&self) -> Vec<(u64, String, bool)> {
362        self.queries
363            .read()
364            .values()
365            .map(|q| (q.id, q.sql.clone(), q.active))
366            .collect()
367    }
368
369    /// Schema for DESCRIBE queries.
370    pub fn describe_source(&self, name: &str) -> Option<SchemaRef> {
371        self.sources.read().get(name).map(|e| e.schema.clone())
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378    use arrow::datatypes::{DataType, Field, Schema};
379
380    fn test_schema() -> SchemaRef {
381        Arc::new(Schema::new(vec![
382            Field::new("id", DataType::Int64, false),
383            Field::new("value", DataType::Float64, false),
384        ]))
385    }
386
387    #[tokio::test]
388    async fn test_register_source() {
389        let catalog = SourceCatalog::new(1024, BackpressureStrategy::Block);
390        let result = catalog.register_source("test", test_schema(), None, None, None, None);
391        assert!(result.is_ok());
392        assert!(catalog.get_source("test").is_some());
393    }
394
395    #[tokio::test]
396    async fn test_register_duplicate_source() {
397        let catalog = SourceCatalog::new(1024, BackpressureStrategy::Block);
398        catalog
399            .register_source("test", test_schema(), None, None, None, None)
400            .unwrap();
401        let result = catalog.register_source("test", test_schema(), None, None, None, None);
402        assert!(matches!(
403            result,
404            Err(crate::DbError::SourceAlreadyExists(_))
405        ));
406    }
407
408    #[tokio::test]
409    async fn test_drop_source() {
410        let catalog = SourceCatalog::new(1024, BackpressureStrategy::Block);
411        catalog
412            .register_source("test", test_schema(), None, None, None, None)
413            .unwrap();
414        assert!(catalog.drop_source("test"));
415        assert!(catalog.get_source("test").is_none());
416    }
417
418    #[tokio::test]
419    async fn test_list_sources() {
420        let catalog = SourceCatalog::new(1024, BackpressureStrategy::Block);
421        catalog
422            .register_source("a", test_schema(), None, None, None, None)
423            .unwrap();
424        catalog
425            .register_source("b", test_schema(), None, None, None, None)
426            .unwrap();
427        let mut names = catalog.list_sources();
428        names.sort();
429        assert_eq!(names, vec!["a", "b"]);
430    }
431
432    #[tokio::test]
433    async fn test_register_sink() {
434        let catalog = SourceCatalog::new(1024, BackpressureStrategy::Block);
435        assert!(catalog.register_sink("output", "events").is_ok());
436        assert_eq!(catalog.list_sinks(), vec!["output"]);
437    }
438
439    #[tokio::test]
440    async fn test_register_query() {
441        let catalog = SourceCatalog::new(1024, BackpressureStrategy::Block);
442        let id = catalog.register_query("SELECT * FROM events");
443        assert_eq!(id, 1);
444        let queries = catalog.list_queries();
445        assert_eq!(queries.len(), 1);
446        assert!(queries[0].2); // active
447    }
448
449    #[tokio::test]
450    async fn test_deactivate_query() {
451        let catalog = SourceCatalog::new(1024, BackpressureStrategy::Block);
452        let id = catalog.register_query("SELECT * FROM events");
453        catalog.deactivate_query(id);
454        let queries = catalog.list_queries();
455        assert!(!queries[0].2); // inactive
456    }
457
458    #[tokio::test]
459    async fn test_describe_source() {
460        let catalog = SourceCatalog::new(1024, BackpressureStrategy::Block);
461        let schema = test_schema();
462        catalog
463            .register_source("test", schema.clone(), None, None, None, None)
464            .unwrap();
465        let result = catalog.describe_source("test");
466        assert!(result.is_some());
467        assert_eq!(result.unwrap().fields().len(), 2);
468    }
469
470    #[tokio::test]
471    async fn test_or_replace() {
472        let catalog = SourceCatalog::new(1024, BackpressureStrategy::Block);
473        catalog
474            .register_source("test", test_schema(), None, None, None, None)
475            .unwrap();
476        let entry = catalog.register_source_or_replace(
477            "test",
478            test_schema(),
479            Some("ts".into()),
480            None,
481            None,
482            None,
483        );
484        assert_eq!(entry.watermark_column, Some("ts".to_string()));
485    }
486
487    #[tokio::test]
488    async fn test_push_and_buffer_snapshot() {
489        let catalog = SourceCatalog::new(1024, BackpressureStrategy::Block);
490        let schema = test_schema();
491        let entry = catalog
492            .register_source("test", schema.clone(), None, None, None, None)
493            .unwrap();
494
495        let batch = RecordBatch::try_new(
496            schema,
497            vec![
498                Arc::new(arrow::array::Int64Array::from(vec![1])),
499                Arc::new(arrow::array::Float64Array::from(vec![1.5])),
500            ],
501        )
502        .unwrap();
503
504        entry.push_and_buffer(batch).unwrap();
505        let snap = entry.snapshot();
506        assert_eq!(snap.len(), 1);
507        assert_eq!(snap[0].num_rows(), 1);
508    }
509
510    #[tokio::test]
511    async fn test_buffer_capacity_drops_oldest() {
512        // SnapshotRing capacity=2; channel gets a larger buffer so pushes don't block.
513        let catalog = SourceCatalog::new(2, BackpressureStrategy::DropOldest);
514        let schema = test_schema();
515        let entry = catalog
516            .register_source("test", schema.clone(), None, None, None, None)
517            .unwrap();
518
519        let values: [(i64, f64); 3] = [(0, 1.0), (1, 2.0), (2, 3.0)];
520        for (id, val) in values {
521            let batch = RecordBatch::try_new(
522                schema.clone(),
523                vec![
524                    Arc::new(arrow::array::Int64Array::from(vec![id])),
525                    Arc::new(arrow::array::Float64Array::from(vec![val])),
526                ],
527            )
528            .unwrap();
529            entry.push_and_buffer(batch).unwrap();
530        }
531
532        let snap = entry.snapshot();
533        // SnapshotRing capacity=2, so only the last 2 batches remain
534        assert_eq!(snap.len(), 2);
535        let col = snap[0]
536            .column(0)
537            .as_any()
538            .downcast_ref::<arrow::array::Int64Array>()
539            .unwrap();
540        assert_eq!(col.value(0), 1); // batch 0 was dropped
541    }
542}