Skip to main content

laminar_storage/
checkpoint_batcher.rs

1//! Checkpoint batching for S3 cost optimization.
2//!
3//! Accumulates state blobs from multiple partitions/operators into a single
4//! compressed object before flushing to object storage. Targets 4-32MB object
5//! sizes to minimize PUT costs ($0.005/1000 PUTs on S3).
6//!
7//! ## Batch Format
8//!
9//! ```text
10//! [magic: 4 bytes "LCB1"]
11//! [version: u32 LE = 1]
12//! [entry_count: u32 LE]
13//! [uncompressed_size: u32 LE]
14//! [crc32c: u32 LE]           ← over compressed_body
15//! [compressed_body: LZ4 block]
16//!   → decoded body is a sequence of frames:
17//!     [key_len: u32 LE][key: UTF-8][data_len: u32 LE][data: bytes]
18//! ```
19
20use std::sync::atomic::{AtomicU64, Ordering};
21use std::sync::Arc;
22
23use object_store::{ObjectStore, PutOptions, PutPayload};
24
25use crate::checkpoint_store::CheckpointStoreError;
26
27/// Magic bytes identifying a checkpoint batch file.
28const BATCH_MAGIC: &[u8; 4] = b"LCB1";
29
30/// Current batch format version (v2 adds CRC32C integrity check).
31const BATCH_VERSION: u32 = 2;
32
33/// Default flush threshold: 8 MB.
34const DEFAULT_FLUSH_THRESHOLD: usize = 8 * 1024 * 1024;
35
36/// Batch header size in bytes (magic + version + count + size + crc32c).
37const HEADER_SIZE: usize = 20;
38
39/// A single entry in the batch buffer.
40struct BatchEntry {
41    key: String,
42    data: Vec<u8>,
43}
44
45/// Metrics for checkpoint batching operations.
46#[derive(Debug)]
47pub struct BatchMetrics {
48    /// Total batches flushed to object storage.
49    pub batches_flushed: AtomicU64,
50    /// Total entries written across all batches.
51    pub entries_flushed: AtomicU64,
52    /// Total bytes before LZ4 compression.
53    pub bytes_before_compression: AtomicU64,
54    /// Total bytes after LZ4 compression.
55    pub bytes_after_compression: AtomicU64,
56    /// Total PUT requests issued.
57    pub put_count: AtomicU64,
58}
59
60impl BatchMetrics {
61    /// Returns zeroed counters.
62    #[must_use]
63    pub fn new() -> Self {
64        Self {
65            batches_flushed: AtomicU64::new(0),
66            entries_flushed: AtomicU64::new(0),
67            bytes_before_compression: AtomicU64::new(0),
68            bytes_after_compression: AtomicU64::new(0),
69            put_count: AtomicU64::new(0),
70        }
71    }
72
73    /// Record a completed flush.
74    fn record_flush(&self, entries: u64, raw_bytes: u64, compressed_bytes: u64) {
75        self.batches_flushed.fetch_add(1, Ordering::Relaxed);
76        self.entries_flushed.fetch_add(entries, Ordering::Relaxed);
77        self.bytes_before_compression
78            .fetch_add(raw_bytes, Ordering::Relaxed);
79        self.bytes_after_compression
80            .fetch_add(compressed_bytes, Ordering::Relaxed);
81        self.put_count.fetch_add(1, Ordering::Relaxed);
82    }
83
84    /// Point-in-time snapshot of metrics.
85    #[must_use]
86    pub fn snapshot(&self) -> BatchMetricsSnapshot {
87        BatchMetricsSnapshot {
88            batches_flushed: self.batches_flushed.load(Ordering::Relaxed),
89            entries_flushed: self.entries_flushed.load(Ordering::Relaxed),
90            bytes_before_compression: self.bytes_before_compression.load(Ordering::Relaxed),
91            bytes_after_compression: self.bytes_after_compression.load(Ordering::Relaxed),
92            put_count: self.put_count.load(Ordering::Relaxed),
93        }
94    }
95}
96
97impl Default for BatchMetrics {
98    fn default() -> Self {
99        Self::new()
100    }
101}
102
103/// Immutable snapshot of [`BatchMetrics`].
104#[derive(Debug, Clone, Copy)]
105pub struct BatchMetricsSnapshot {
106    /// Total batches flushed to object storage.
107    pub batches_flushed: u64,
108    /// Total entries written across all batches.
109    pub entries_flushed: u64,
110    /// Total bytes before LZ4 compression.
111    pub bytes_before_compression: u64,
112    /// Total bytes after LZ4 compression.
113    pub bytes_after_compression: u64,
114    /// Total PUT requests issued.
115    pub put_count: u64,
116}
117
118/// Accumulates state blobs and flushes them as a single compressed object.
119///
120/// Call [`add`](Self::add) for each partition/operator state blob during a
121/// checkpoint cycle, then [`flush`](Self::flush) at the end.
122/// [`should_flush`](Self::should_flush) returns `true` when the buffer exceeds
123/// the configured threshold.
124pub struct CheckpointBatcher {
125    buffer: Vec<BatchEntry>,
126    buffer_size: usize,
127    flush_threshold: usize,
128    store: Arc<dyn ObjectStore>,
129    prefix: String,
130    rt: tokio::runtime::Runtime,
131    metrics: Arc<BatchMetrics>,
132}
133
134impl CheckpointBatcher {
135    /// Create a new batcher.
136    ///
137    /// `prefix` is prepended to all object paths (e.g., `"nodes/abc123/"`).
138    /// `flush_threshold` is the uncompressed buffer size that triggers a flush
139    /// (default: 8 MB).
140    ///
141    /// # Errors
142    ///
143    /// Returns `std::io::Error` if the internal Tokio runtime cannot be created.
144    pub fn new(
145        store: Arc<dyn ObjectStore>,
146        prefix: String,
147        flush_threshold: Option<usize>,
148    ) -> std::io::Result<Self> {
149        let rt = tokio::runtime::Builder::new_current_thread()
150            .enable_all()
151            .build()?;
152        Ok(Self {
153            buffer: Vec::new(),
154            buffer_size: 0,
155            flush_threshold: flush_threshold.unwrap_or(DEFAULT_FLUSH_THRESHOLD),
156            store,
157            prefix,
158            rt,
159            metrics: Arc::new(BatchMetrics::new()),
160        })
161    }
162
163    /// Add a state blob to the buffer.
164    ///
165    /// The `key` identifies the partition/operator (e.g., `"partition-0/agg"`).
166    /// Call [`should_flush`](Self::should_flush) after adding to check whether
167    /// the buffer exceeds the threshold.
168    pub fn add(&mut self, key: String, data: Vec<u8>) {
169        self.buffer_size += key.len() + data.len() + 8; // +8 for two u32 length prefixes
170        self.buffer.push(BatchEntry { key, data });
171    }
172
173    /// Returns `true` if the buffer exceeds the flush threshold.
174    #[must_use]
175    pub fn should_flush(&self) -> bool {
176        self.buffer_size >= self.flush_threshold
177    }
178
179    /// Returns the number of entries currently buffered.
180    #[must_use]
181    pub fn len(&self) -> usize {
182        self.buffer.len()
183    }
184
185    /// Returns `true` if the buffer is empty.
186    #[must_use]
187    pub fn is_empty(&self) -> bool {
188        self.buffer.is_empty()
189    }
190
191    /// Returns the current uncompressed buffer size in bytes.
192    #[must_use]
193    pub fn buffer_size(&self) -> usize {
194        self.buffer_size
195    }
196
197    /// Shared metrics handle.
198    #[must_use]
199    pub fn metrics(&self) -> &Arc<BatchMetrics> {
200        &self.metrics
201    }
202
203    /// Flush buffered entries as a single LZ4-compressed object.
204    ///
205    /// The object is written to `{prefix}checkpoints/batch-{epoch:06}.lz4`.
206    /// After a successful flush the buffer is cleared.
207    ///
208    /// Does nothing if the buffer is empty.
209    ///
210    /// # Errors
211    ///
212    /// Returns [`CheckpointStoreError`] on serialization or object store failure.
213    pub fn flush(&mut self, epoch: u64) -> Result<(), CheckpointStoreError> {
214        if self.buffer.is_empty() {
215            return Ok(());
216        }
217
218        let (raw_size, payload) = encode_batch(&self.buffer);
219
220        let path = object_store::path::Path::from(format!(
221            "{}checkpoints/batch-{epoch:06}.lz4",
222            self.prefix
223        ));
224
225        let compressed_size = payload.content_length();
226
227        self.rt.block_on(async {
228            self.store
229                .put_opts(&path, payload, PutOptions::default())
230                .await
231        })?;
232
233        let entry_count = self.buffer.len() as u64;
234        self.metrics
235            .record_flush(entry_count, raw_size as u64, compressed_size as u64);
236
237        self.buffer.clear();
238        self.buffer_size = 0;
239
240        Ok(())
241    }
242}
243
244/// Encode buffered entries into an LZ4-compressed batch payload.
245///
246/// Returns `(uncompressed_body_size, payload)`.
247#[allow(clippy::cast_possible_truncation)] // Entry counts/sizes are bounded well below u32::MAX
248fn encode_batch(entries: &[BatchEntry]) -> (usize, PutPayload) {
249    // Serialize entries into an uncompressed body.
250    let mut body = Vec::new();
251    for entry in entries {
252        body.extend_from_slice(&(entry.key.len() as u32).to_le_bytes());
253        body.extend_from_slice(entry.key.as_bytes());
254        body.extend_from_slice(&(entry.data.len() as u32).to_le_bytes());
255        body.extend_from_slice(&entry.data);
256    }
257
258    let uncompressed_size = body.len();
259    let compressed = lz4_flex::compress_prepend_size(&body);
260
261    // CRC32C over the compressed body for integrity verification
262    let crc = crc32c::crc32c(&compressed);
263
264    // Build header + compressed body.
265    let mut out = Vec::with_capacity(HEADER_SIZE + compressed.len());
266    out.extend_from_slice(BATCH_MAGIC);
267    out.extend_from_slice(&BATCH_VERSION.to_le_bytes());
268    out.extend_from_slice(&(entries.len() as u32).to_le_bytes());
269    out.extend_from_slice(&(uncompressed_size as u32).to_le_bytes());
270    out.extend_from_slice(&crc.to_le_bytes());
271    out.extend_from_slice(&compressed);
272
273    (
274        uncompressed_size,
275        PutPayload::from_bytes(bytes::Bytes::from(out)),
276    )
277}
278
279/// Decode a batch payload into `(key, data)` pairs.
280///
281/// # Errors
282///
283/// Returns [`CheckpointStoreError::Io`] if the batch is malformed.
284#[allow(clippy::cast_possible_truncation)] // u32→usize is always safe (widens on 64-bit)
285pub fn decode_batch(raw: &[u8]) -> Result<Vec<(String, Vec<u8>)>, CheckpointStoreError> {
286    if raw.len() < HEADER_SIZE {
287        return Err(CheckpointStoreError::Io(std::io::Error::new(
288            std::io::ErrorKind::InvalidData,
289            "batch too short for header",
290        )));
291    }
292
293    if &raw[..4] != BATCH_MAGIC {
294        return Err(CheckpointStoreError::Io(std::io::Error::new(
295            std::io::ErrorKind::InvalidData,
296            "invalid batch magic",
297        )));
298    }
299
300    let version = u32::from_le_bytes([raw[4], raw[5], raw[6], raw[7]]);
301
302    // Determine header size and whether CRC is present based on version
303    let (header_size, has_crc) = match version {
304        1 => (16, false), // v1: no CRC field
305        2 => (20, true),  // v2: CRC32C after size field
306        _ => {
307            return Err(CheckpointStoreError::Io(std::io::Error::new(
308                std::io::ErrorKind::InvalidData,
309                format!("unsupported batch version {version}"),
310            )));
311        }
312    };
313
314    if raw.len() < header_size {
315        return Err(CheckpointStoreError::Io(std::io::Error::new(
316            std::io::ErrorKind::InvalidData,
317            format!("batch too short for v{version} header"),
318        )));
319    }
320
321    let entry_count = u32::from_le_bytes([raw[8], raw[9], raw[10], raw[11]]) as usize;
322    let _uncompressed_size = u32::from_le_bytes([raw[12], raw[13], raw[14], raw[15]]);
323
324    let compressed_body = &raw[header_size..];
325
326    // Verify CRC32C if present (v2+)
327    if has_crc {
328        let expected_crc = u32::from_le_bytes([raw[16], raw[17], raw[18], raw[19]]);
329        let actual_crc = crc32c::crc32c(compressed_body);
330        if actual_crc != expected_crc {
331            return Err(CheckpointStoreError::Io(std::io::Error::new(
332                std::io::ErrorKind::InvalidData,
333                format!(
334                    "batch CRC32C mismatch: expected {expected_crc:#010x}, \
335                     actual {actual_crc:#010x}"
336                ),
337            )));
338        }
339    }
340
341    let body = lz4_flex::decompress_size_prepended(compressed_body).map_err(|e| {
342        CheckpointStoreError::Io(std::io::Error::new(
343            std::io::ErrorKind::InvalidData,
344            format!("LZ4 decompression failed: {e}"),
345        ))
346    })?;
347
348    let mut entries = Vec::with_capacity(entry_count);
349    let mut cursor = 0;
350
351    for _ in 0..entry_count {
352        if cursor + 4 > body.len() {
353            return Err(CheckpointStoreError::Io(std::io::Error::new(
354                std::io::ErrorKind::UnexpectedEof,
355                "truncated batch entry (key length)",
356            )));
357        }
358        let key_len = u32::from_le_bytes([
359            body[cursor],
360            body[cursor + 1],
361            body[cursor + 2],
362            body[cursor + 3],
363        ]) as usize;
364        cursor += 4;
365
366        if cursor + key_len > body.len() {
367            return Err(CheckpointStoreError::Io(std::io::Error::new(
368                std::io::ErrorKind::UnexpectedEof,
369                "truncated batch entry (key data)",
370            )));
371        }
372        let key = String::from_utf8_lossy(&body[cursor..cursor + key_len]).into_owned();
373        cursor += key_len;
374
375        if cursor + 4 > body.len() {
376            return Err(CheckpointStoreError::Io(std::io::Error::new(
377                std::io::ErrorKind::UnexpectedEof,
378                "truncated batch entry (data length)",
379            )));
380        }
381        let data_len = u32::from_le_bytes([
382            body[cursor],
383            body[cursor + 1],
384            body[cursor + 2],
385            body[cursor + 3],
386        ]) as usize;
387        cursor += 4;
388
389        if cursor + data_len > body.len() {
390            return Err(CheckpointStoreError::Io(std::io::Error::new(
391                std::io::ErrorKind::UnexpectedEof,
392                "truncated batch entry (data)",
393            )));
394        }
395        let data = body[cursor..cursor + data_len].to_vec();
396        cursor += data_len;
397
398        entries.push((key, data));
399    }
400
401    Ok(entries)
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407    use object_store::memory::InMemory;
408
409    fn make_batcher(threshold: usize) -> (CheckpointBatcher, Arc<dyn ObjectStore>) {
410        let store: Arc<dyn ObjectStore> = Arc::new(InMemory::new());
411        let batcher =
412            CheckpointBatcher::new(store.clone(), String::new(), Some(threshold)).unwrap();
413        (batcher, store)
414    }
415
416    #[test]
417    fn test_add_tracks_size() {
418        let (mut batcher, _store) = make_batcher(1024);
419        assert!(batcher.is_empty());
420        assert_eq!(batcher.buffer_size(), 0);
421
422        batcher.add("key1".into(), vec![0u8; 100]);
423        assert_eq!(batcher.len(), 1);
424        assert!(!batcher.is_empty());
425        // key(4) + data(100) + 8 bytes overhead = 112
426        assert_eq!(batcher.buffer_size(), 112);
427    }
428
429    #[test]
430    fn test_should_flush_at_threshold() {
431        let (mut batcher, _store) = make_batcher(200);
432        assert!(!batcher.should_flush());
433
434        batcher.add("k".into(), vec![0u8; 100]);
435        assert!(!batcher.should_flush());
436
437        batcher.add("k".into(), vec![0u8; 100]);
438        assert!(batcher.should_flush());
439    }
440
441    #[test]
442    fn test_flush_empty_is_noop() {
443        let (mut batcher, _store) = make_batcher(1024);
444        batcher.flush(1).unwrap();
445        let snap = batcher.metrics().snapshot();
446        assert_eq!(snap.batches_flushed, 0);
447        assert_eq!(snap.put_count, 0);
448    }
449
450    #[test]
451    fn test_flush_writes_object() {
452        let (mut batcher, store) = make_batcher(1024 * 1024);
453
454        batcher.add("partition-0/agg".into(), vec![42u8; 256]);
455        batcher.add("partition-1/agg".into(), vec![99u8; 128]);
456        batcher.flush(7).unwrap();
457
458        assert!(batcher.is_empty());
459        assert_eq!(batcher.buffer_size(), 0);
460
461        // Verify object exists at expected path
462        let rt = tokio::runtime::Builder::new_current_thread()
463            .enable_all()
464            .build()
465            .unwrap();
466        let result = rt.block_on(async {
467            store
468                .get_opts(
469                    &object_store::path::Path::from("checkpoints/batch-000007.lz4"),
470                    object_store::GetOptions::default(),
471                )
472                .await
473        });
474        assert!(result.is_ok());
475    }
476
477    #[test]
478    fn test_lz4_roundtrip() {
479        let (mut batcher, store) = make_batcher(1024 * 1024);
480
481        let entries = vec![
482            ("partition-0/state".to_string(), vec![1u8; 500]),
483            ("partition-1/state".to_string(), vec![2u8; 300]),
484            ("partition-2/agg".to_string(), vec![3u8; 200]),
485        ];
486
487        for (k, v) in &entries {
488            batcher.add(k.clone(), v.clone());
489        }
490        batcher.flush(42).unwrap();
491
492        // Read back and decode
493        let rt = tokio::runtime::Builder::new_current_thread()
494            .enable_all()
495            .build()
496            .unwrap();
497        let data = rt.block_on(async {
498            let result = store
499                .get_opts(
500                    &object_store::path::Path::from("checkpoints/batch-000042.lz4"),
501                    object_store::GetOptions::default(),
502                )
503                .await
504                .unwrap();
505            result.bytes().await.unwrap()
506        });
507
508        let decoded = decode_batch(&data).unwrap();
509        assert_eq!(decoded.len(), 3);
510        assert_eq!(decoded[0].0, "partition-0/state");
511        assert_eq!(decoded[0].1, vec![1u8; 500]);
512        assert_eq!(decoded[1].0, "partition-1/state");
513        assert_eq!(decoded[1].1, vec![2u8; 300]);
514        assert_eq!(decoded[2].0, "partition-2/agg");
515        assert_eq!(decoded[2].1, vec![3u8; 200]);
516    }
517
518    #[test]
519    fn test_metrics_recorded_on_flush() {
520        let (mut batcher, _store) = make_batcher(1024 * 1024);
521
522        batcher.add("k1".into(), vec![0u8; 100]);
523        batcher.add("k2".into(), vec![0u8; 200]);
524        batcher.flush(1).unwrap();
525
526        let snap = batcher.metrics().snapshot();
527        assert_eq!(snap.batches_flushed, 1);
528        assert_eq!(snap.entries_flushed, 2);
529        assert_eq!(snap.put_count, 1);
530        assert!(snap.bytes_before_compression > 0);
531        assert!(snap.bytes_after_compression > 0);
532    }
533
534    #[test]
535    fn test_metrics_accumulate_across_flushes() {
536        let (mut batcher, _store) = make_batcher(1024 * 1024);
537
538        batcher.add("k1".into(), vec![0u8; 100]);
539        batcher.flush(1).unwrap();
540
541        batcher.add("k2".into(), vec![0u8; 200]);
542        batcher.add("k3".into(), vec![0u8; 50]);
543        batcher.flush(2).unwrap();
544
545        let snap = batcher.metrics().snapshot();
546        assert_eq!(snap.batches_flushed, 2);
547        assert_eq!(snap.entries_flushed, 3);
548        assert_eq!(snap.put_count, 2);
549    }
550
551    #[test]
552    fn test_compression_reduces_size() {
553        let (mut batcher, _store) = make_batcher(1024 * 1024);
554
555        // Highly compressible data (all zeros)
556        batcher.add("big".into(), vec![0u8; 10_000]);
557        batcher.flush(1).unwrap();
558
559        let snap = batcher.metrics().snapshot();
560        assert!(
561            snap.bytes_after_compression < snap.bytes_before_compression,
562            "compressed ({}) should be smaller than raw ({})",
563            snap.bytes_after_compression,
564            snap.bytes_before_compression
565        );
566    }
567
568    #[test]
569    fn test_decode_invalid_magic() {
570        let bad = b"XXXX\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
571        let err = decode_batch(bad).unwrap_err();
572        assert!(err.to_string().contains("invalid batch magic"));
573    }
574
575    #[test]
576    fn test_decode_too_short() {
577        let err = decode_batch(b"LCB").unwrap_err();
578        assert!(err.to_string().contains("too short"));
579    }
580
581    #[test]
582    fn test_decode_bad_version() {
583        let mut buf = Vec::new();
584        buf.extend_from_slice(b"LCB1");
585        buf.extend_from_slice(&99u32.to_le_bytes()); // bad version
586        buf.extend_from_slice(&0u32.to_le_bytes()); // count
587        buf.extend_from_slice(&0u32.to_le_bytes()); // size
588        buf.extend_from_slice(&0u32.to_le_bytes()); // crc
589        let err = decode_batch(&buf).unwrap_err();
590        assert!(err.to_string().contains("unsupported batch version"));
591    }
592
593    #[test]
594    fn test_flush_clears_buffer() {
595        let (mut batcher, _store) = make_batcher(64);
596
597        batcher.add("a".into(), vec![0u8; 50]);
598        batcher.add("b".into(), vec![0u8; 50]);
599        assert!(batcher.should_flush());
600
601        batcher.flush(1).unwrap();
602        assert!(!batcher.should_flush());
603        assert!(batcher.is_empty());
604        assert_eq!(batcher.buffer_size(), 0);
605    }
606}