1use std::sync::Arc;
14
15use async_trait::async_trait;
16use bytes::Bytes;
17use futures::StreamExt;
18use object_store::path::Path;
19use object_store::{GetOptions, ObjectStore, PutOptions, PutPayload};
20use sha2::{Digest, Sha256};
21use tokio::task::JoinSet;
22
23use super::layout::{CheckpointId, CheckpointManifestV2, CheckpointPaths, PartitionSnapshotEntry};
24
25#[derive(Debug, thiserror::Error)]
27pub enum CheckpointerError {
28 #[error("object store error: {0}")]
30 ObjectStore(#[from] object_store::Error),
31
32 #[error("serialization error: {0}")]
34 Serialization(#[from] serde_json::Error),
35
36 #[error("integrity check failed for {path}: expected {expected}, got {actual}")]
38 IntegrityMismatch {
39 path: String,
41 expected: String,
43 actual: String,
45 },
46
47 #[error("upload task failed: {0}")]
49 JoinError(#[from] tokio::task::JoinError),
50}
51
52#[async_trait]
58pub trait Checkpointer: Send + Sync {
59 async fn save_manifest(&self, manifest: &CheckpointManifestV2)
61 -> Result<(), CheckpointerError>;
62
63 async fn load_manifest(
65 &self,
66 id: &CheckpointId,
67 ) -> Result<CheckpointManifestV2, CheckpointerError>;
68
69 async fn save_snapshot(
73 &self,
74 id: &CheckpointId,
75 operator: &str,
76 partition: u32,
77 data: Bytes,
78 ) -> Result<String, CheckpointerError>;
79
80 async fn save_delta(
84 &self,
85 id: &CheckpointId,
86 operator: &str,
87 partition: u32,
88 data: Bytes,
89 ) -> Result<String, CheckpointerError>;
90
91 async fn load_artifact(&self, path: &str) -> Result<Bytes, CheckpointerError>;
93
94 async fn update_latest(&self, id: &CheckpointId) -> Result<(), CheckpointerError>;
96
97 async fn read_latest(&self) -> Result<Option<CheckpointId>, CheckpointerError>;
99
100 async fn list_checkpoints(&self) -> Result<Vec<CheckpointId>, CheckpointerError>;
102
103 async fn delete_checkpoint(&self, id: &CheckpointId) -> Result<(), CheckpointerError>;
105}
106
107pub struct ObjectStoreCheckpointer {
111 store: Arc<dyn ObjectStore>,
113 paths: CheckpointPaths,
115 max_concurrent_uploads: usize,
117}
118
119impl ObjectStoreCheckpointer {
120 #[must_use]
122 pub fn new(
123 store: Arc<dyn ObjectStore>,
124 paths: CheckpointPaths,
125 max_concurrent_uploads: usize,
126 ) -> Self {
127 Self {
128 store,
129 paths,
130 max_concurrent_uploads,
131 }
132 }
133
134 async fn put_with_retry(
136 store: &dyn ObjectStore,
137 path: &Path,
138 data: PutPayload,
139 ) -> Result<(), CheckpointerError> {
140 let op = || async {
141 store
142 .put_opts(path, data.clone(), PutOptions::default())
143 .await
144 .map_err(|e| match &e {
145 object_store::Error::Generic { .. } => {
146 backoff::Error::transient(CheckpointerError::ObjectStore(e))
147 }
148 _ => backoff::Error::permanent(CheckpointerError::ObjectStore(e)),
149 })?;
150 Ok(())
151 };
152
153 let backoff = backoff::ExponentialBackoffBuilder::new()
154 .with_max_elapsed_time(Some(std::time::Duration::from_secs(30)))
155 .build();
156
157 backoff::future::retry(backoff, op).await
158 }
159
160 fn sha256_hex(data: &[u8]) -> String {
162 let mut hasher = Sha256::new();
163 hasher.update(data);
164 format!("{:x}", hasher.finalize())
165 }
166
167 async fn write_with_digest(
169 &self,
170 path_str: &str,
171 data: Bytes,
172 ) -> Result<String, CheckpointerError> {
173 let digest = Self::sha256_hex(&data);
174 let path = Path::from(path_str);
175 let payload = PutPayload::from_bytes(data);
176 Self::put_with_retry(self.store.as_ref(), &path, payload).await?;
177 Ok(digest)
178 }
179
180 pub async fn save_partitions_concurrent(
188 &self,
189 id: &CheckpointId,
190 snapshots: Vec<(String, u32, bool, Bytes)>,
191 ) -> Result<Vec<(String, PartitionSnapshotEntry)>, CheckpointerError> {
192 let mut join_set = JoinSet::new();
193 let store = Arc::clone(&self.store);
194
195 for (operator, partition, is_delta, data) in snapshots {
196 let path_str = if is_delta {
197 self.paths.delta(id, &operator, partition)
198 } else {
199 self.paths.snapshot(id, &operator, partition)
200 };
201
202 let store = Arc::clone(&store);
203 let data_len = data.len() as u64;
204
205 if join_set.len() >= self.max_concurrent_uploads {
207 if let Some(result) = join_set.join_next().await {
208 result??;
209 }
210 }
211
212 join_set.spawn(async move {
213 let digest = {
214 let mut hasher = Sha256::new();
215 hasher.update(&data);
216 format!("{:x}", hasher.finalize())
217 };
218 let path = Path::from(path_str.as_str());
219 let payload = PutPayload::from_bytes(data);
220 Self::put_with_retry(store.as_ref(), &path, payload).await?;
221
222 Ok::<_, CheckpointerError>((
223 operator,
224 PartitionSnapshotEntry {
225 partition_id: partition,
226 is_delta,
227 path: path_str,
228 size_bytes: data_len,
229 sha256: Some(digest),
230 },
231 ))
232 });
233 }
234
235 let mut entries = Vec::new();
237 while let Some(result) = join_set.join_next().await {
238 entries.push(result??);
239 }
240
241 Ok(entries)
242 }
243}
244
245#[async_trait]
246impl Checkpointer for ObjectStoreCheckpointer {
247 async fn save_manifest(
248 &self,
249 manifest: &CheckpointManifestV2,
250 ) -> Result<(), CheckpointerError> {
251 let json = serde_json::to_vec_pretty(manifest)?;
252 let path_str = self.paths.manifest(&manifest.checkpoint_id);
253 let path = Path::from(path_str.as_str());
254 let payload = PutPayload::from_bytes(Bytes::from(json));
255 Self::put_with_retry(self.store.as_ref(), &path, payload).await
256 }
257
258 async fn load_manifest(
259 &self,
260 id: &CheckpointId,
261 ) -> Result<CheckpointManifestV2, CheckpointerError> {
262 let path_str = self.paths.manifest(id);
263 let path = Path::from(path_str.as_str());
264 let result = self.store.get_opts(&path, GetOptions::default()).await?;
265 let data = result.bytes().await?;
266 let manifest: CheckpointManifestV2 = serde_json::from_slice(&data)?;
267 Ok(manifest)
268 }
269
270 async fn save_snapshot(
271 &self,
272 id: &CheckpointId,
273 operator: &str,
274 partition: u32,
275 data: Bytes,
276 ) -> Result<String, CheckpointerError> {
277 let path_str = self.paths.snapshot(id, operator, partition);
278 self.write_with_digest(&path_str, data).await
279 }
280
281 async fn save_delta(
282 &self,
283 id: &CheckpointId,
284 operator: &str,
285 partition: u32,
286 data: Bytes,
287 ) -> Result<String, CheckpointerError> {
288 let path_str = self.paths.delta(id, operator, partition);
289 self.write_with_digest(&path_str, data).await
290 }
291
292 async fn load_artifact(&self, path_str: &str) -> Result<Bytes, CheckpointerError> {
293 let path = Path::from(path_str);
294 let result = self.store.get_opts(&path, GetOptions::default()).await?;
295 let data = result.bytes().await?;
296 Ok(data)
297 }
298
299 async fn update_latest(&self, id: &CheckpointId) -> Result<(), CheckpointerError> {
300 let path_str = self.paths.latest_pointer();
301 let path = Path::from(path_str.as_str());
302 let payload = PutPayload::from_bytes(Bytes::from(id.to_string_id()));
303 Self::put_with_retry(self.store.as_ref(), &path, payload).await
304 }
305
306 async fn read_latest(&self) -> Result<Option<CheckpointId>, CheckpointerError> {
307 let path_str = self.paths.latest_pointer();
308 let path = Path::from(path_str.as_str());
309 match self.store.get_opts(&path, GetOptions::default()).await {
310 Ok(result) => {
311 let data = result.bytes().await?;
312 let id_str =
313 std::str::from_utf8(&data).map_err(|e| object_store::Error::Generic {
314 store: "checkpointer",
315 source: Box::new(e),
316 })?;
317 let uuid =
318 uuid::Uuid::parse_str(id_str).map_err(|e| object_store::Error::Generic {
319 store: "checkpointer",
320 source: Box::new(e),
321 })?;
322 Ok(Some(CheckpointId::from_uuid(uuid)))
323 }
324 Err(object_store::Error::NotFound { .. }) => Ok(None),
325 Err(e) => Err(CheckpointerError::ObjectStore(e)),
326 }
327 }
328
329 async fn list_checkpoints(&self) -> Result<Vec<CheckpointId>, CheckpointerError> {
330 let prefix = Path::from(self.paths.latest_pointer().trim_end_matches("_latest"));
336 let base_prefix = &self.paths.base_prefix;
337 let mut ids = Vec::new();
338
339 let mut stream = self.store.list(Some(&prefix));
340 while let Some(meta) = stream.next().await {
341 let meta = meta?;
342 let path_str = meta.location.to_string();
343 if path_str.ends_with("manifest.json") {
344 if let Some(remainder) = path_str.strip_suffix("manifest.json") {
349 let id_str = remainder
350 .strip_prefix(base_prefix.as_str())
351 .unwrap_or(remainder);
352 let id_str = id_str.trim_end_matches('/');
353 if !id_str.is_empty() {
354 if let Ok(uuid) = uuid::Uuid::parse_str(id_str) {
355 ids.push(CheckpointId::from_uuid(uuid));
356 }
357 }
358 }
359 }
360 }
361
362 ids.sort();
363 Ok(ids)
364 }
365
366 async fn delete_checkpoint(&self, id: &CheckpointId) -> Result<(), CheckpointerError> {
367 let dir = self.paths.checkpoint_dir(id);
369 let prefix = Path::from(dir.as_str());
370
371 let mut paths_to_delete = Vec::new();
372 let mut stream = self.store.list(Some(&prefix));
373 while let Some(meta) = stream.next().await {
374 let meta = meta?;
375 paths_to_delete.push(meta.location);
376 }
377
378 let locations_stream = futures::stream::iter(paths_to_delete.into_iter().map(Ok)).boxed();
380 let mut results = self.store.delete_stream(locations_stream);
381 while let Some(result) = results.next().await {
382 result?;
383 }
384
385 Ok(())
386 }
387}
388
389pub fn verify_integrity(
395 path: &str,
396 data: &[u8],
397 expected_sha256: &str,
398) -> Result<(), CheckpointerError> {
399 let actual = ObjectStoreCheckpointer::sha256_hex(data);
400 if actual != expected_sha256 {
401 return Err(CheckpointerError::IntegrityMismatch {
402 path: path.to_string(),
403 expected: expected_sha256.to_string(),
404 actual,
405 });
406 }
407 Ok(())
408}
409
410#[cfg(test)]
411mod tests {
412 use super::super::layout::OperatorSnapshotEntry;
413 use super::*;
414 use object_store::memory::InMemory;
415
416 fn make_checkpointer() -> ObjectStoreCheckpointer {
417 let store = Arc::new(InMemory::new());
418 let paths = CheckpointPaths::default();
419 ObjectStoreCheckpointer::new(store, paths, 4)
420 }
421
422 #[tokio::test]
423 async fn test_save_and_load_manifest() {
424 let ckpt = make_checkpointer();
425 let id = CheckpointId::now();
426 let manifest = CheckpointManifestV2::new(id, 1);
427
428 ckpt.save_manifest(&manifest).await.unwrap();
429 let loaded = ckpt.load_manifest(&id).await.unwrap();
430
431 assert_eq!(loaded.checkpoint_id, id);
432 assert_eq!(loaded.epoch, 1);
433 }
434
435 #[tokio::test]
436 async fn test_save_snapshot_with_digest() {
437 let ckpt = make_checkpointer();
438 let id = CheckpointId::now();
439 let data = Bytes::from_static(b"hello world");
440
441 let digest = ckpt
442 .save_snapshot(&id, "window-agg", 0, data.clone())
443 .await
444 .unwrap();
445
446 assert!(!digest.is_empty());
448 verify_integrity("test", &data, &digest).unwrap();
449 }
450
451 #[tokio::test]
452 async fn test_load_artifact() {
453 let ckpt = make_checkpointer();
454 let id = CheckpointId::now();
455 let data = Bytes::from_static(b"partition state");
456
457 ckpt.save_snapshot(&id, "op1", 0, data.clone())
458 .await
459 .unwrap();
460
461 let path = ckpt.paths.snapshot(&id, "op1", 0);
462 let loaded = ckpt.load_artifact(&path).await.unwrap();
463 assert_eq!(loaded, data);
464 }
465
466 #[tokio::test]
467 async fn test_latest_pointer() {
468 let ckpt = make_checkpointer();
469
470 assert!(ckpt.read_latest().await.unwrap().is_none());
472
473 let id = CheckpointId::now();
474 ckpt.update_latest(&id).await.unwrap();
475
476 let latest = ckpt.read_latest().await.unwrap().unwrap();
477 assert_eq!(latest, id);
478 }
479
480 #[tokio::test]
481 async fn test_concurrent_partition_uploads() {
482 let ckpt = make_checkpointer();
483 let id = CheckpointId::now();
484
485 let snapshots = vec![
486 ("op1".into(), 0, false, Bytes::from_static(b"part0")),
487 ("op1".into(), 1, false, Bytes::from_static(b"part1")),
488 ("op1".into(), 2, true, Bytes::from_static(b"delta2")),
489 ];
490
491 let entries = ckpt
492 .save_partitions_concurrent(&id, snapshots)
493 .await
494 .unwrap();
495
496 assert_eq!(entries.len(), 3);
497 for (_, entry) in &entries {
498 assert!(entry.sha256.is_some());
499 }
500 }
501
502 #[tokio::test]
503 async fn test_integrity_mismatch() {
504 let result = verify_integrity("test.snap", b"data", "wrong_hash");
505 assert!(result.is_err());
506 assert!(matches!(
507 result.unwrap_err(),
508 CheckpointerError::IntegrityMismatch { .. }
509 ));
510 }
511
512 #[tokio::test]
513 async fn test_save_and_build_manifest() {
514 let ckpt = make_checkpointer();
515 let id = CheckpointId::now();
516
517 let snapshots = vec![
519 ("op1".into(), 0, false, Bytes::from_static(b"state0")),
520 ("op1".into(), 1, false, Bytes::from_static(b"state1")),
521 ];
522 let entries = ckpt
523 .save_partitions_concurrent(&id, snapshots)
524 .await
525 .unwrap();
526
527 let mut manifest = CheckpointManifestV2::new(id, 5);
529 let mut op_entry = OperatorSnapshotEntry {
530 partitions: Vec::new(),
531 total_bytes: 0,
532 };
533 for (_, part) in entries {
534 op_entry.total_bytes += part.size_bytes;
535 op_entry.partitions.push(part);
536 }
537 manifest.operators.insert("op1".into(), op_entry);
538
539 ckpt.save_manifest(&manifest).await.unwrap();
540 ckpt.update_latest(&id).await.unwrap();
541
542 let loaded = ckpt.load_manifest(&id).await.unwrap();
544 assert_eq!(loaded.operators.len(), 1);
545 assert_eq!(loaded.operators["op1"].partitions.len(), 2);
546 }
547}