1use std::sync::atomic::{AtomicU64, Ordering};
21use std::sync::Arc;
22
23use object_store::{ObjectStore, PutOptions, PutPayload};
24
25use crate::checkpoint_store::CheckpointStoreError;
26
27const BATCH_MAGIC: &[u8; 4] = b"LCB1";
29
30const BATCH_VERSION: u32 = 2;
32
33const DEFAULT_FLUSH_THRESHOLD: usize = 8 * 1024 * 1024;
35
36const HEADER_SIZE: usize = 20;
38
39struct BatchEntry {
41 key: String,
42 data: Vec<u8>,
43}
44
45#[derive(Debug)]
47pub struct BatchMetrics {
48 pub batches_flushed: AtomicU64,
50 pub entries_flushed: AtomicU64,
52 pub bytes_before_compression: AtomicU64,
54 pub bytes_after_compression: AtomicU64,
56 pub put_count: AtomicU64,
58}
59
60impl BatchMetrics {
61 #[must_use]
63 pub fn new() -> Self {
64 Self {
65 batches_flushed: AtomicU64::new(0),
66 entries_flushed: AtomicU64::new(0),
67 bytes_before_compression: AtomicU64::new(0),
68 bytes_after_compression: AtomicU64::new(0),
69 put_count: AtomicU64::new(0),
70 }
71 }
72
73 fn record_flush(&self, entries: u64, raw_bytes: u64, compressed_bytes: u64) {
75 self.batches_flushed.fetch_add(1, Ordering::Relaxed);
76 self.entries_flushed.fetch_add(entries, Ordering::Relaxed);
77 self.bytes_before_compression
78 .fetch_add(raw_bytes, Ordering::Relaxed);
79 self.bytes_after_compression
80 .fetch_add(compressed_bytes, Ordering::Relaxed);
81 self.put_count.fetch_add(1, Ordering::Relaxed);
82 }
83
84 #[must_use]
86 pub fn snapshot(&self) -> BatchMetricsSnapshot {
87 BatchMetricsSnapshot {
88 batches_flushed: self.batches_flushed.load(Ordering::Relaxed),
89 entries_flushed: self.entries_flushed.load(Ordering::Relaxed),
90 bytes_before_compression: self.bytes_before_compression.load(Ordering::Relaxed),
91 bytes_after_compression: self.bytes_after_compression.load(Ordering::Relaxed),
92 put_count: self.put_count.load(Ordering::Relaxed),
93 }
94 }
95}
96
97impl Default for BatchMetrics {
98 fn default() -> Self {
99 Self::new()
100 }
101}
102
103#[derive(Debug, Clone, Copy)]
105pub struct BatchMetricsSnapshot {
106 pub batches_flushed: u64,
108 pub entries_flushed: u64,
110 pub bytes_before_compression: u64,
112 pub bytes_after_compression: u64,
114 pub put_count: u64,
116}
117
118pub struct CheckpointBatcher {
125 buffer: Vec<BatchEntry>,
126 buffer_size: usize,
127 flush_threshold: usize,
128 store: Arc<dyn ObjectStore>,
129 prefix: String,
130 rt: tokio::runtime::Runtime,
131 metrics: Arc<BatchMetrics>,
132}
133
134impl CheckpointBatcher {
135 pub fn new(
145 store: Arc<dyn ObjectStore>,
146 prefix: String,
147 flush_threshold: Option<usize>,
148 ) -> std::io::Result<Self> {
149 let rt = tokio::runtime::Builder::new_current_thread()
150 .enable_all()
151 .build()?;
152 Ok(Self {
153 buffer: Vec::new(),
154 buffer_size: 0,
155 flush_threshold: flush_threshold.unwrap_or(DEFAULT_FLUSH_THRESHOLD),
156 store,
157 prefix,
158 rt,
159 metrics: Arc::new(BatchMetrics::new()),
160 })
161 }
162
163 pub fn add(&mut self, key: String, data: Vec<u8>) {
169 self.buffer_size += key.len() + data.len() + 8; self.buffer.push(BatchEntry { key, data });
171 }
172
173 #[must_use]
175 pub fn should_flush(&self) -> bool {
176 self.buffer_size >= self.flush_threshold
177 }
178
179 #[must_use]
181 pub fn len(&self) -> usize {
182 self.buffer.len()
183 }
184
185 #[must_use]
187 pub fn is_empty(&self) -> bool {
188 self.buffer.is_empty()
189 }
190
191 #[must_use]
193 pub fn buffer_size(&self) -> usize {
194 self.buffer_size
195 }
196
197 #[must_use]
199 pub fn metrics(&self) -> &Arc<BatchMetrics> {
200 &self.metrics
201 }
202
203 pub fn flush(&mut self, epoch: u64) -> Result<(), CheckpointStoreError> {
214 if self.buffer.is_empty() {
215 return Ok(());
216 }
217
218 let (raw_size, payload) = encode_batch(&self.buffer);
219
220 let path = object_store::path::Path::from(format!(
221 "{}checkpoints/batch-{epoch:06}.lz4",
222 self.prefix
223 ));
224
225 let compressed_size = payload.content_length();
226
227 self.rt.block_on(async {
228 self.store
229 .put_opts(&path, payload, PutOptions::default())
230 .await
231 })?;
232
233 let entry_count = self.buffer.len() as u64;
234 self.metrics
235 .record_flush(entry_count, raw_size as u64, compressed_size as u64);
236
237 self.buffer.clear();
238 self.buffer_size = 0;
239
240 Ok(())
241 }
242}
243
244#[allow(clippy::cast_possible_truncation)] fn encode_batch(entries: &[BatchEntry]) -> (usize, PutPayload) {
249 let mut body = Vec::new();
251 for entry in entries {
252 body.extend_from_slice(&(entry.key.len() as u32).to_le_bytes());
253 body.extend_from_slice(entry.key.as_bytes());
254 body.extend_from_slice(&(entry.data.len() as u32).to_le_bytes());
255 body.extend_from_slice(&entry.data);
256 }
257
258 let uncompressed_size = body.len();
259 let compressed = lz4_flex::compress_prepend_size(&body);
260
261 let crc = crc32c::crc32c(&compressed);
263
264 let mut out = Vec::with_capacity(HEADER_SIZE + compressed.len());
266 out.extend_from_slice(BATCH_MAGIC);
267 out.extend_from_slice(&BATCH_VERSION.to_le_bytes());
268 out.extend_from_slice(&(entries.len() as u32).to_le_bytes());
269 out.extend_from_slice(&(uncompressed_size as u32).to_le_bytes());
270 out.extend_from_slice(&crc.to_le_bytes());
271 out.extend_from_slice(&compressed);
272
273 (
274 uncompressed_size,
275 PutPayload::from_bytes(bytes::Bytes::from(out)),
276 )
277}
278
279#[allow(clippy::cast_possible_truncation)] pub fn decode_batch(raw: &[u8]) -> Result<Vec<(String, Vec<u8>)>, CheckpointStoreError> {
286 if raw.len() < HEADER_SIZE {
287 return Err(CheckpointStoreError::Io(std::io::Error::new(
288 std::io::ErrorKind::InvalidData,
289 "batch too short for header",
290 )));
291 }
292
293 if &raw[..4] != BATCH_MAGIC {
294 return Err(CheckpointStoreError::Io(std::io::Error::new(
295 std::io::ErrorKind::InvalidData,
296 "invalid batch magic",
297 )));
298 }
299
300 let version = u32::from_le_bytes([raw[4], raw[5], raw[6], raw[7]]);
301
302 let (header_size, has_crc) = match version {
304 1 => (16, false), 2 => (20, true), _ => {
307 return Err(CheckpointStoreError::Io(std::io::Error::new(
308 std::io::ErrorKind::InvalidData,
309 format!("unsupported batch version {version}"),
310 )));
311 }
312 };
313
314 if raw.len() < header_size {
315 return Err(CheckpointStoreError::Io(std::io::Error::new(
316 std::io::ErrorKind::InvalidData,
317 format!("batch too short for v{version} header"),
318 )));
319 }
320
321 let entry_count = u32::from_le_bytes([raw[8], raw[9], raw[10], raw[11]]) as usize;
322 let _uncompressed_size = u32::from_le_bytes([raw[12], raw[13], raw[14], raw[15]]);
323
324 let compressed_body = &raw[header_size..];
325
326 if has_crc {
328 let expected_crc = u32::from_le_bytes([raw[16], raw[17], raw[18], raw[19]]);
329 let actual_crc = crc32c::crc32c(compressed_body);
330 if actual_crc != expected_crc {
331 return Err(CheckpointStoreError::Io(std::io::Error::new(
332 std::io::ErrorKind::InvalidData,
333 format!(
334 "batch CRC32C mismatch: expected {expected_crc:#010x}, \
335 actual {actual_crc:#010x}"
336 ),
337 )));
338 }
339 }
340
341 let body = lz4_flex::decompress_size_prepended(compressed_body).map_err(|e| {
342 CheckpointStoreError::Io(std::io::Error::new(
343 std::io::ErrorKind::InvalidData,
344 format!("LZ4 decompression failed: {e}"),
345 ))
346 })?;
347
348 let mut entries = Vec::with_capacity(entry_count);
349 let mut cursor = 0;
350
351 for _ in 0..entry_count {
352 if cursor + 4 > body.len() {
353 return Err(CheckpointStoreError::Io(std::io::Error::new(
354 std::io::ErrorKind::UnexpectedEof,
355 "truncated batch entry (key length)",
356 )));
357 }
358 let key_len = u32::from_le_bytes([
359 body[cursor],
360 body[cursor + 1],
361 body[cursor + 2],
362 body[cursor + 3],
363 ]) as usize;
364 cursor += 4;
365
366 if cursor + key_len > body.len() {
367 return Err(CheckpointStoreError::Io(std::io::Error::new(
368 std::io::ErrorKind::UnexpectedEof,
369 "truncated batch entry (key data)",
370 )));
371 }
372 let key = String::from_utf8_lossy(&body[cursor..cursor + key_len]).into_owned();
373 cursor += key_len;
374
375 if cursor + 4 > body.len() {
376 return Err(CheckpointStoreError::Io(std::io::Error::new(
377 std::io::ErrorKind::UnexpectedEof,
378 "truncated batch entry (data length)",
379 )));
380 }
381 let data_len = u32::from_le_bytes([
382 body[cursor],
383 body[cursor + 1],
384 body[cursor + 2],
385 body[cursor + 3],
386 ]) as usize;
387 cursor += 4;
388
389 if cursor + data_len > body.len() {
390 return Err(CheckpointStoreError::Io(std::io::Error::new(
391 std::io::ErrorKind::UnexpectedEof,
392 "truncated batch entry (data)",
393 )));
394 }
395 let data = body[cursor..cursor + data_len].to_vec();
396 cursor += data_len;
397
398 entries.push((key, data));
399 }
400
401 Ok(entries)
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407 use object_store::memory::InMemory;
408
409 fn make_batcher(threshold: usize) -> (CheckpointBatcher, Arc<dyn ObjectStore>) {
410 let store: Arc<dyn ObjectStore> = Arc::new(InMemory::new());
411 let batcher =
412 CheckpointBatcher::new(store.clone(), String::new(), Some(threshold)).unwrap();
413 (batcher, store)
414 }
415
416 #[test]
417 fn test_add_tracks_size() {
418 let (mut batcher, _store) = make_batcher(1024);
419 assert!(batcher.is_empty());
420 assert_eq!(batcher.buffer_size(), 0);
421
422 batcher.add("key1".into(), vec![0u8; 100]);
423 assert_eq!(batcher.len(), 1);
424 assert!(!batcher.is_empty());
425 assert_eq!(batcher.buffer_size(), 112);
427 }
428
429 #[test]
430 fn test_should_flush_at_threshold() {
431 let (mut batcher, _store) = make_batcher(200);
432 assert!(!batcher.should_flush());
433
434 batcher.add("k".into(), vec![0u8; 100]);
435 assert!(!batcher.should_flush());
436
437 batcher.add("k".into(), vec![0u8; 100]);
438 assert!(batcher.should_flush());
439 }
440
441 #[test]
442 fn test_flush_empty_is_noop() {
443 let (mut batcher, _store) = make_batcher(1024);
444 batcher.flush(1).unwrap();
445 let snap = batcher.metrics().snapshot();
446 assert_eq!(snap.batches_flushed, 0);
447 assert_eq!(snap.put_count, 0);
448 }
449
450 #[test]
451 fn test_flush_writes_object() {
452 let (mut batcher, store) = make_batcher(1024 * 1024);
453
454 batcher.add("partition-0/agg".into(), vec![42u8; 256]);
455 batcher.add("partition-1/agg".into(), vec![99u8; 128]);
456 batcher.flush(7).unwrap();
457
458 assert!(batcher.is_empty());
459 assert_eq!(batcher.buffer_size(), 0);
460
461 let rt = tokio::runtime::Builder::new_current_thread()
463 .enable_all()
464 .build()
465 .unwrap();
466 let result = rt.block_on(async {
467 store
468 .get_opts(
469 &object_store::path::Path::from("checkpoints/batch-000007.lz4"),
470 object_store::GetOptions::default(),
471 )
472 .await
473 });
474 assert!(result.is_ok());
475 }
476
477 #[test]
478 fn test_lz4_roundtrip() {
479 let (mut batcher, store) = make_batcher(1024 * 1024);
480
481 let entries = vec![
482 ("partition-0/state".to_string(), vec![1u8; 500]),
483 ("partition-1/state".to_string(), vec![2u8; 300]),
484 ("partition-2/agg".to_string(), vec![3u8; 200]),
485 ];
486
487 for (k, v) in &entries {
488 batcher.add(k.clone(), v.clone());
489 }
490 batcher.flush(42).unwrap();
491
492 let rt = tokio::runtime::Builder::new_current_thread()
494 .enable_all()
495 .build()
496 .unwrap();
497 let data = rt.block_on(async {
498 let result = store
499 .get_opts(
500 &object_store::path::Path::from("checkpoints/batch-000042.lz4"),
501 object_store::GetOptions::default(),
502 )
503 .await
504 .unwrap();
505 result.bytes().await.unwrap()
506 });
507
508 let decoded = decode_batch(&data).unwrap();
509 assert_eq!(decoded.len(), 3);
510 assert_eq!(decoded[0].0, "partition-0/state");
511 assert_eq!(decoded[0].1, vec![1u8; 500]);
512 assert_eq!(decoded[1].0, "partition-1/state");
513 assert_eq!(decoded[1].1, vec![2u8; 300]);
514 assert_eq!(decoded[2].0, "partition-2/agg");
515 assert_eq!(decoded[2].1, vec![3u8; 200]);
516 }
517
518 #[test]
519 fn test_metrics_recorded_on_flush() {
520 let (mut batcher, _store) = make_batcher(1024 * 1024);
521
522 batcher.add("k1".into(), vec![0u8; 100]);
523 batcher.add("k2".into(), vec![0u8; 200]);
524 batcher.flush(1).unwrap();
525
526 let snap = batcher.metrics().snapshot();
527 assert_eq!(snap.batches_flushed, 1);
528 assert_eq!(snap.entries_flushed, 2);
529 assert_eq!(snap.put_count, 1);
530 assert!(snap.bytes_before_compression > 0);
531 assert!(snap.bytes_after_compression > 0);
532 }
533
534 #[test]
535 fn test_metrics_accumulate_across_flushes() {
536 let (mut batcher, _store) = make_batcher(1024 * 1024);
537
538 batcher.add("k1".into(), vec![0u8; 100]);
539 batcher.flush(1).unwrap();
540
541 batcher.add("k2".into(), vec![0u8; 200]);
542 batcher.add("k3".into(), vec![0u8; 50]);
543 batcher.flush(2).unwrap();
544
545 let snap = batcher.metrics().snapshot();
546 assert_eq!(snap.batches_flushed, 2);
547 assert_eq!(snap.entries_flushed, 3);
548 assert_eq!(snap.put_count, 2);
549 }
550
551 #[test]
552 fn test_compression_reduces_size() {
553 let (mut batcher, _store) = make_batcher(1024 * 1024);
554
555 batcher.add("big".into(), vec![0u8; 10_000]);
557 batcher.flush(1).unwrap();
558
559 let snap = batcher.metrics().snapshot();
560 assert!(
561 snap.bytes_after_compression < snap.bytes_before_compression,
562 "compressed ({}) should be smaller than raw ({})",
563 snap.bytes_after_compression,
564 snap.bytes_before_compression
565 );
566 }
567
568 #[test]
569 fn test_decode_invalid_magic() {
570 let bad = b"XXXX\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
571 let err = decode_batch(bad).unwrap_err();
572 assert!(err.to_string().contains("invalid batch magic"));
573 }
574
575 #[test]
576 fn test_decode_too_short() {
577 let err = decode_batch(b"LCB").unwrap_err();
578 assert!(err.to_string().contains("too short"));
579 }
580
581 #[test]
582 fn test_decode_bad_version() {
583 let mut buf = Vec::new();
584 buf.extend_from_slice(b"LCB1");
585 buf.extend_from_slice(&99u32.to_le_bytes()); buf.extend_from_slice(&0u32.to_le_bytes()); buf.extend_from_slice(&0u32.to_le_bytes()); buf.extend_from_slice(&0u32.to_le_bytes()); let err = decode_batch(&buf).unwrap_err();
590 assert!(err.to_string().contains("unsupported batch version"));
591 }
592
593 #[test]
594 fn test_flush_clears_buffer() {
595 let (mut batcher, _store) = make_batcher(64);
596
597 batcher.add("a".into(), vec![0u8; 50]);
598 batcher.add("b".into(), vec![0u8; 50]);
599 assert!(batcher.should_flush());
600
601 batcher.flush(1).unwrap();
602 assert!(!batcher.should_flush());
603 assert!(batcher.is_empty());
604 assert_eq!(batcher.buffer_size(), 0);
605 }
606}