1use std::time::{Duration, Instant};
27
28use super::barrier::CheckpointBarrier;
29
30#[derive(Debug, Clone)]
32pub struct UnalignedCheckpointConfig {
33 pub enabled: bool,
35 pub alignment_timeout_threshold: Duration,
37 pub max_inflight_buffer_bytes: usize,
39 pub force_unaligned: bool,
41}
42
43impl Default for UnalignedCheckpointConfig {
44 fn default() -> Self {
45 Self {
46 enabled: true,
47 alignment_timeout_threshold: Duration::from_secs(10),
48 max_inflight_buffer_bytes: 256 * 1024 * 1024,
49 force_unaligned: false,
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct InFlightChannelData {
57 pub input_id: usize,
59 pub events: Vec<Vec<u8>>,
61 pub size_bytes: usize,
63}
64
65#[derive(Debug)]
67pub struct UnalignedSnapshot {
68 pub barrier: CheckpointBarrier,
70 pub operator_state: Option<Vec<u8>>,
72 pub inflight_data: Vec<InFlightChannelData>,
74 pub total_size_bytes: usize,
76 pub was_threshold_triggered: bool,
78}
79
80#[derive(Debug)]
82pub enum UnalignedAction<T> {
83 Forward(T),
85 Buffer,
87 AlignedSnapshot(CheckpointBarrier),
89 UnalignedSnapshot(UnalignedSnapshot),
91 Drain(T),
93 WatermarkPassThrough(i64),
95}
96
97#[derive(Debug, Clone, PartialEq, Eq)]
99enum State {
100 Idle,
102 AligningWithFallback {
104 started_at: Instant,
106 aligned_inputs: u128,
108 barrier: CheckpointBarrier,
110 },
111 WaitingForLateBarriers {
113 received_inputs: u128,
115 barrier: CheckpointBarrier,
117 },
118}
119
120pub struct UnalignedCheckpointer<T> {
129 num_inputs: usize,
131 config: UnalignedCheckpointConfig,
133 state: State,
135 inflight_buffers: Vec<Vec<T>>,
137 buffered_bytes: usize,
139 is_sink: bool,
141}
142
143impl<T> UnalignedCheckpointer<T> {
144 #[must_use]
152 pub fn new(num_inputs: usize, config: UnalignedCheckpointConfig, is_sink: bool) -> Self {
153 let inflight_buffers = (0..num_inputs).map(|_| Vec::new()).collect();
154 Self {
155 num_inputs,
156 config,
157 state: State::Idle,
158 inflight_buffers,
159 buffered_bytes: 0,
160 is_sink,
161 }
162 }
163
164 #[must_use]
166 pub fn aligned_count(&self) -> usize {
167 match &self.state {
168 State::AligningWithFallback { aligned_inputs, .. } => {
169 aligned_inputs.count_ones() as usize
170 }
171 State::WaitingForLateBarriers {
172 received_inputs, ..
173 } => received_inputs.count_ones() as usize,
174 State::Idle => 0,
175 }
176 }
177
178 #[must_use]
180 pub fn is_checkpointing(&self) -> bool {
181 !matches!(self.state, State::Idle)
182 }
183
184 #[must_use]
186 pub fn buffered_bytes(&self) -> usize {
187 self.buffered_bytes
188 }
189
190 pub fn on_barrier(&mut self, input_id: usize, barrier: CheckpointBarrier) -> UnalignedAction<T>
195 where
196 T: std::fmt::Debug,
197 {
198 match &self.state {
199 State::Idle => {
200 let mut aligned_inputs = 0u128;
202 aligned_inputs |= 1u128 << input_id;
203
204 if self.num_inputs == 1 || aligned_inputs.count_ones() as usize == self.num_inputs {
205 self.state = State::Idle;
207 self.clear_buffers();
208 return UnalignedAction::AlignedSnapshot(barrier);
209 }
210
211 if self.config.force_unaligned && !self.is_sink {
212 return self.trigger_unaligned(barrier, aligned_inputs);
214 }
215
216 self.state = State::AligningWithFallback {
217 started_at: Instant::now(),
218 aligned_inputs,
219 barrier,
220 };
221
222 UnalignedAction::Buffer
223 }
224 State::AligningWithFallback {
225 started_at,
226 aligned_inputs,
227 barrier: pending_barrier,
228 } => {
229 let mut aligned = *aligned_inputs;
230 aligned |= 1u128 << input_id;
231 let started = *started_at;
232 let pending = *pending_barrier;
233
234 if aligned.count_ones() as usize == self.num_inputs {
235 self.state = State::Idle;
237 self.clear_buffers();
238 return UnalignedAction::AlignedSnapshot(pending);
239 }
240
241 if started.elapsed() >= self.config.alignment_timeout_threshold && !self.is_sink {
243 return self.trigger_unaligned(pending, aligned);
244 }
245
246 self.state = State::AligningWithFallback {
247 started_at: started,
248 aligned_inputs: aligned,
249 barrier: pending,
250 };
251
252 UnalignedAction::Buffer
253 }
254 State::WaitingForLateBarriers {
255 received_inputs,
256 barrier: pending_barrier,
257 } => {
258 let mut received = *received_inputs;
259 received |= 1u128 << input_id;
260 let pending = *pending_barrier;
261
262 if received.count_ones() as usize == self.num_inputs {
263 self.state = State::Idle;
265 self.clear_buffers();
266 } else {
267 self.state = State::WaitingForLateBarriers {
268 received_inputs: received,
269 barrier: pending,
270 };
271 }
272
273 UnalignedAction::Buffer
275 }
276 }
277 }
278
279 pub fn check_timeout(&mut self) -> Option<UnalignedAction<T>>
284 where
285 T: std::fmt::Debug,
286 {
287 if self.is_sink {
288 return None;
289 }
290
291 match &self.state {
292 State::AligningWithFallback {
293 started_at,
294 aligned_inputs,
295 barrier,
296 } => {
297 if started_at.elapsed() >= self.config.alignment_timeout_threshold {
298 let barrier = *barrier;
299 let aligned = *aligned_inputs;
300 Some(self.trigger_unaligned(barrier, aligned))
301 } else {
302 None
303 }
304 }
305 _ => None,
306 }
307 }
308
309 fn trigger_unaligned(
311 &mut self,
312 barrier: CheckpointBarrier,
313 aligned_inputs: u128,
314 ) -> UnalignedAction<T>
315 where
316 T: std::fmt::Debug,
317 {
318 if self.buffered_bytes > self.config.max_inflight_buffer_bytes {
320 self.state = State::Idle;
322 self.clear_buffers();
323 return UnalignedAction::AlignedSnapshot(barrier);
325 }
326
327 let mut inflight_data = Vec::new();
328 for input_id in 0..self.num_inputs {
329 if aligned_inputs & (1u128 << input_id) == 0 {
330 let events = std::mem::take(&mut self.inflight_buffers[input_id]);
332 if !events.is_empty() {
333 inflight_data.push(InFlightChannelData {
334 input_id,
335 events: Vec::new(), size_bytes: 0,
337 });
338 let _ = events; }
343 }
344 }
345
346 let total_size = self.buffered_bytes;
347 let unaligned_barrier = CheckpointBarrier {
348 checkpoint_id: barrier.checkpoint_id,
349 epoch: barrier.epoch,
350 flags: barrier.flags | super::barrier::flags::UNALIGNED,
351 };
352
353 let snapshot = UnalignedSnapshot {
354 barrier: unaligned_barrier,
355 operator_state: None, inflight_data,
357 total_size_bytes: total_size,
358 was_threshold_triggered: true,
359 };
360
361 self.state = State::WaitingForLateBarriers {
362 received_inputs: aligned_inputs,
363 barrier,
364 };
365 self.buffered_bytes = 0;
366
367 UnalignedAction::UnalignedSnapshot(snapshot)
368 }
369
370 pub fn buffer_event(&mut self, input_id: usize, event: T, size_bytes: usize) -> bool {
375 if self.buffered_bytes + size_bytes > self.config.max_inflight_buffer_bytes {
376 return false;
377 }
378 if input_id < self.inflight_buffers.len() {
379 self.inflight_buffers[input_id].push(event);
380 self.buffered_bytes += size_bytes;
381 }
382 true
383 }
384
385 fn clear_buffers(&mut self) {
387 for buf in &mut self.inflight_buffers {
388 buf.clear();
389 }
390 self.buffered_bytes = 0;
391 }
392}
393
394impl<T: std::fmt::Debug> std::fmt::Debug for UnalignedCheckpointer<T> {
395 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
396 f.debug_struct("UnalignedCheckpointer")
397 .field("num_inputs", &self.num_inputs)
398 .field("state", &self.state)
399 .field("buffered_bytes", &self.buffered_bytes)
400 .field("is_sink", &self.is_sink)
401 .finish_non_exhaustive()
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408 use crate::checkpoint::barrier::flags;
409
410 fn default_config() -> UnalignedCheckpointConfig {
411 UnalignedCheckpointConfig {
412 enabled: true,
413 alignment_timeout_threshold: Duration::from_millis(100),
414 max_inflight_buffer_bytes: 1024 * 1024,
415 force_unaligned: false,
416 }
417 }
418
419 #[test]
420 fn test_aligned_fast_path() {
421 let config = default_config();
422 let mut ckpt: UnalignedCheckpointer<String> = UnalignedCheckpointer::new(2, config, false);
423
424 let barrier = CheckpointBarrier::new(1, 1);
425
426 let action = ckpt.on_barrier(0, barrier);
428 assert!(matches!(action, UnalignedAction::Buffer));
429 assert!(ckpt.is_checkpointing());
430
431 let action = ckpt.on_barrier(1, barrier);
433 assert!(matches!(action, UnalignedAction::AlignedSnapshot(b) if b.checkpoint_id == 1));
434 assert!(!ckpt.is_checkpointing());
435 }
436
437 #[test]
438 fn test_single_input_immediate_aligned() {
439 let config = default_config();
440 let mut ckpt: UnalignedCheckpointer<String> = UnalignedCheckpointer::new(1, config, false);
441
442 let barrier = CheckpointBarrier::new(1, 1);
443 let action = ckpt.on_barrier(0, barrier);
444 assert!(matches!(action, UnalignedAction::AlignedSnapshot(_)));
445 assert!(!ckpt.is_checkpointing());
446 }
447
448 #[test]
449 fn test_timeout_triggers_unaligned() {
450 let config = UnalignedCheckpointConfig {
451 alignment_timeout_threshold: Duration::from_millis(1),
452 ..default_config()
453 };
454 let mut ckpt: UnalignedCheckpointer<String> = UnalignedCheckpointer::new(3, config, false);
455
456 let barrier = CheckpointBarrier::new(1, 1);
457
458 let action = ckpt.on_barrier(0, barrier);
460 assert!(matches!(action, UnalignedAction::Buffer));
461
462 std::thread::sleep(Duration::from_millis(5));
464
465 let action = ckpt.check_timeout();
467 assert!(action.is_some());
468 match action.unwrap() {
469 UnalignedAction::UnalignedSnapshot(snap) => {
470 assert!(snap.barrier.is_unaligned());
471 assert!(snap.was_threshold_triggered);
472 }
473 other => panic!("expected UnalignedSnapshot, got {other:?}"),
474 }
475 }
476
477 #[test]
478 fn test_inflight_capture() {
479 let config = UnalignedCheckpointConfig {
480 alignment_timeout_threshold: Duration::from_millis(1),
481 ..default_config()
482 };
483 let mut ckpt: UnalignedCheckpointer<String> = UnalignedCheckpointer::new(2, config, false);
484
485 let barrier = CheckpointBarrier::new(1, 1);
486
487 ckpt.on_barrier(0, barrier);
489
490 assert!(ckpt.buffer_event(1, "event-1".into(), 7));
492 assert!(ckpt.buffer_event(1, "event-2".into(), 7));
493
494 assert_eq!(ckpt.buffered_bytes(), 14);
495
496 std::thread::sleep(Duration::from_millis(5));
498 let action = ckpt.check_timeout();
499 assert!(action.is_some());
500
501 match action.unwrap() {
502 UnalignedAction::UnalignedSnapshot(snap) => {
503 assert!(!snap.inflight_data.is_empty());
504 assert_eq!(snap.inflight_data[0].input_id, 1);
505 }
506 other => panic!("expected UnalignedSnapshot, got {other:?}"),
507 }
508 }
509
510 #[test]
511 fn test_max_buffer_exceeded() {
512 let config = UnalignedCheckpointConfig {
513 max_inflight_buffer_bytes: 10,
514 ..default_config()
515 };
516 let mut ckpt: UnalignedCheckpointer<String> = UnalignedCheckpointer::new(2, config, false);
517
518 let barrier = CheckpointBarrier::new(1, 1);
519 ckpt.on_barrier(0, barrier);
520
521 assert!(ckpt.buffer_event(1, "12345".into(), 5));
523 assert!(ckpt.buffer_event(1, "12345".into(), 5));
524 assert!(!ckpt.buffer_event(1, "x".into(), 1));
526 }
527
528 #[test]
529 fn test_force_unaligned_mode() {
530 let config = UnalignedCheckpointConfig {
531 force_unaligned: true,
532 ..default_config()
533 };
534 let mut ckpt: UnalignedCheckpointer<String> = UnalignedCheckpointer::new(3, config, false);
535
536 let barrier = CheckpointBarrier::new(1, 1);
537
538 let action = ckpt.on_barrier(0, barrier);
540 match action {
541 UnalignedAction::UnalignedSnapshot(snap) => {
542 assert!(snap.barrier.is_unaligned());
543 }
544 other => panic!("expected UnalignedSnapshot, got {other:?}"),
545 }
546 }
547
548 #[test]
549 fn test_sink_cannot_use_unaligned() {
550 let config = UnalignedCheckpointConfig {
551 alignment_timeout_threshold: Duration::from_millis(1),
552 force_unaligned: true,
553 ..default_config()
554 };
555 let mut ckpt: UnalignedCheckpointer<String> = UnalignedCheckpointer::new(2, config, true);
557
558 let barrier = CheckpointBarrier::new(1, 1);
559
560 let action = ckpt.on_barrier(0, barrier);
562 assert!(matches!(action, UnalignedAction::Buffer));
563
564 std::thread::sleep(Duration::from_millis(5));
566 assert!(ckpt.check_timeout().is_none());
567 }
568
569 #[test]
570 fn test_late_barriers_complete_cycle() {
571 let config = UnalignedCheckpointConfig {
572 alignment_timeout_threshold: Duration::from_millis(1),
573 ..default_config()
574 };
575 let mut ckpt: UnalignedCheckpointer<String> = UnalignedCheckpointer::new(3, config, false);
576
577 let barrier = CheckpointBarrier::new(1, 1);
578
579 ckpt.on_barrier(0, barrier);
581
582 std::thread::sleep(Duration::from_millis(5));
584 let action = ckpt.check_timeout().unwrap();
585 assert!(matches!(action, UnalignedAction::UnalignedSnapshot(_)));
586
587 assert!(ckpt.is_checkpointing());
589
590 let action = ckpt.on_barrier(1, barrier);
591 assert!(matches!(action, UnalignedAction::Buffer));
592 assert!(ckpt.is_checkpointing()); let action = ckpt.on_barrier(2, barrier);
595 assert!(matches!(action, UnalignedAction::Buffer));
596 assert!(!ckpt.is_checkpointing()); }
598
599 #[test]
600 fn test_unaligned_flag_set_on_barrier() {
601 let barrier = CheckpointBarrier {
602 checkpoint_id: 1,
603 epoch: 1,
604 flags: flags::NONE | flags::UNALIGNED,
605 };
606 assert!(barrier.is_unaligned());
607 assert!(!barrier.is_full_snapshot());
608 assert!(!barrier.is_drain());
609 }
610}