Skip to main content

laminar_connectors/kafka/
rebalance.rs

1//! Kafka consumer group rebalance state tracking.
2//!
3//! [`RebalanceState`] tracks which topic-partitions are currently
4//! assigned to this consumer and counts rebalance events.
5//!
6//! [`LaminarConsumerContext`] is an rdkafka `ConsumerContext` that
7//! signals a checkpoint request on partition revocation, enabling
8//! the pipeline to persist offsets before ownership changes.
9
10use std::collections::HashSet;
11use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
12use std::sync::{Arc, Mutex};
13
14use rdkafka::consumer::ConsumerContext;
15use rdkafka::ClientContext;
16use tracing::{info, warn};
17
18/// Tracks partition assignments across consumer group rebalances.
19#[derive(Debug, Clone, Default)]
20pub struct RebalanceState {
21    /// Currently assigned (topic, partition) pairs.
22    assigned: HashSet<(String, i32)>,
23    /// Total number of rebalance events.
24    rebalance_count: u64,
25}
26
27impl RebalanceState {
28    /// Starts with no partitions assigned.
29    #[must_use]
30    pub fn new() -> Self {
31        Self::default()
32    }
33
34    /// Handles a partition assignment event.
35    ///
36    /// Replaces the current assignment set and increments the rebalance counter.
37    pub fn on_assign(&mut self, partitions: &[(String, i32)]) {
38        self.assigned.clear();
39        for (topic, partition) in partitions {
40            self.assigned.insert((topic.clone(), *partition));
41        }
42        self.rebalance_count += 1;
43    }
44
45    /// Handles a partition revocation event.
46    ///
47    /// Removes the specified partitions from the assignment set.
48    pub fn on_revoke(&mut self, partitions: &[(String, i32)]) {
49        for (topic, partition) in partitions {
50            self.assigned.remove(&(topic.clone(), *partition));
51        }
52    }
53
54    /// Returns the set of currently assigned partitions.
55    #[must_use]
56    pub fn assigned_partitions(&self) -> &HashSet<(String, i32)> {
57        &self.assigned
58    }
59
60    /// Returns the total number of rebalance events.
61    #[must_use]
62    pub fn rebalance_count(&self) -> u64 {
63        self.rebalance_count
64    }
65
66    /// Returns `true` if the given topic-partition is currently assigned.
67    #[must_use]
68    pub fn is_assigned(&self, topic: &str, partition: i32) -> bool {
69        self.assigned.contains(&(topic.to_string(), partition))
70    }
71}
72
73/// rdkafka consumer context that signals a checkpoint on partition revocation.
74///
75/// When a consumer group rebalance revokes partitions from this consumer,
76/// the context notifies the pipeline coordinator to trigger an immediate
77/// checkpoint before the partitions are reassigned. This prevents offset
78/// loss during rebalance.
79///
80/// Rebalance callbacks run on rdkafka's background thread, so all shared
81/// state uses `Arc` + atomic types for thread safety.
82pub struct LaminarConsumerContext {
83    checkpoint_requested: Arc<AtomicBool>,
84    rebalance_count: AtomicU64,
85    /// Shared rebalance state updated on Assign/Revoke events.
86    rebalance_state: Arc<Mutex<RebalanceState>>,
87    /// Shared rebalance event counter for source-level metrics.
88    rebalance_metric: Arc<AtomicU64>,
89    /// Monotonically increasing generation bumped on each Revoke event.
90    ///
91    /// Allows lock-free detection of revoke events from the hot path
92    /// (`poll_batch`) — the source compares its cached generation against
93    /// this value using `Relaxed` ordering, and only locks the mutex when
94    /// a change is detected.
95    revoke_generation: Arc<AtomicU64>,
96}
97
98impl LaminarConsumerContext {
99    /// Wires checkpoint signaling, partition tracking, and rebalance metrics.
100    #[must_use]
101    pub fn new(
102        checkpoint_requested: Arc<AtomicBool>,
103        rebalance_state: Arc<Mutex<RebalanceState>>,
104        rebalance_metric: Arc<AtomicU64>,
105        revoke_generation: Arc<AtomicU64>,
106    ) -> Self {
107        Self {
108            checkpoint_requested,
109            rebalance_count: AtomicU64::new(0),
110            rebalance_state,
111            rebalance_metric,
112            revoke_generation,
113        }
114    }
115
116    /// Total rebalance events observed.
117    #[must_use]
118    pub fn rebalance_count(&self) -> u64 {
119        self.rebalance_count.load(Ordering::Relaxed)
120    }
121
122    /// Returns the shared revoke generation counter.
123    #[must_use]
124    pub fn revoke_generation(&self) -> &Arc<AtomicU64> {
125        &self.revoke_generation
126    }
127}
128
129impl ClientContext for LaminarConsumerContext {}
130
131impl ConsumerContext for LaminarConsumerContext {
132    fn pre_rebalance(
133        &self,
134        _base_consumer: &rdkafka::consumer::BaseConsumer<Self>,
135        rebalance: &rdkafka::consumer::Rebalance<'_>,
136    ) {
137        use rdkafka::consumer::Rebalance;
138
139        match rebalance {
140            Rebalance::Revoke(tpl) => {
141                let count = tpl.count();
142                info!(
143                    partitions_revoked = count,
144                    "kafka rebalance: partitions being revoked, requesting checkpoint"
145                );
146                // Update shared rebalance state.
147                let partitions: Vec<(String, i32)> = tpl
148                    .elements()
149                    .iter()
150                    .map(|e| (e.topic().to_string(), e.partition()))
151                    .collect();
152                match self.rebalance_state.lock() {
153                    Ok(mut state) => state.on_revoke(&partitions),
154                    Err(poisoned) => {
155                        warn!("rebalance_state mutex poisoned, recovering");
156                        poisoned.into_inner().on_revoke(&partitions);
157                    }
158                }
159                self.revoke_generation
160                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
161                self.rebalance_count
162                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
163                self.rebalance_metric
164                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
165                self.checkpoint_requested.store(true, Ordering::Release);
166            }
167            Rebalance::Assign(tpl) => {
168                let count = tpl.count();
169                info!(
170                    partitions_assigned = count,
171                    "kafka rebalance: new partitions assigned"
172                );
173                // Update shared rebalance state.
174                let partitions: Vec<(String, i32)> = tpl
175                    .elements()
176                    .iter()
177                    .map(|e| (e.topic().to_string(), e.partition()))
178                    .collect();
179                match self.rebalance_state.lock() {
180                    Ok(mut state) => state.on_assign(&partitions),
181                    Err(poisoned) => {
182                        warn!("rebalance_state mutex poisoned, recovering");
183                        poisoned.into_inner().on_assign(&partitions);
184                    }
185                }
186                self.rebalance_count
187                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
188                self.rebalance_metric
189                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
190            }
191            Rebalance::Error(msg) => {
192                warn!(error = %msg, "kafka rebalance error");
193            }
194        }
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    #[test]
203    fn test_assign() {
204        let mut state = RebalanceState::new();
205        state.on_assign(&[
206            ("events".into(), 0),
207            ("events".into(), 1),
208            ("events".into(), 2),
209        ]);
210
211        assert_eq!(state.assigned_partitions().len(), 3);
212        assert!(state.is_assigned("events", 0));
213        assert!(state.is_assigned("events", 1));
214        assert!(state.is_assigned("events", 2));
215        assert!(!state.is_assigned("events", 3));
216        assert_eq!(state.rebalance_count(), 1);
217    }
218
219    #[test]
220    fn test_revoke() {
221        let mut state = RebalanceState::new();
222        state.on_assign(&[("events".into(), 0), ("events".into(), 1)]);
223        state.on_revoke(&[("events".into(), 1)]);
224
225        assert_eq!(state.assigned_partitions().len(), 1);
226        assert!(state.is_assigned("events", 0));
227        assert!(!state.is_assigned("events", 1));
228    }
229
230    #[test]
231    fn test_reassign() {
232        let mut state = RebalanceState::new();
233        state.on_assign(&[("events".into(), 0), ("events".into(), 1)]);
234        // New assignment replaces old
235        state.on_assign(&[("events".into(), 2), ("events".into(), 3)]);
236
237        assert_eq!(state.assigned_partitions().len(), 2);
238        assert!(!state.is_assigned("events", 0));
239        assert!(state.is_assigned("events", 2));
240        assert_eq!(state.rebalance_count(), 2);
241    }
242
243    #[test]
244    fn test_empty_state() {
245        let state = RebalanceState::new();
246        assert_eq!(state.assigned_partitions().len(), 0);
247        assert_eq!(state.rebalance_count(), 0);
248        assert!(!state.is_assigned("events", 0));
249    }
250
251    fn make_context() -> (Arc<AtomicBool>, LaminarConsumerContext) {
252        let flag = Arc::new(AtomicBool::new(false));
253        let state = Arc::new(Mutex::new(RebalanceState::new()));
254        let metric = Arc::new(AtomicU64::new(0));
255        let revoke_gen = Arc::new(AtomicU64::new(0));
256        let ctx = LaminarConsumerContext::new(Arc::clone(&flag), state, metric, revoke_gen);
257        (flag, ctx)
258    }
259
260    #[test]
261    fn test_consumer_context_initial_state() {
262        let (flag, ctx) = make_context();
263        assert!(!flag.load(Ordering::Relaxed));
264        assert_eq!(ctx.rebalance_count(), 0);
265    }
266
267    #[test]
268    fn test_consumer_context_shared_flag() {
269        let (flag, _ctx) = make_context();
270
271        // Simulate what pre_rebalance(Revoke) does.
272        flag.store(true, Ordering::Relaxed);
273        assert!(flag.load(Ordering::Relaxed));
274
275        // Coordinator would swap-clear the flag.
276        assert!(flag.swap(false, Ordering::Relaxed));
277        assert!(!flag.load(Ordering::Relaxed));
278    }
279}