Skip to main content

laminar_core/streaming/
checkpoint.rs

1//! Streaming checkpoint support.
2//!
3//! Provides optional, zero-overhead checkpointing for the streaming API.
4//! When disabled (the default), no runtime cost is incurred. When enabled,
5//! captures source sequences, watermarks, and persists checkpoint snapshots.
6//!
7//! ## Architecture
8//!
9//! ```text
10//! Ring 0 (Hot Path): Source.push() -> increment sequence (AtomicU64 Relaxed ~1ns)
11//! Ring 1 (Background): StreamCheckpointManager.trigger() -> capture atomics -> store
12//! Ring 2 (Control):    LaminarDB.checkpoint() -> manual trigger
13//! ```
14
15#![allow(clippy::disallowed_types)] // checkpoint serialization: Ring 1/Ring 2 path, not hot path
16
17use std::collections::HashMap;
18use std::fmt;
19use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
20use std::sync::Arc;
21
22// Configuration
23
24/// WAL mode for checkpoint durability.
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum WalMode {
27    /// Asynchronous WAL writes (faster, may lose last few entries on crash).
28    Async,
29    /// Synchronous WAL writes (slower, durable).
30    Sync,
31}
32
33/// Configuration for streaming checkpoints.
34///
35/// All fields default to `None`/disabled. Checkpointing is opt-in.
36#[derive(Debug, Clone)]
37pub struct StreamCheckpointConfig {
38    /// Checkpoint interval in milliseconds. `None` = manual only.
39    pub interval_ms: Option<u64>,
40    /// WAL mode. Requires `data_dir` to be set.
41    pub wal_mode: Option<WalMode>,
42    /// Directory for persisting checkpoints/WAL. `None` = in-memory only.
43    pub data_dir: Option<std::path::PathBuf>,
44    /// Changelog buffer capacity. `None` = no changelog buffer.
45    pub changelog_capacity: Option<usize>,
46    /// Maximum number of retained checkpoints. `None` = unlimited.
47    pub max_retained: Option<usize>,
48    /// Overflow policy for the changelog buffer.
49    pub overflow_policy: OverflowPolicy,
50}
51
52impl Default for StreamCheckpointConfig {
53    fn default() -> Self {
54        Self {
55            interval_ms: None,
56            wal_mode: None,
57            data_dir: None,
58            changelog_capacity: None,
59            max_retained: None,
60            overflow_policy: OverflowPolicy::DropNew,
61        }
62    }
63}
64
65impl StreamCheckpointConfig {
66    /// Validates the configuration, returning an error if invalid.
67    ///
68    /// # Errors
69    ///
70    /// Returns `CheckpointError::InvalidConfig` if WAL mode is set without
71    /// `data_dir`, or if `changelog_capacity` is zero.
72    pub fn validate(&self) -> Result<(), CheckpointError> {
73        if self.wal_mode.is_some() && self.data_dir.is_none() {
74            return Err(CheckpointError::InvalidConfig(
75                "WAL mode requires data_dir to be set".into(),
76            ));
77        }
78        if let Some(cap) = self.changelog_capacity {
79            if cap == 0 {
80                return Err(CheckpointError::InvalidConfig(
81                    "changelog_capacity must be > 0".into(),
82                ));
83            }
84        }
85        Ok(())
86    }
87}
88
89// Errors
90
91/// Errors from checkpoint operations.
92#[derive(Debug, Clone, PartialEq, Eq)]
93pub enum CheckpointError {
94    /// Checkpointing is disabled.
95    Disabled,
96    /// A data directory is required for this operation.
97    DataDirRequired,
98    /// WAL mode requires checkpointing to be enabled.
99    WalRequiresCheckpoint,
100    /// No checkpoint available for restore.
101    NoCheckpoint,
102    /// Operation timed out.
103    Timeout,
104    /// Invalid configuration.
105    InvalidConfig(String),
106    /// I/O error (stored as string for Clone/PartialEq).
107    IoError(String),
108}
109
110impl fmt::Display for CheckpointError {
111    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
112        match self {
113            Self::Disabled => write!(f, "checkpointing is disabled"),
114            Self::DataDirRequired => write!(f, "data directory is required"),
115            Self::WalRequiresCheckpoint => {
116                write!(f, "WAL mode requires checkpointing")
117            }
118            Self::NoCheckpoint => write!(f, "no checkpoint available"),
119            Self::Timeout => write!(f, "checkpoint operation timed out"),
120            Self::InvalidConfig(msg) => {
121                write!(f, "invalid checkpoint config: {msg}")
122            }
123            Self::IoError(msg) => write!(f, "checkpoint I/O error: {msg}"),
124        }
125    }
126}
127
128impl std::error::Error for CheckpointError {}
129
130// Overflow policy
131
132/// Policy when the changelog buffer is full.
133#[derive(Debug, Clone, Copy, PartialEq, Eq)]
134pub enum OverflowPolicy {
135    /// Drop new entries when buffer is full.
136    DropNew,
137    /// Overwrite the oldest entry.
138    OverwriteOldest,
139}
140
141// Changelog entry (24 bytes, repr(C))
142
143/// Type of changelog operation.
144#[derive(Debug, Clone, Copy, PartialEq, Eq)]
145#[repr(u8)]
146pub enum StreamChangeOp {
147    /// A record was pushed.
148    Push = 0,
149    /// A watermark was emitted.
150    Watermark = 1,
151    /// A checkpoint barrier.
152    Barrier = 2,
153}
154
155impl StreamChangeOp {
156    fn from_u8(v: u8) -> Option<Self> {
157        match v {
158            0 => Some(Self::Push),
159            1 => Some(Self::Watermark),
160            2 => Some(Self::Barrier),
161            _ => None,
162        }
163    }
164}
165
166/// A single changelog entry — fixed 24 bytes, no heap allocation.
167///
168/// Layout (repr(C)):
169/// ```text
170/// [source_id: u16][op: u8][padding: u8][reserved: u32][sequence: u64][watermark: i64]
171/// ```
172#[derive(Debug, Clone, Copy, PartialEq, Eq)]
173#[repr(C)]
174pub struct StreamChangelogEntry {
175    /// Source identifier (compact).
176    pub source_id: u16,
177    /// Operation type.
178    pub op: u8,
179    /// Padding for alignment.
180    _padding: u8,
181    /// Reserved for future use.
182    _reserved: u32,
183    /// Sequence number at time of operation.
184    pub sequence: u64,
185    /// Watermark value at time of operation.
186    pub watermark: i64,
187}
188
189impl StreamChangelogEntry {
190    /// Creates a new changelog entry.
191    #[must_use]
192    pub fn new(source_id: u16, op: StreamChangeOp, sequence: u64, watermark: i64) -> Self {
193        Self {
194            source_id,
195            op: op as u8,
196            _padding: 0,
197            _reserved: 0,
198            sequence,
199            watermark,
200        }
201    }
202
203    /// Returns the operation type.
204    #[must_use]
205    pub fn op_type(&self) -> Option<StreamChangeOp> {
206        StreamChangeOp::from_u8(self.op)
207    }
208}
209
210// Changelog buffer (pre-allocated ring buffer, zero-alloc after init)
211
212/// A pre-allocated ring buffer for changelog entries.
213///
214/// Uses a simple write/read index scheme. Not thread-safe on its own —
215/// intended to be used behind the `StreamCheckpointManager` mutex.
216pub struct StreamChangelogBuffer {
217    entries: Vec<StreamChangelogEntry>,
218    capacity: usize,
219    write_idx: usize,
220    read_idx: usize,
221    count: usize,
222    overflow_count: u64,
223    policy: OverflowPolicy,
224}
225
226impl StreamChangelogBuffer {
227    /// Creates a new changelog buffer with the given capacity.
228    #[must_use]
229    pub fn new(capacity: usize, policy: OverflowPolicy) -> Self {
230        let zeroed = StreamChangelogEntry {
231            source_id: 0,
232            op: 0,
233            _padding: 0,
234            _reserved: 0,
235            sequence: 0,
236            watermark: 0,
237        };
238        Self {
239            entries: vec![zeroed; capacity],
240            capacity,
241            write_idx: 0,
242            read_idx: 0,
243            count: 0,
244            overflow_count: 0,
245            policy,
246        }
247    }
248
249    /// Pushes an entry into the buffer.
250    ///
251    /// Returns `true` if the entry was stored, `false` if dropped due to
252    /// overflow policy.
253    pub fn push(&mut self, entry: StreamChangelogEntry) -> bool {
254        if self.count == self.capacity {
255            self.overflow_count += 1;
256            match self.policy {
257                OverflowPolicy::DropNew => return false,
258                OverflowPolicy::OverwriteOldest => {
259                    // Advance read pointer, discarding oldest
260                    self.read_idx = (self.read_idx + 1) % self.capacity;
261                    self.count -= 1;
262                }
263            }
264        }
265        self.entries[self.write_idx] = entry;
266        self.write_idx = (self.write_idx + 1) % self.capacity;
267        self.count += 1;
268        true
269    }
270
271    /// Pops the oldest entry from the buffer.
272    pub fn pop(&mut self) -> Option<StreamChangelogEntry> {
273        if self.count == 0 {
274            return None;
275        }
276        let entry = self.entries[self.read_idx];
277        self.read_idx = (self.read_idx + 1) % self.capacity;
278        self.count -= 1;
279        Some(entry)
280    }
281
282    /// Drains up to `max` entries into the provided vector.
283    pub fn drain(&mut self, max: usize, out: &mut Vec<StreamChangelogEntry>) {
284        let n = max.min(self.count);
285        for _ in 0..n {
286            if let Some(entry) = self.pop() {
287                out.push(entry);
288            }
289        }
290    }
291
292    /// Drains all entries into the provided vector.
293    pub fn drain_all(&mut self, out: &mut Vec<StreamChangelogEntry>) {
294        let n = self.count;
295        self.drain(n, out);
296    }
297
298    /// Returns the number of entries in the buffer.
299    #[must_use]
300    pub fn len(&self) -> usize {
301        self.count
302    }
303
304    /// Returns `true` if the buffer is empty.
305    #[must_use]
306    pub fn is_empty(&self) -> bool {
307        self.count == 0
308    }
309
310    /// Returns `true` if the buffer is full.
311    #[must_use]
312    pub fn is_full(&self) -> bool {
313        self.count == self.capacity
314    }
315
316    /// Returns the buffer capacity.
317    #[must_use]
318    pub fn capacity(&self) -> usize {
319        self.capacity
320    }
321
322    /// Returns the total number of overflows since creation.
323    #[must_use]
324    pub fn overflow_count(&self) -> u64 {
325        self.overflow_count
326    }
327}
328
329impl fmt::Debug for StreamChangelogBuffer {
330    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
331        f.debug_struct("StreamChangelogBuffer")
332            .field("capacity", &self.capacity)
333            .field("len", &self.count)
334            .field("overflow_count", &self.overflow_count)
335            .finish_non_exhaustive()
336    }
337}
338
339// Checkpoint snapshot
340
341/// A point-in-time snapshot of streaming pipeline state.
342#[derive(Debug, Clone)]
343pub struct StreamCheckpoint {
344    /// Unique checkpoint identifier (monotonically increasing).
345    pub id: u64,
346    /// Epoch number.
347    pub epoch: u64,
348    /// Source name -> sequence number at checkpoint time.
349    pub source_sequences: HashMap<String, u64>,
350    /// Sink name -> position at checkpoint time.
351    pub sink_positions: HashMap<String, u64>,
352    /// Source name -> watermark at checkpoint time.
353    pub watermarks: HashMap<String, i64>,
354    /// Operator name -> opaque state bytes.
355    pub operator_states: HashMap<String, Vec<u8>>,
356    /// Timestamp when this checkpoint was created (millis since epoch).
357    pub created_at: u64,
358}
359
360impl StreamCheckpoint {
361    /// Serializes the checkpoint to bytes.
362    ///
363    /// Format (v2):
364    /// ```text
365    /// [version: 2][id: 8][epoch: 8][created_at: 8]
366    /// [num_sources: 4][ [name_len:4][name][seq:8] ... ]
367    /// [num_sinks: 4][ [name_len:4][name][pos:8] ... ]
368    /// [num_watermarks: 4][ [name_len:4][name][wm:8] ... ]
369    /// [num_ops: 4][ [name_len:4][name][data_len:4][data] ... ]
370    /// [crc32c: 4]  (over all bytes after version)
371    /// ```
372    #[must_use]
373    #[allow(clippy::cast_possible_truncation)] // Wire format uses u32 for collection lengths
374    pub fn to_bytes(&self) -> Vec<u8> {
375        let mut buf = Vec::with_capacity(256);
376
377        // Version
378        buf.push(2u8);
379
380        // Header
381        buf.extend_from_slice(&self.id.to_le_bytes());
382        buf.extend_from_slice(&self.epoch.to_le_bytes());
383        buf.extend_from_slice(&self.created_at.to_le_bytes());
384
385        // Source sequences
386        buf.extend_from_slice(&(self.source_sequences.len() as u32).to_le_bytes());
387        for (name, seq) in &self.source_sequences {
388            buf.extend_from_slice(&(name.len() as u32).to_le_bytes());
389            buf.extend_from_slice(name.as_bytes());
390            buf.extend_from_slice(&seq.to_le_bytes());
391        }
392
393        // Sink positions
394        buf.extend_from_slice(&(self.sink_positions.len() as u32).to_le_bytes());
395        for (name, pos) in &self.sink_positions {
396            buf.extend_from_slice(&(name.len() as u32).to_le_bytes());
397            buf.extend_from_slice(name.as_bytes());
398            buf.extend_from_slice(&pos.to_le_bytes());
399        }
400
401        // Watermarks
402        buf.extend_from_slice(&(self.watermarks.len() as u32).to_le_bytes());
403        for (name, wm) in &self.watermarks {
404            buf.extend_from_slice(&(name.len() as u32).to_le_bytes());
405            buf.extend_from_slice(name.as_bytes());
406            buf.extend_from_slice(&wm.to_le_bytes());
407        }
408
409        // Operator states
410        buf.extend_from_slice(&(self.operator_states.len() as u32).to_le_bytes());
411        for (name, data) in &self.operator_states {
412            buf.extend_from_slice(&(name.len() as u32).to_le_bytes());
413            buf.extend_from_slice(name.as_bytes());
414            buf.extend_from_slice(&(data.len() as u32).to_le_bytes());
415            buf.extend_from_slice(data);
416        }
417
418        // CRC32C over payload (everything after the version byte)
419        let crc = crc32c::crc32c(&buf[1..]);
420        buf.extend_from_slice(&crc.to_le_bytes());
421
422        buf
423    }
424
425    /// Deserializes a checkpoint from bytes.
426    ///
427    /// # Errors
428    ///
429    /// Returns `CheckpointError::IoError` if the data is truncated, corrupted,
430    /// or uses an unsupported version.
431    #[allow(clippy::similar_names, clippy::too_many_lines)]
432    pub fn from_bytes(data: &[u8]) -> Result<Self, CheckpointError> {
433        let mut pos = 0;
434
435        let read_u32 = |p: &mut usize| -> Result<u32, CheckpointError> {
436            if *p + 4 > data.len() {
437                return Err(CheckpointError::IoError("truncated u32".into()));
438            }
439            let val = u32::from_le_bytes(
440                data[*p..*p + 4]
441                    .try_into()
442                    .map_err(|_| CheckpointError::IoError("bad u32".into()))?,
443            );
444            *p += 4;
445            Ok(val)
446        };
447
448        let read_u64_val = |p: &mut usize| -> Result<u64, CheckpointError> {
449            if *p + 8 > data.len() {
450                return Err(CheckpointError::IoError("truncated u64".into()));
451            }
452            let val = u64::from_le_bytes(
453                data[*p..*p + 8]
454                    .try_into()
455                    .map_err(|_| CheckpointError::IoError("bad u64".into()))?,
456            );
457            *p += 8;
458            Ok(val)
459        };
460
461        let read_i64_val = |p: &mut usize| -> Result<i64, CheckpointError> {
462            if *p + 8 > data.len() {
463                return Err(CheckpointError::IoError("truncated i64".into()));
464            }
465            let val = i64::from_le_bytes(
466                data[*p..*p + 8]
467                    .try_into()
468                    .map_err(|_| CheckpointError::IoError("bad i64".into()))?,
469            );
470            *p += 8;
471            Ok(val)
472        };
473
474        let read_string = |p: &mut usize| -> Result<String, CheckpointError> {
475            let slen = read_u32(p)? as usize;
476            if *p + slen > data.len() {
477                return Err(CheckpointError::IoError("truncated string".into()));
478            }
479            let s = std::str::from_utf8(&data[*p..*p + slen])
480                .map_err(|_| CheckpointError::IoError("invalid utf8".into()))?
481                .to_string();
482            *p += slen;
483            Ok(s)
484        };
485
486        // Version
487        if pos >= data.len() {
488            return Err(CheckpointError::IoError("empty checkpoint data".into()));
489        }
490        let version = data[pos];
491        pos += 1;
492        if version != 2 {
493            return Err(CheckpointError::IoError(format!(
494                "unsupported checkpoint version: {version} (expected 2)"
495            )));
496        }
497
498        // Verify CRC32C: last 4 bytes are the checksum over bytes[1..len-4]
499        if data.len() < 5 {
500            return Err(CheckpointError::IoError(
501                "checkpoint too short for CRC".into(),
502            ));
503        }
504        let crc_start = data.len() - 4;
505        let stored_crc = u32::from_le_bytes(
506            data[crc_start..]
507                .try_into()
508                .map_err(|_| CheckpointError::IoError("bad CRC bytes".into()))?,
509        );
510        let computed_crc = crc32c::crc32c(&data[1..crc_start]);
511        if stored_crc != computed_crc {
512            return Err(CheckpointError::IoError(format!(
513                "CRC mismatch: stored={stored_crc:#010x} computed={computed_crc:#010x}"
514            )));
515        }
516
517        // Limit deserialization to the payload (exclude trailing CRC)
518        let data = &data[..crc_start];
519
520        // Header
521        let id = read_u64_val(&mut pos)?;
522        let epoch = read_u64_val(&mut pos)?;
523        let created_at = read_u64_val(&mut pos)?;
524
525        // Source sequences
526        let num_sources = read_u32(&mut pos)? as usize;
527        let mut source_sequences = HashMap::with_capacity(num_sources);
528        for _ in 0..num_sources {
529            let name = read_string(&mut pos)?;
530            let seq = read_u64_val(&mut pos)?;
531            source_sequences.insert(name, seq);
532        }
533
534        // Sink positions
535        let num_sinks = read_u32(&mut pos)? as usize;
536        let mut sink_positions = HashMap::with_capacity(num_sinks);
537        for _ in 0..num_sinks {
538            let name = read_string(&mut pos)?;
539            let sink_pos = read_u64_val(&mut pos)?;
540            sink_positions.insert(name, sink_pos);
541        }
542
543        // Watermarks
544        let num_watermarks = read_u32(&mut pos)? as usize;
545        let mut watermarks = HashMap::with_capacity(num_watermarks);
546        for _ in 0..num_watermarks {
547            let name = read_string(&mut pos)?;
548            let wm = read_i64_val(&mut pos)?;
549            watermarks.insert(name, wm);
550        }
551
552        // Operator states
553        let num_ops = read_u32(&mut pos)? as usize;
554        let mut operator_states = HashMap::with_capacity(num_ops);
555        for _ in 0..num_ops {
556            let name = read_string(&mut pos)?;
557            let data_len = read_u32(&mut pos)? as usize;
558            if pos + data_len > data.len() {
559                return Err(CheckpointError::IoError("truncated operator state".into()));
560            }
561            let state_data = data[pos..pos + data_len].to_vec();
562            pos += data_len;
563            operator_states.insert(name, state_data);
564        }
565
566        Ok(Self {
567            id,
568            epoch,
569            source_sequences,
570            sink_positions,
571            watermarks,
572            operator_states,
573            created_at,
574        })
575    }
576}
577
578// Registered source info (held by the manager)
579
580/// Registered source state visible to the checkpoint manager.
581struct RegisteredSource {
582    /// Shared sequence counter (atomically incremented by Source on push).
583    sequence: Arc<AtomicU64>,
584    /// Shared watermark (atomically updated by Source).
585    watermark: Arc<AtomicI64>,
586}
587
588// Checkpoint manager
589
590/// Coordinates checkpoint lifecycle for streaming sources and sinks.
591///
592/// Disabled by default. When enabled via [`StreamCheckpointConfig`], the
593/// manager captures atomic counters from registered sources to produce
594/// consistent [`StreamCheckpoint`] snapshots.
595pub struct StreamCheckpointManager {
596    config: StreamCheckpointConfig,
597    enabled: bool,
598    sources: HashMap<String, RegisteredSource>,
599    sinks: HashMap<String, u64>,
600    checkpoints: Vec<StreamCheckpoint>,
601    next_id: u64,
602    epoch: u64,
603    changelog: Option<StreamChangelogBuffer>,
604}
605
606impl StreamCheckpointManager {
607    /// Creates a new checkpoint manager.
608    ///
609    /// If `config` validation fails, the manager is created in disabled state.
610    #[must_use]
611    pub fn new(config: StreamCheckpointConfig) -> Self {
612        let enabled = config.validate().is_ok();
613        let changelog = config
614            .changelog_capacity
615            .filter(|_| enabled)
616            .map(|cap| StreamChangelogBuffer::new(cap, config.overflow_policy));
617        Self {
618            config,
619            enabled,
620            sources: HashMap::new(),
621            sinks: HashMap::new(),
622            checkpoints: Vec::new(),
623            next_id: 1,
624            epoch: 0,
625            changelog,
626        }
627    }
628
629    /// Creates a disabled (no-op) manager.
630    #[must_use]
631    pub fn disabled() -> Self {
632        Self {
633            config: StreamCheckpointConfig::default(),
634            enabled: false,
635            sources: HashMap::new(),
636            sinks: HashMap::new(),
637            checkpoints: Vec::new(),
638            next_id: 1,
639            epoch: 0,
640            changelog: None,
641        }
642    }
643
644    /// Returns whether checkpointing is enabled.
645    #[must_use]
646    pub fn is_enabled(&self) -> bool {
647        self.enabled
648    }
649
650    /// Registers a source for checkpoint tracking.
651    ///
652    /// The `sequence` and `watermark` atomics are shared with the live
653    /// [`Source`](super::Source) — reading them is lock-free.
654    pub fn register_source(
655        &mut self,
656        name: &str,
657        sequence: Arc<AtomicU64>,
658        watermark: Arc<AtomicI64>,
659    ) {
660        self.sources.insert(
661            name.to_string(),
662            RegisteredSource {
663                sequence,
664                watermark,
665            },
666        );
667    }
668
669    /// Registers a sink for checkpoint tracking.
670    pub fn register_sink(&mut self, name: &str, position: u64) {
671        self.sinks.insert(name.to_string(), position);
672    }
673
674    /// Triggers a checkpoint, capturing current source/sink state.
675    ///
676    /// Returns the checkpoint ID, or `None` if checkpointing is disabled.
677    #[allow(clippy::cast_possible_truncation)] // Timestamp ms fits i64 for ~292 years from epoch
678    pub fn trigger(&mut self) -> Option<u64> {
679        if !self.enabled {
680            return None;
681        }
682
683        self.epoch += 1;
684        let id = self.next_id;
685        self.next_id += 1;
686
687        // Capture source sequences and watermarks atomically
688        let mut source_sequences = HashMap::with_capacity(self.sources.len());
689        let mut watermarks = HashMap::with_capacity(self.sources.len());
690        for (name, src) in &self.sources {
691            source_sequences.insert(name.clone(), src.sequence.load(Ordering::Acquire));
692            watermarks.insert(name.clone(), src.watermark.load(Ordering::Acquire));
693        }
694
695        // Capture sink positions
696        let sink_positions = self.sinks.clone();
697
698        let now = std::time::SystemTime::now()
699            .duration_since(std::time::UNIX_EPOCH)
700            .map(|d| d.as_millis() as u64)
701            .unwrap_or(0);
702
703        let checkpoint = StreamCheckpoint {
704            id,
705            epoch: self.epoch,
706            source_sequences,
707            sink_positions,
708            watermarks,
709            operator_states: HashMap::new(),
710            created_at: now,
711        };
712
713        self.checkpoints.push(checkpoint);
714
715        // Prune old checkpoints if max_retained is set
716        if let Some(max) = self.config.max_retained {
717            while self.checkpoints.len() > max {
718                self.checkpoints.remove(0);
719            }
720        }
721
722        Some(id)
723    }
724
725    /// Creates a checkpoint and returns the checkpoint ID.
726    ///
727    /// # Errors
728    ///
729    /// Returns `CheckpointError::Disabled` if checkpointing is not enabled.
730    pub fn checkpoint(&mut self) -> Result<Option<u64>, CheckpointError> {
731        if !self.enabled {
732            return Err(CheckpointError::Disabled);
733        }
734        Ok(self.trigger())
735    }
736
737    /// Returns the most recent checkpoint for restore.
738    ///
739    /// # Errors
740    ///
741    /// Returns `CheckpointError::Disabled` if checkpointing is not enabled,
742    /// or `CheckpointError::NoCheckpoint` if no checkpoint exists.
743    pub fn restore(&self) -> Result<&StreamCheckpoint, CheckpointError> {
744        if !self.enabled {
745            return Err(CheckpointError::Disabled);
746        }
747        self.checkpoints.last().ok_or(CheckpointError::NoCheckpoint)
748    }
749
750    /// Returns a checkpoint by ID.
751    #[must_use]
752    pub fn get_checkpoint(&self, id: u64) -> Option<&StreamCheckpoint> {
753        self.checkpoints.iter().find(|cp| cp.id == id)
754    }
755
756    /// Returns the ID of the most recent checkpoint.
757    #[must_use]
758    pub fn last_checkpoint_id(&self) -> Option<u64> {
759        self.checkpoints.last().map(|cp| cp.id)
760    }
761
762    /// Returns a reference to the changelog buffer, if configured.
763    #[must_use]
764    pub fn changelog(&self) -> Option<&StreamChangelogBuffer> {
765        self.changelog.as_ref()
766    }
767
768    /// Returns a mutable reference to the changelog buffer.
769    pub fn changelog_mut(&mut self) -> Option<&mut StreamChangelogBuffer> {
770        self.changelog.as_mut()
771    }
772
773    /// Returns the current epoch.
774    #[must_use]
775    pub fn epoch(&self) -> u64 {
776        self.epoch
777    }
778}
779
780impl fmt::Debug for StreamCheckpointManager {
781    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
782        f.debug_struct("StreamCheckpointManager")
783            .field("enabled", &self.enabled)
784            .field("sources", &self.sources.len())
785            .field("sinks", &self.sinks.len())
786            .field("checkpoints", &self.checkpoints.len())
787            .field("epoch", &self.epoch)
788            .finish_non_exhaustive()
789    }
790}
791
792// Tests
793
794#[cfg(test)]
795mod tests {
796    use super::*;
797
798    fn enabled_config() -> StreamCheckpointConfig {
799        StreamCheckpointConfig {
800            interval_ms: Some(1000),
801            ..Default::default()
802        }
803    }
804
805    // -- Config / disabled tests --
806
807    #[test]
808    fn test_checkpoint_disabled_by_default() {
809        let config = StreamCheckpointConfig::default();
810        let mgr = StreamCheckpointManager::new(config);
811        // Default config is valid but has no interval — still "enabled"
812        // because validate() passes. Disabled means validate() fails.
813        assert!(mgr.is_enabled());
814
815        // A truly disabled manager:
816        let mgr2 = StreamCheckpointManager::disabled();
817        assert!(!mgr2.is_enabled());
818    }
819
820    #[test]
821    fn test_checkpoint_no_op_when_disabled() {
822        let mgr = StreamCheckpointManager::disabled();
823        assert!(!mgr.is_enabled());
824        assert_eq!(mgr.last_checkpoint_id(), None);
825    }
826
827    #[test]
828    fn test_checkpoint_config_requires_data_dir() {
829        let config = StreamCheckpointConfig {
830            wal_mode: Some(WalMode::Sync),
831            data_dir: None,
832            ..Default::default()
833        };
834        assert!(config.validate().is_err());
835
836        // With data_dir set, validation passes
837        let config2 = StreamCheckpointConfig {
838            wal_mode: Some(WalMode::Sync),
839            data_dir: Some(std::path::PathBuf::from("/tmp/test")),
840            ..Default::default()
841        };
842        assert!(config2.validate().is_ok());
843    }
844
845    #[test]
846    fn test_wal_requires_checkpoint() {
847        let config = StreamCheckpointConfig {
848            wal_mode: Some(WalMode::Async),
849            data_dir: None, // missing
850            ..Default::default()
851        };
852        let result = config.validate();
853        assert!(matches!(result, Err(CheckpointError::InvalidConfig(_))));
854    }
855
856    // -- Source registration --
857
858    #[test]
859    fn test_register_source() {
860        let mut mgr = StreamCheckpointManager::new(enabled_config());
861
862        let seq = Arc::new(AtomicU64::new(0));
863        let wm = Arc::new(AtomicI64::new(i64::MIN));
864
865        mgr.register_source("trades", Arc::clone(&seq), Arc::clone(&wm));
866        assert!(mgr.is_enabled());
867    }
868
869    // -- Trigger / capture --
870
871    #[test]
872    fn test_trigger_checkpoint() {
873        let mut mgr = StreamCheckpointManager::new(enabled_config());
874        let id = mgr.trigger();
875        assert_eq!(id, Some(1));
876
877        let id2 = mgr.trigger();
878        assert_eq!(id2, Some(2));
879    }
880
881    #[test]
882    fn test_checkpoint_captures_sequences() {
883        let mut mgr = StreamCheckpointManager::new(enabled_config());
884
885        let seq = Arc::new(AtomicU64::new(0));
886        let wm = Arc::new(AtomicI64::new(i64::MIN));
887        mgr.register_source("src1", Arc::clone(&seq), Arc::clone(&wm));
888
889        // Simulate pushes
890        seq.store(42, Ordering::Release);
891
892        let id = mgr.trigger().unwrap();
893        let cp = mgr.get_checkpoint(id).unwrap();
894        assert_eq!(cp.source_sequences.get("src1"), Some(&42));
895    }
896
897    #[test]
898    fn test_checkpoint_captures_watermarks() {
899        let mut mgr = StreamCheckpointManager::new(enabled_config());
900
901        let seq = Arc::new(AtomicU64::new(0));
902        let wm = Arc::new(AtomicI64::new(5000));
903        mgr.register_source("src1", Arc::clone(&seq), Arc::clone(&wm));
904
905        let id = mgr.trigger().unwrap();
906        let cp = mgr.get_checkpoint(id).unwrap();
907        assert_eq!(cp.watermarks.get("src1"), Some(&5000));
908    }
909
910    #[test]
911    fn test_restore_from_checkpoint() {
912        let mut mgr = StreamCheckpointManager::new(enabled_config());
913
914        let seq = Arc::new(AtomicU64::new(10));
915        let wm = Arc::new(AtomicI64::new(1000));
916        mgr.register_source("src1", Arc::clone(&seq), Arc::clone(&wm));
917
918        mgr.trigger();
919        let restored = mgr.restore().unwrap();
920        assert_eq!(restored.source_sequences.get("src1"), Some(&10));
921        assert_eq!(restored.watermarks.get("src1"), Some(&1000));
922    }
923
924    // -- Changelog buffer --
925
926    #[test]
927    fn test_changelog_buffer_push_pop() {
928        let mut buf = StreamChangelogBuffer::new(4, OverflowPolicy::DropNew);
929        assert!(buf.is_empty());
930
931        let entry = StreamChangelogEntry::new(0, StreamChangeOp::Push, 1, i64::MIN);
932        assert!(buf.push(entry));
933        assert_eq!(buf.len(), 1);
934
935        let popped = buf.pop().unwrap();
936        assert_eq!(popped.sequence, 1);
937        assert!(buf.is_empty());
938    }
939
940    #[test]
941    fn test_changelog_buffer_overflow() {
942        // DropNew policy
943        let mut buf = StreamChangelogBuffer::new(2, OverflowPolicy::DropNew);
944        let e1 = StreamChangelogEntry::new(0, StreamChangeOp::Push, 1, i64::MIN);
945        let e2 = StreamChangelogEntry::new(0, StreamChangeOp::Push, 2, i64::MIN);
946        let e3 = StreamChangelogEntry::new(0, StreamChangeOp::Push, 3, i64::MIN);
947
948        assert!(buf.push(e1));
949        assert!(buf.push(e2));
950        assert!(buf.is_full());
951        assert!(!buf.push(e3)); // dropped
952        assert_eq!(buf.overflow_count(), 1);
953
954        // Verify oldest is still there
955        assert_eq!(buf.pop().unwrap().sequence, 1);
956
957        // OverwriteOldest policy
958        let mut buf2 = StreamChangelogBuffer::new(2, OverflowPolicy::OverwriteOldest);
959        assert!(buf2.push(e1));
960        assert!(buf2.push(e2));
961        assert!(buf2.push(e3)); // overwrites e1
962        assert_eq!(buf2.overflow_count(), 1);
963        assert_eq!(buf2.pop().unwrap().sequence, 2); // e1 was overwritten
964    }
965
966    // -- Prune --
967
968    #[test]
969    fn test_checkpoint_prune_old() {
970        let config = StreamCheckpointConfig {
971            interval_ms: Some(1000),
972            max_retained: Some(2),
973            ..Default::default()
974        };
975        let mut mgr = StreamCheckpointManager::new(config);
976
977        mgr.trigger(); // id=1
978        mgr.trigger(); // id=2
979        mgr.trigger(); // id=3 — should prune id=1
980
981        assert_eq!(mgr.checkpoints.len(), 2);
982        assert!(mgr.get_checkpoint(1).is_none());
983        assert!(mgr.get_checkpoint(2).is_some());
984        assert!(mgr.get_checkpoint(3).is_some());
985    }
986
987    // -- Serialization --
988
989    #[test]
990    fn test_checkpoint_serialization() {
991        let mut source_sequences = HashMap::new();
992        source_sequences.insert("src1".to_string(), 100u64);
993        source_sequences.insert("src2".to_string(), 200u64);
994
995        let mut sink_positions = HashMap::new();
996        sink_positions.insert("sink1".to_string(), 50u64);
997
998        let mut watermarks = HashMap::new();
999        watermarks.insert("src1".to_string(), 5000i64);
1000        watermarks.insert("src2".to_string(), 6000i64);
1001
1002        let mut operator_states = HashMap::new();
1003        operator_states.insert("op1".to_string(), vec![1, 2, 3, 4]);
1004
1005        let cp = StreamCheckpoint {
1006            id: 42,
1007            epoch: 7,
1008            source_sequences,
1009            sink_positions,
1010            watermarks,
1011            operator_states,
1012            created_at: 1_706_400_000_000,
1013        };
1014
1015        let bytes = cp.to_bytes();
1016        let restored = StreamCheckpoint::from_bytes(&bytes).unwrap();
1017
1018        assert_eq!(restored.id, 42);
1019        assert_eq!(restored.epoch, 7);
1020        assert_eq!(restored.created_at, 1_706_400_000_000);
1021        assert_eq!(restored.source_sequences.get("src1"), Some(&100));
1022        assert_eq!(restored.source_sequences.get("src2"), Some(&200));
1023        assert_eq!(restored.sink_positions.get("sink1"), Some(&50));
1024        assert_eq!(restored.watermarks.get("src1"), Some(&5000));
1025        assert_eq!(restored.watermarks.get("src2"), Some(&6000));
1026        assert_eq!(restored.operator_states.get("op1"), Some(&vec![1, 2, 3, 4]));
1027    }
1028
1029    #[test]
1030    fn test_changelog_entry_size() {
1031        assert_eq!(
1032            std::mem::size_of::<StreamChangelogEntry>(),
1033            24,
1034            "StreamChangelogEntry must be exactly 24 bytes"
1035        );
1036    }
1037
1038    // -- Source sequence counter tests --
1039
1040    #[test]
1041    fn test_source_sequence_counter() {
1042        let seq = Arc::new(AtomicU64::new(0));
1043        assert_eq!(seq.load(Ordering::Acquire), 0);
1044
1045        seq.fetch_add(1, Ordering::Relaxed);
1046        seq.fetch_add(1, Ordering::Relaxed);
1047        seq.fetch_add(1, Ordering::Relaxed);
1048        assert_eq!(seq.load(Ordering::Acquire), 3);
1049    }
1050
1051    #[test]
1052    fn test_source_clone_shares_sequence() {
1053        let seq = Arc::new(AtomicU64::new(0));
1054        let seq2 = Arc::clone(&seq);
1055
1056        seq.fetch_add(1, Ordering::Relaxed);
1057        assert_eq!(seq2.load(Ordering::Acquire), 1);
1058
1059        seq2.fetch_add(5, Ordering::Relaxed);
1060        assert_eq!(seq.load(Ordering::Acquire), 6);
1061    }
1062
1063    #[test]
1064    fn test_stream_checkpoint_crc_roundtrip() {
1065        let mut cp = StreamCheckpoint {
1066            id: 42,
1067            epoch: 10,
1068            source_sequences: HashMap::new(),
1069            sink_positions: HashMap::new(),
1070            watermarks: HashMap::new(),
1071            operator_states: HashMap::new(),
1072            created_at: 0,
1073        };
1074        cp.source_sequences.insert("kafka".into(), 1000);
1075        cp.watermarks.insert("src".into(), 500);
1076        cp.operator_states.insert("agg".into(), vec![1, 2, 3]);
1077
1078        let bytes = cp.to_bytes();
1079        let restored = StreamCheckpoint::from_bytes(&bytes).unwrap();
1080        assert_eq!(restored.id, 42);
1081        assert_eq!(restored.epoch, 10);
1082        assert_eq!(restored.source_sequences.get("kafka"), Some(&1000));
1083        assert_eq!(restored.watermarks.get("src"), Some(&500));
1084        assert_eq!(restored.operator_states.get("agg").unwrap(), &[1, 2, 3]);
1085    }
1086
1087    #[test]
1088    fn test_stream_checkpoint_crc_corruption_detected() {
1089        let cp = StreamCheckpoint {
1090            id: 1,
1091            epoch: 1,
1092            source_sequences: HashMap::new(),
1093            sink_positions: HashMap::new(),
1094            watermarks: HashMap::new(),
1095            operator_states: HashMap::new(),
1096            created_at: 0,
1097        };
1098        let mut bytes = cp.to_bytes();
1099        // Flip a byte in the payload (not the version or CRC)
1100        bytes[5] ^= 0xFF;
1101
1102        let result = StreamCheckpoint::from_bytes(&bytes);
1103        assert!(result.is_err());
1104        let err = result.unwrap_err().to_string();
1105        assert!(err.contains("CRC mismatch"), "got: {err}");
1106    }
1107
1108    #[test]
1109    fn test_stream_checkpoint_v1_rejected() {
1110        let mut bytes = vec![1u8];
1111        bytes.extend_from_slice(&[0u8; 40]);
1112        let result = StreamCheckpoint::from_bytes(&bytes);
1113        assert!(result.is_err());
1114        let err = result.unwrap_err().to_string();
1115        assert!(
1116            err.contains("unsupported checkpoint version: 1"),
1117            "got: {err}"
1118        );
1119    }
1120}