Skip to main content

laminar_connectors/kafka/
rebalance.rs

1//! Kafka consumer group rebalance state tracking.
2//!
3//! [`RebalanceState`] tracks which topic-partitions are currently
4//! assigned to this consumer and counts rebalance events.
5//!
6//! [`LaminarConsumerContext`] is an rdkafka `ConsumerContext` that
7//! signals a checkpoint request on partition revocation, enabling
8//! the pipeline to persist offsets before ownership changes.
9
10use std::collections::HashSet;
11use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
12use std::sync::{Arc, Mutex};
13
14use prometheus::IntCounter;
15use rdkafka::consumer::{Consumer, ConsumerContext};
16use rdkafka::ClientContext;
17use tracing::{info, warn};
18
19/// Tracks partition assignments across consumer group rebalances.
20#[derive(Debug, Clone, Default)]
21pub struct RebalanceState {
22    /// Currently assigned (topic, partition) pairs.
23    assigned: HashSet<(String, i32)>,
24    /// Total number of rebalance events.
25    rebalance_count: u64,
26}
27
28impl RebalanceState {
29    /// Starts with no partitions assigned.
30    #[must_use]
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    /// Handles a partition assignment event.
36    ///
37    /// Additive: inserts new partitions without clearing existing ones.
38    /// This is correct for both eager and cooperative rebalance protocols:
39    /// - Eager: the preceding `on_revoke(all)` already clears the set.
40    /// - Cooperative: `Assign` only contains newly assigned partitions,
41    ///   so clearing would lose existing assignments.
42    pub fn on_assign(&mut self, partitions: &[(String, i32)]) {
43        for (topic, partition) in partitions {
44            self.assigned.insert((topic.clone(), *partition));
45        }
46        self.rebalance_count += 1;
47    }
48
49    /// Handles a partition revocation event.
50    ///
51    /// Removes the specified partitions from the assignment set.
52    pub fn on_revoke(&mut self, partitions: &[(String, i32)]) {
53        for (topic, partition) in partitions {
54            self.assigned.remove(&(topic.clone(), *partition));
55        }
56    }
57
58    /// Returns the set of currently assigned partitions.
59    #[must_use]
60    pub fn assigned_partitions(&self) -> &HashSet<(String, i32)> {
61        &self.assigned
62    }
63
64    /// Returns the total number of rebalance events.
65    #[must_use]
66    pub fn rebalance_count(&self) -> u64 {
67        self.rebalance_count
68    }
69
70    /// Returns `true` if the given topic-partition is currently assigned.
71    #[must_use]
72    pub fn is_assigned(&self, topic: &str, partition: i32) -> bool {
73        self.assigned.contains(&(topic.to_string(), partition))
74    }
75}
76
77/// rdkafka consumer context that signals a checkpoint on partition revocation.
78///
79/// When a consumer group rebalance revokes partitions from this consumer,
80/// the context notifies the pipeline coordinator to trigger an immediate
81/// checkpoint before the partitions are reassigned. This prevents offset
82/// loss during rebalance.
83///
84/// Rebalance callbacks run on rdkafka's background thread, so all shared
85/// state uses `Arc` + atomic types for thread safety.
86pub struct LaminarConsumerContext {
87    checkpoint_requested: Arc<AtomicBool>,
88    rebalance_count: AtomicU64,
89    /// Shared rebalance state updated on Assign/Revoke events.
90    rebalance_state: Arc<Mutex<RebalanceState>>,
91    /// Shared rebalance event counter for source-level metrics.
92    rebalance_metric: Arc<AtomicU64>,
93    /// Monotonically increasing generation bumped on each Revoke event.
94    ///
95    /// Allows lock-free detection of revoke events from the hot path
96    /// (`poll_batch`) — the source compares its cached generation against
97    /// this value using `Relaxed` ordering, and only locks the mutex when
98    /// a change is detected.
99    revoke_generation: Arc<AtomicU64>,
100    /// Shared flag indicating whether the reader task has paused Kafka
101    /// partitions for backpressure. On `Assign`, newly assigned partitions
102    /// must be re-paused if this flag is true.
103    reader_paused: Arc<AtomicBool>,
104    /// Set by `commit_callback` on broker rejection; reader task escalates
105    /// to `CommitMode::Sync` on the next timer tick.
106    commit_retry_needed: Arc<AtomicBool>,
107    /// Snapshot of consumed offsets, updated once per `poll_batch()` cycle.
108    /// Read on Assign to seek newly assigned partitions to last-consumed
109    /// offset + 1, preventing duplicates after broker failures.
110    offset_snapshot: Arc<Mutex<super::offsets::OffsetTracker>>,
111    /// Counter bumped on every broker-confirmed async commit. The immediate
112    /// return from `CommitMode::Async` only means "queued"; the real
113    /// outcome arrives here via `commit_callback`, so this is the
114    /// authoritative success counter for async commits.
115    commits_counter: IntCounter,
116    /// Counter bumped when the broker rejects an async commit.
117    commit_failures_counter: IntCounter,
118}
119
120impl LaminarConsumerContext {
121    /// Wires checkpoint signaling, partition tracking, and rebalance metrics.
122    #[must_use]
123    #[allow(clippy::too_many_arguments)]
124    pub fn new(
125        checkpoint_requested: Arc<AtomicBool>,
126        rebalance_state: Arc<Mutex<RebalanceState>>,
127        rebalance_metric: Arc<AtomicU64>,
128        revoke_generation: Arc<AtomicU64>,
129        reader_paused: Arc<AtomicBool>,
130        commit_retry_needed: Arc<AtomicBool>,
131        offset_snapshot: Arc<Mutex<super::offsets::OffsetTracker>>,
132        commits_counter: IntCounter,
133        commit_failures_counter: IntCounter,
134    ) -> Self {
135        Self {
136            checkpoint_requested,
137            rebalance_count: AtomicU64::new(0),
138            rebalance_state,
139            rebalance_metric,
140            revoke_generation,
141            reader_paused,
142            commit_retry_needed,
143            offset_snapshot,
144            commits_counter,
145            commit_failures_counter,
146        }
147    }
148
149    /// Total rebalance events observed.
150    #[must_use]
151    pub fn rebalance_count(&self) -> u64 {
152        self.rebalance_count.load(Ordering::Relaxed)
153    }
154
155    /// Returns the shared revoke generation counter.
156    #[must_use]
157    pub fn revoke_generation(&self) -> &Arc<AtomicU64> {
158        &self.revoke_generation
159    }
160
161    /// Locks the rebalance state, recovering from poison.
162    fn lock_rebalance_state(&self) -> std::sync::MutexGuard<'_, RebalanceState> {
163        self.rebalance_state.lock().unwrap_or_else(|poisoned| {
164            warn!("rebalance_state mutex poisoned, recovering");
165            poisoned.into_inner()
166        })
167    }
168
169    /// Locks the offset snapshot, recovering from poison.
170    fn lock_offset_snapshot(&self) -> std::sync::MutexGuard<'_, super::offsets::OffsetTracker> {
171        self.offset_snapshot.lock().unwrap_or_else(|poisoned| {
172            warn!("offset_snapshot mutex poisoned, recovering");
173            poisoned.into_inner()
174        })
175    }
176}
177
178impl ClientContext for LaminarConsumerContext {}
179
180impl ConsumerContext for LaminarConsumerContext {
181    fn pre_rebalance(
182        &self,
183        _base_consumer: &rdkafka::consumer::BaseConsumer<Self>,
184        rebalance: &rdkafka::consumer::Rebalance<'_>,
185    ) {
186        use rdkafka::consumer::Rebalance;
187
188        match rebalance {
189            Rebalance::Revoke(tpl) => {
190                let count = tpl.count();
191                info!(
192                    partitions_revoked = count,
193                    "kafka rebalance: partitions being revoked, requesting checkpoint"
194                );
195                // Update shared rebalance state.
196                let partitions: Vec<(String, i32)> = tpl
197                    .elements()
198                    .iter()
199                    .map(|e| (e.topic().to_string(), e.partition()))
200                    .collect();
201                self.lock_rebalance_state().on_revoke(&partitions);
202                self.revoke_generation
203                    .fetch_add(1, std::sync::atomic::Ordering::Release);
204                self.rebalance_count
205                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
206                self.rebalance_metric
207                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
208                self.checkpoint_requested.store(true, Ordering::Release);
209            }
210            Rebalance::Assign(tpl) => {
211                let count = tpl.count();
212                info!(
213                    partitions_assigned = count,
214                    "kafka rebalance: new partitions assigned"
215                );
216                // Update shared rebalance state.
217                let partitions: Vec<(String, i32)> = tpl
218                    .elements()
219                    .iter()
220                    .map(|e| (e.topic().to_string(), e.partition()))
221                    .collect();
222                self.lock_rebalance_state().on_assign(&partitions);
223                self.rebalance_count
224                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
225                self.rebalance_metric
226                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
227            }
228            Rebalance::Error(msg) => {
229                warn!(error = %msg, "kafka rebalance error");
230            }
231        }
232    }
233
234    fn commit_callback(
235        &self,
236        result: rdkafka::error::KafkaResult<()>,
237        offsets: &rdkafka::TopicPartitionList,
238    ) {
239        match result {
240            Ok(()) => {
241                self.commits_counter.inc();
242                tracing::debug!(
243                    partition_count = offsets.count(),
244                    "broker offset commit confirmed"
245                );
246            }
247            Err(e) => {
248                self.commit_failures_counter.inc();
249                self.commit_retry_needed.store(true, Ordering::Release);
250                warn!(
251                    error = %e,
252                    partition_count = offsets.count(),
253                    "broker offset commit failed — scheduling sync retry"
254                );
255            }
256        }
257    }
258
259    fn post_rebalance(
260        &self,
261        base_consumer: &rdkafka::consumer::BaseConsumer<Self>,
262        rebalance: &rdkafka::consumer::Rebalance<'_>,
263    ) {
264        use rdkafka::consumer::Rebalance;
265
266        if let Rebalance::Assign(tpl) = rebalance {
267            // Seek assigned partitions to tracked offsets so we don't fall
268            // back to broker-stored group offsets (which may be stale or
269            // reset to earliest after a broker failure).
270            //
271            // Uses seek_partitions() instead of assign() to avoid clobbering
272            // the partition set under cooperative rebalancing, where the tpl
273            // contains only NEWLY assigned partitions (not the full set).
274            let assigned: Vec<(String, i32)> = tpl
275                .elements()
276                .iter()
277                .map(|e| (e.topic().to_string(), e.partition()))
278                .collect();
279
280            let seek_tpl = self.lock_offset_snapshot().to_seek_tpl(&assigned);
281
282            if seek_tpl.count() > 0 {
283                // Non-zero timeout: Duration::ZERO returns "In Progress" for every partition.
284                match base_consumer.seek_partitions(seek_tpl, std::time::Duration::from_secs(10)) {
285                    Ok(result) => {
286                        let errors: Vec<_> = result
287                            .elements()
288                            .iter()
289                            .filter(|e| e.error().is_err())
290                            .map(|e| format!("{}[{}]: {:?}", e.topic(), e.partition(), e.error()))
291                            .collect();
292                        if errors.is_empty() {
293                            info!(
294                                partition_count = result.count(),
295                                "seeked assigned partitions to tracked offsets"
296                            );
297                        } else {
298                            warn!(?errors, "some partitions failed to seek to tracked offsets");
299                        }
300                    }
301                    Err(e) => warn!(
302                        error = %e,
303                        "failed to seek assigned partitions to tracked offsets"
304                    ),
305                }
306            }
307
308            // Re-pause newly assigned partitions if backpressure is active.
309            if self.reader_paused.load(Ordering::Acquire) {
310                if let Err(e) = base_consumer.pause(tpl) {
311                    warn!(error = %e, "failed to re-pause newly assigned partitions");
312                } else {
313                    info!(
314                        partition_count = tpl.count(),
315                        "re-paused newly assigned partitions (reader backpressure active)"
316                    );
317                }
318            }
319        }
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326
327    #[test]
328    fn test_assign() {
329        let mut state = RebalanceState::new();
330        state.on_assign(&[
331            ("events".into(), 0),
332            ("events".into(), 1),
333            ("events".into(), 2),
334        ]);
335
336        assert_eq!(state.assigned_partitions().len(), 3);
337        assert!(state.is_assigned("events", 0));
338        assert!(state.is_assigned("events", 1));
339        assert!(state.is_assigned("events", 2));
340        assert!(!state.is_assigned("events", 3));
341        assert_eq!(state.rebalance_count(), 1);
342    }
343
344    #[test]
345    fn test_revoke() {
346        let mut state = RebalanceState::new();
347        state.on_assign(&[("events".into(), 0), ("events".into(), 1)]);
348        state.on_revoke(&[("events".into(), 1)]);
349
350        assert_eq!(state.assigned_partitions().len(), 1);
351        assert!(state.is_assigned("events", 0));
352        assert!(!state.is_assigned("events", 1));
353    }
354
355    #[test]
356    fn test_eager_reassign() {
357        let mut state = RebalanceState::new();
358        state.on_assign(&[("events".into(), 0), ("events".into(), 1)]);
359        // Eager rebalance: revoke all first, then assign new set
360        state.on_revoke(&[("events".into(), 0), ("events".into(), 1)]);
361        state.on_assign(&[("events".into(), 2), ("events".into(), 3)]);
362
363        assert_eq!(state.assigned_partitions().len(), 2);
364        assert!(!state.is_assigned("events", 0));
365        assert!(state.is_assigned("events", 2));
366        assert_eq!(state.rebalance_count(), 2);
367    }
368
369    #[test]
370    fn test_cooperative_assign() {
371        let mut state = RebalanceState::new();
372        state.on_assign(&[("events".into(), 0), ("events".into(), 1)]);
373        // Cooperative: only revoke subset, assign new subset
374        state.on_revoke(&[("events".into(), 1)]);
375        state.on_assign(&[("events".into(), 2)]);
376
377        assert_eq!(state.assigned_partitions().len(), 2);
378        assert!(state.is_assigned("events", 0)); // retained
379        assert!(!state.is_assigned("events", 1)); // revoked
380        assert!(state.is_assigned("events", 2)); // newly assigned
381    }
382
383    #[test]
384    fn test_empty_state() {
385        let state = RebalanceState::new();
386        assert_eq!(state.assigned_partitions().len(), 0);
387        assert_eq!(state.rebalance_count(), 0);
388        assert!(!state.is_assigned("events", 0));
389    }
390
391    fn make_context() -> (Arc<AtomicBool>, LaminarConsumerContext) {
392        let flag = Arc::new(AtomicBool::new(false));
393        let state = Arc::new(Mutex::new(RebalanceState::new()));
394        let metric = Arc::new(AtomicU64::new(0));
395        let revoke_gen = Arc::new(AtomicU64::new(0));
396        let reader_paused = Arc::new(AtomicBool::new(false));
397        let commit_retry = Arc::new(AtomicBool::new(false));
398        let offset_snapshot = Arc::new(Mutex::new(super::super::offsets::OffsetTracker::new()));
399        let commits = IntCounter::new("test_commits", "test").unwrap();
400        let commit_failures = IntCounter::new("test_commit_failures", "test").unwrap();
401        let ctx = LaminarConsumerContext::new(
402            Arc::clone(&flag),
403            state,
404            metric,
405            revoke_gen,
406            reader_paused,
407            commit_retry,
408            offset_snapshot,
409            commits,
410            commit_failures,
411        );
412        (flag, ctx)
413    }
414
415    #[test]
416    fn test_consumer_context_initial_state() {
417        let (flag, ctx) = make_context();
418        assert!(!flag.load(Ordering::Relaxed));
419        assert_eq!(ctx.rebalance_count(), 0);
420    }
421
422    #[test]
423    fn test_consumer_context_shared_flag() {
424        let (flag, _ctx) = make_context();
425
426        // Simulate what pre_rebalance(Revoke) does.
427        flag.store(true, Ordering::Relaxed);
428        assert!(flag.load(Ordering::Relaxed));
429
430        // Coordinator would swap-clear the flag.
431        assert!(flag.swap(false, Ordering::Relaxed));
432        assert!(!flag.load(Ordering::Relaxed));
433    }
434}