Skip to main content

laminar_storage/checkpoint/
checkpointer.rs

1//! Async checkpoint persistence via object stores.
2//!
3//! The `Checkpointer` trait abstracts checkpoint I/O so that the
4//! checkpoint coordinator doesn't need to know whether state is persisted
5//! to local disk, S3, GCS, or Azure Blob.
6//!
7//! `ObjectStoreCheckpointer` is the production implementation that
8//! writes to any `object_store::ObjectStore` backend with:
9//! - Concurrent partition uploads via `JoinSet`
10//! - SHA-256 integrity digests
11//! - Exponential backoff retry on transient errors
12
13use std::sync::Arc;
14
15use async_trait::async_trait;
16use bytes::Bytes;
17use futures::StreamExt;
18use object_store::path::Path;
19use object_store::{GetOptions, ObjectStore, PutOptions, PutPayload};
20use sha2::{Digest, Sha256};
21use tokio::task::JoinSet;
22
23use super::layout::{CheckpointId, CheckpointManifestV2, CheckpointPaths, PartitionSnapshotEntry};
24
25/// Errors from checkpoint persistence operations.
26#[derive(Debug, thiserror::Error)]
27pub enum CheckpointerError {
28    /// Object store I/O error.
29    #[error("object store error: {0}")]
30    ObjectStore(#[from] object_store::Error),
31
32    /// JSON serialization/deserialization error.
33    #[error("serialization error: {0}")]
34    Serialization(#[from] serde_json::Error),
35
36    /// SHA-256 integrity check failed.
37    #[error("integrity check failed for {path}: expected {expected}, got {actual}")]
38    IntegrityMismatch {
39        /// Object path.
40        path: String,
41        /// Expected SHA-256 hex digest.
42        expected: String,
43        /// Actual SHA-256 hex digest.
44        actual: String,
45    },
46
47    /// A concurrent upload task failed.
48    #[error("upload task failed: {0}")]
49    JoinError(#[from] tokio::task::JoinError),
50}
51
52/// Async trait for checkpoint persistence operations.
53///
54/// Implementations handle writing/reading checkpoint artifacts to
55/// durable storage. The checkpoint coordinator calls these methods
56/// during the checkpoint commit protocol.
57#[async_trait]
58pub trait Checkpointer: Send + Sync {
59    /// Write a manifest to the checkpoint store.
60    async fn save_manifest(&self, manifest: &CheckpointManifestV2)
61        -> Result<(), CheckpointerError>;
62
63    /// Load a manifest by checkpoint ID.
64    async fn load_manifest(
65        &self,
66        id: &CheckpointId,
67    ) -> Result<CheckpointManifestV2, CheckpointerError>;
68
69    /// Write a state snapshot for a single operator partition.
70    ///
71    /// Returns the SHA-256 hex digest of the written data.
72    async fn save_snapshot(
73        &self,
74        id: &CheckpointId,
75        operator: &str,
76        partition: u32,
77        data: Bytes,
78    ) -> Result<String, CheckpointerError>;
79
80    /// Write an incremental delta for a single operator partition.
81    ///
82    /// Returns the SHA-256 hex digest of the written data.
83    async fn save_delta(
84        &self,
85        id: &CheckpointId,
86        operator: &str,
87        partition: u32,
88        data: Bytes,
89    ) -> Result<String, CheckpointerError>;
90
91    /// Load a snapshot or delta by path.
92    async fn load_artifact(&self, path: &str) -> Result<Bytes, CheckpointerError>;
93
94    /// Update the `_latest` pointer to the given checkpoint.
95    async fn update_latest(&self, id: &CheckpointId) -> Result<(), CheckpointerError>;
96
97    /// Read the `_latest` pointer to find the most recent checkpoint.
98    async fn read_latest(&self) -> Result<Option<CheckpointId>, CheckpointerError>;
99
100    /// List all checkpoint IDs (sorted chronologically, oldest first).
101    async fn list_checkpoints(&self) -> Result<Vec<CheckpointId>, CheckpointerError>;
102
103    /// Delete a checkpoint and all its artifacts.
104    async fn delete_checkpoint(&self, id: &CheckpointId) -> Result<(), CheckpointerError>;
105}
106
107/// Production [`Checkpointer`] backed by an [`ObjectStore`].
108///
109/// Supports concurrent partition uploads and SHA-256 integrity verification.
110pub struct ObjectStoreCheckpointer {
111    /// The underlying object store.
112    store: Arc<dyn ObjectStore>,
113    /// Path generator for checkpoint artifacts.
114    paths: CheckpointPaths,
115    /// Maximum number of concurrent uploads.
116    max_concurrent_uploads: usize,
117}
118
119impl ObjectStoreCheckpointer {
120    /// Create a new checkpointer.
121    #[must_use]
122    pub fn new(
123        store: Arc<dyn ObjectStore>,
124        paths: CheckpointPaths,
125        max_concurrent_uploads: usize,
126    ) -> Self {
127        Self {
128            store,
129            paths,
130            max_concurrent_uploads,
131        }
132    }
133
134    /// Write data to a path with exponential backoff retry.
135    async fn put_with_retry(
136        store: &dyn ObjectStore,
137        path: &Path,
138        data: PutPayload,
139    ) -> Result<(), CheckpointerError> {
140        let op = || async {
141            store
142                .put_opts(path, data.clone(), PutOptions::default())
143                .await
144                .map_err(|e| match &e {
145                    object_store::Error::Generic { .. } => {
146                        backoff::Error::transient(CheckpointerError::ObjectStore(e))
147                    }
148                    _ => backoff::Error::permanent(CheckpointerError::ObjectStore(e)),
149                })?;
150            Ok(())
151        };
152
153        let backoff = backoff::ExponentialBackoffBuilder::new()
154            .with_max_elapsed_time(Some(std::time::Duration::from_secs(30)))
155            .build();
156
157        backoff::future::retry(backoff, op).await
158    }
159
160    /// Compute SHA-256 hex digest of data.
161    fn sha256_hex(data: &[u8]) -> String {
162        let mut hasher = Sha256::new();
163        hasher.update(data);
164        format!("{:x}", hasher.finalize())
165    }
166
167    /// Write data to a path and return its SHA-256 digest.
168    async fn write_with_digest(
169        &self,
170        path_str: &str,
171        data: Bytes,
172    ) -> Result<String, CheckpointerError> {
173        let digest = Self::sha256_hex(&data);
174        let path = Path::from(path_str);
175        let payload = PutPayload::from_bytes(data);
176        Self::put_with_retry(self.store.as_ref(), &path, payload).await?;
177        Ok(digest)
178    }
179
180    /// Save multiple operator partition snapshots concurrently.
181    ///
182    /// Returns a map of `(operator, partition) -> PartitionSnapshotEntry`.
183    ///
184    /// # Errors
185    ///
186    /// Returns [`CheckpointerError`] if any upload or join fails.
187    pub async fn save_partitions_concurrent(
188        &self,
189        id: &CheckpointId,
190        snapshots: Vec<(String, u32, bool, Bytes)>,
191    ) -> Result<Vec<(String, PartitionSnapshotEntry)>, CheckpointerError> {
192        let mut join_set = JoinSet::new();
193        let store = Arc::clone(&self.store);
194
195        for (operator, partition, is_delta, data) in snapshots {
196            let path_str = if is_delta {
197                self.paths.delta(id, &operator, partition)
198            } else {
199                self.paths.snapshot(id, &operator, partition)
200            };
201
202            let store = Arc::clone(&store);
203            let data_len = data.len() as u64;
204
205            // Limit concurrency by awaiting when at max
206            if join_set.len() >= self.max_concurrent_uploads {
207                if let Some(result) = join_set.join_next().await {
208                    result??;
209                }
210            }
211
212            join_set.spawn(async move {
213                let digest = {
214                    let mut hasher = Sha256::new();
215                    hasher.update(&data);
216                    format!("{:x}", hasher.finalize())
217                };
218                let path = Path::from(path_str.as_str());
219                let payload = PutPayload::from_bytes(data);
220                Self::put_with_retry(store.as_ref(), &path, payload).await?;
221
222                Ok::<_, CheckpointerError>((
223                    operator,
224                    PartitionSnapshotEntry {
225                        partition_id: partition,
226                        is_delta,
227                        path: path_str,
228                        size_bytes: data_len,
229                        sha256: Some(digest),
230                    },
231                ))
232            });
233        }
234
235        // Collect remaining results
236        let mut entries = Vec::new();
237        while let Some(result) = join_set.join_next().await {
238            entries.push(result??);
239        }
240
241        Ok(entries)
242    }
243}
244
245#[async_trait]
246impl Checkpointer for ObjectStoreCheckpointer {
247    async fn save_manifest(
248        &self,
249        manifest: &CheckpointManifestV2,
250    ) -> Result<(), CheckpointerError> {
251        let json = serde_json::to_vec_pretty(manifest)?;
252        let path_str = self.paths.manifest(&manifest.checkpoint_id);
253        let path = Path::from(path_str.as_str());
254        let payload = PutPayload::from_bytes(Bytes::from(json));
255        Self::put_with_retry(self.store.as_ref(), &path, payload).await
256    }
257
258    async fn load_manifest(
259        &self,
260        id: &CheckpointId,
261    ) -> Result<CheckpointManifestV2, CheckpointerError> {
262        let path_str = self.paths.manifest(id);
263        let path = Path::from(path_str.as_str());
264        let result = self.store.get_opts(&path, GetOptions::default()).await?;
265        let data = result.bytes().await?;
266        let manifest: CheckpointManifestV2 = serde_json::from_slice(&data)?;
267        Ok(manifest)
268    }
269
270    async fn save_snapshot(
271        &self,
272        id: &CheckpointId,
273        operator: &str,
274        partition: u32,
275        data: Bytes,
276    ) -> Result<String, CheckpointerError> {
277        let path_str = self.paths.snapshot(id, operator, partition);
278        self.write_with_digest(&path_str, data).await
279    }
280
281    async fn save_delta(
282        &self,
283        id: &CheckpointId,
284        operator: &str,
285        partition: u32,
286        data: Bytes,
287    ) -> Result<String, CheckpointerError> {
288        let path_str = self.paths.delta(id, operator, partition);
289        self.write_with_digest(&path_str, data).await
290    }
291
292    async fn load_artifact(&self, path_str: &str) -> Result<Bytes, CheckpointerError> {
293        let path = Path::from(path_str);
294        let result = self.store.get_opts(&path, GetOptions::default()).await?;
295        let data = result.bytes().await?;
296        Ok(data)
297    }
298
299    async fn update_latest(&self, id: &CheckpointId) -> Result<(), CheckpointerError> {
300        let path_str = self.paths.latest_pointer();
301        let path = Path::from(path_str.as_str());
302        let payload = PutPayload::from_bytes(Bytes::from(id.to_string_id()));
303        Self::put_with_retry(self.store.as_ref(), &path, payload).await
304    }
305
306    async fn read_latest(&self) -> Result<Option<CheckpointId>, CheckpointerError> {
307        let path_str = self.paths.latest_pointer();
308        let path = Path::from(path_str.as_str());
309        match self.store.get_opts(&path, GetOptions::default()).await {
310            Ok(result) => {
311                let data = result.bytes().await?;
312                let id_str =
313                    std::str::from_utf8(&data).map_err(|e| object_store::Error::Generic {
314                        store: "checkpointer",
315                        source: Box::new(e),
316                    })?;
317                let uuid =
318                    uuid::Uuid::parse_str(id_str).map_err(|e| object_store::Error::Generic {
319                        store: "checkpointer",
320                        source: Box::new(e),
321                    })?;
322                Ok(Some(CheckpointId::from_uuid(uuid)))
323            }
324            Err(object_store::Error::NotFound { .. }) => Ok(None),
325            Err(e) => Err(CheckpointerError::ObjectStore(e)),
326        }
327    }
328
329    async fn list_checkpoints(&self) -> Result<Vec<CheckpointId>, CheckpointerError> {
330        // List objects at the base prefix — each checkpoint is a directory.
331        // We look for manifest.json files to identify checkpoints.
332        //
333        // Manifest path format: {base_prefix}{UUID}manifest.json
334        // (no slash between UUID and "manifest.json", see layout.rs:114)
335        let prefix = Path::from(self.paths.latest_pointer().trim_end_matches("_latest"));
336        let base_prefix = &self.paths.base_prefix;
337        let mut ids = Vec::new();
338
339        let mut stream = self.store.list(Some(&prefix));
340        while let Some(meta) = stream.next().await {
341            let meta = meta?;
342            let path_str = meta.location.to_string();
343            if path_str.ends_with("manifest.json") {
344                // Extract checkpoint ID: strip suffix then strip base prefix.
345                // Path is "{base_prefix}{UUID}manifest.json".
346                // After stripping suffix: "{base_prefix}{UUID}".
347                // After stripping prefix: "{UUID}".
348                if let Some(remainder) = path_str.strip_suffix("manifest.json") {
349                    let id_str = remainder
350                        .strip_prefix(base_prefix.as_str())
351                        .unwrap_or(remainder);
352                    let id_str = id_str.trim_end_matches('/');
353                    if !id_str.is_empty() {
354                        if let Ok(uuid) = uuid::Uuid::parse_str(id_str) {
355                            ids.push(CheckpointId::from_uuid(uuid));
356                        }
357                    }
358                }
359            }
360        }
361
362        ids.sort();
363        Ok(ids)
364    }
365
366    async fn delete_checkpoint(&self, id: &CheckpointId) -> Result<(), CheckpointerError> {
367        // List all objects under this checkpoint's directory and delete them
368        let dir = self.paths.checkpoint_dir(id);
369        let prefix = Path::from(dir.as_str());
370
371        let mut paths_to_delete = Vec::new();
372        let mut stream = self.store.list(Some(&prefix));
373        while let Some(meta) = stream.next().await {
374            let meta = meta?;
375            paths_to_delete.push(meta.location);
376        }
377
378        // Delete using delete_stream (object-safe, no ObjectStoreExt needed)
379        let locations_stream = futures::stream::iter(paths_to_delete.into_iter().map(Ok)).boxed();
380        let mut results = self.store.delete_stream(locations_stream);
381        while let Some(result) = results.next().await {
382            result?;
383        }
384
385        Ok(())
386    }
387}
388
389/// Verify a loaded artifact against its expected SHA-256 digest.
390///
391/// # Errors
392///
393/// Returns [`CheckpointerError::IntegrityMismatch`] if the digest doesn't match.
394pub fn verify_integrity(
395    path: &str,
396    data: &[u8],
397    expected_sha256: &str,
398) -> Result<(), CheckpointerError> {
399    let actual = ObjectStoreCheckpointer::sha256_hex(data);
400    if actual != expected_sha256 {
401        return Err(CheckpointerError::IntegrityMismatch {
402            path: path.to_string(),
403            expected: expected_sha256.to_string(),
404            actual,
405        });
406    }
407    Ok(())
408}
409
410#[cfg(test)]
411mod tests {
412    use super::super::layout::OperatorSnapshotEntry;
413    use super::*;
414    use object_store::memory::InMemory;
415
416    fn make_checkpointer() -> ObjectStoreCheckpointer {
417        let store = Arc::new(InMemory::new());
418        let paths = CheckpointPaths::default();
419        ObjectStoreCheckpointer::new(store, paths, 4)
420    }
421
422    #[tokio::test]
423    async fn test_save_and_load_manifest() {
424        let ckpt = make_checkpointer();
425        let id = CheckpointId::now();
426        let manifest = CheckpointManifestV2::new(id, 1);
427
428        ckpt.save_manifest(&manifest).await.unwrap();
429        let loaded = ckpt.load_manifest(&id).await.unwrap();
430
431        assert_eq!(loaded.checkpoint_id, id);
432        assert_eq!(loaded.epoch, 1);
433    }
434
435    #[tokio::test]
436    async fn test_save_snapshot_with_digest() {
437        let ckpt = make_checkpointer();
438        let id = CheckpointId::now();
439        let data = Bytes::from_static(b"hello world");
440
441        let digest = ckpt
442            .save_snapshot(&id, "window-agg", 0, data.clone())
443            .await
444            .unwrap();
445
446        // Verify digest
447        assert!(!digest.is_empty());
448        verify_integrity("test", &data, &digest).unwrap();
449    }
450
451    #[tokio::test]
452    async fn test_load_artifact() {
453        let ckpt = make_checkpointer();
454        let id = CheckpointId::now();
455        let data = Bytes::from_static(b"partition state");
456
457        ckpt.save_snapshot(&id, "op1", 0, data.clone())
458            .await
459            .unwrap();
460
461        let path = ckpt.paths.snapshot(&id, "op1", 0);
462        let loaded = ckpt.load_artifact(&path).await.unwrap();
463        assert_eq!(loaded, data);
464    }
465
466    #[tokio::test]
467    async fn test_latest_pointer() {
468        let ckpt = make_checkpointer();
469
470        // No latest yet
471        assert!(ckpt.read_latest().await.unwrap().is_none());
472
473        let id = CheckpointId::now();
474        ckpt.update_latest(&id).await.unwrap();
475
476        let latest = ckpt.read_latest().await.unwrap().unwrap();
477        assert_eq!(latest, id);
478    }
479
480    #[tokio::test]
481    async fn test_concurrent_partition_uploads() {
482        let ckpt = make_checkpointer();
483        let id = CheckpointId::now();
484
485        let snapshots = vec![
486            ("op1".into(), 0, false, Bytes::from_static(b"part0")),
487            ("op1".into(), 1, false, Bytes::from_static(b"part1")),
488            ("op1".into(), 2, true, Bytes::from_static(b"delta2")),
489        ];
490
491        let entries = ckpt
492            .save_partitions_concurrent(&id, snapshots)
493            .await
494            .unwrap();
495
496        assert_eq!(entries.len(), 3);
497        for (_, entry) in &entries {
498            assert!(entry.sha256.is_some());
499        }
500    }
501
502    #[tokio::test]
503    async fn test_integrity_mismatch() {
504        let result = verify_integrity("test.snap", b"data", "wrong_hash");
505        assert!(result.is_err());
506        assert!(matches!(
507            result.unwrap_err(),
508            CheckpointerError::IntegrityMismatch { .. }
509        ));
510    }
511
512    #[tokio::test]
513    async fn test_save_and_build_manifest() {
514        let ckpt = make_checkpointer();
515        let id = CheckpointId::now();
516
517        // Save partitions
518        let snapshots = vec![
519            ("op1".into(), 0, false, Bytes::from_static(b"state0")),
520            ("op1".into(), 1, false, Bytes::from_static(b"state1")),
521        ];
522        let entries = ckpt
523            .save_partitions_concurrent(&id, snapshots)
524            .await
525            .unwrap();
526
527        // Build manifest from entries
528        let mut manifest = CheckpointManifestV2::new(id, 5);
529        let mut op_entry = OperatorSnapshotEntry {
530            partitions: Vec::new(),
531            total_bytes: 0,
532        };
533        for (_, part) in entries {
534            op_entry.total_bytes += part.size_bytes;
535            op_entry.partitions.push(part);
536        }
537        manifest.operators.insert("op1".into(), op_entry);
538
539        ckpt.save_manifest(&manifest).await.unwrap();
540        ckpt.update_latest(&id).await.unwrap();
541
542        // Verify round-trip
543        let loaded = ckpt.load_manifest(&id).await.unwrap();
544        assert_eq!(loaded.operators.len(), 1);
545        assert_eq!(loaded.operators["op1"].partitions.len(), 2);
546    }
547}