laminar_connectors/kafka/
rebalance.rs1use std::collections::HashSet;
11use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
12use std::sync::{Arc, Mutex};
13
14use prometheus::IntCounter;
15use rdkafka::consumer::{Consumer, ConsumerContext};
16use rdkafka::ClientContext;
17use tracing::{info, warn};
18
19#[derive(Debug, Clone, Default)]
21pub struct RebalanceState {
22 assigned: HashSet<(String, i32)>,
24 rebalance_count: u64,
26}
27
28impl RebalanceState {
29 #[must_use]
31 pub fn new() -> Self {
32 Self::default()
33 }
34
35 pub fn on_assign(&mut self, partitions: &[(String, i32)]) {
43 for (topic, partition) in partitions {
44 self.assigned.insert((topic.clone(), *partition));
45 }
46 self.rebalance_count += 1;
47 }
48
49 pub fn on_revoke(&mut self, partitions: &[(String, i32)]) {
53 for (topic, partition) in partitions {
54 self.assigned.remove(&(topic.clone(), *partition));
55 }
56 }
57
58 #[must_use]
60 pub fn assigned_partitions(&self) -> &HashSet<(String, i32)> {
61 &self.assigned
62 }
63
64 #[must_use]
66 pub fn rebalance_count(&self) -> u64 {
67 self.rebalance_count
68 }
69
70 #[must_use]
72 pub fn is_assigned(&self, topic: &str, partition: i32) -> bool {
73 self.assigned.contains(&(topic.to_string(), partition))
74 }
75}
76
77pub struct LaminarConsumerContext {
87 checkpoint_requested: Arc<AtomicBool>,
88 rebalance_count: AtomicU64,
89 rebalance_state: Arc<Mutex<RebalanceState>>,
91 rebalance_metric: Arc<AtomicU64>,
93 revoke_generation: Arc<AtomicU64>,
100 reader_paused: Arc<AtomicBool>,
104 commit_retry_needed: Arc<AtomicBool>,
107 offset_snapshot: Arc<Mutex<super::offsets::OffsetTracker>>,
111 commits_counter: IntCounter,
116 commit_failures_counter: IntCounter,
118}
119
120impl LaminarConsumerContext {
121 #[must_use]
123 #[allow(clippy::too_many_arguments)]
124 pub fn new(
125 checkpoint_requested: Arc<AtomicBool>,
126 rebalance_state: Arc<Mutex<RebalanceState>>,
127 rebalance_metric: Arc<AtomicU64>,
128 revoke_generation: Arc<AtomicU64>,
129 reader_paused: Arc<AtomicBool>,
130 commit_retry_needed: Arc<AtomicBool>,
131 offset_snapshot: Arc<Mutex<super::offsets::OffsetTracker>>,
132 commits_counter: IntCounter,
133 commit_failures_counter: IntCounter,
134 ) -> Self {
135 Self {
136 checkpoint_requested,
137 rebalance_count: AtomicU64::new(0),
138 rebalance_state,
139 rebalance_metric,
140 revoke_generation,
141 reader_paused,
142 commit_retry_needed,
143 offset_snapshot,
144 commits_counter,
145 commit_failures_counter,
146 }
147 }
148
149 #[must_use]
151 pub fn rebalance_count(&self) -> u64 {
152 self.rebalance_count.load(Ordering::Relaxed)
153 }
154
155 #[must_use]
157 pub fn revoke_generation(&self) -> &Arc<AtomicU64> {
158 &self.revoke_generation
159 }
160
161 fn lock_rebalance_state(&self) -> std::sync::MutexGuard<'_, RebalanceState> {
163 self.rebalance_state.lock().unwrap_or_else(|poisoned| {
164 warn!("rebalance_state mutex poisoned, recovering");
165 poisoned.into_inner()
166 })
167 }
168
169 fn lock_offset_snapshot(&self) -> std::sync::MutexGuard<'_, super::offsets::OffsetTracker> {
171 self.offset_snapshot.lock().unwrap_or_else(|poisoned| {
172 warn!("offset_snapshot mutex poisoned, recovering");
173 poisoned.into_inner()
174 })
175 }
176}
177
178impl ClientContext for LaminarConsumerContext {}
179
180impl ConsumerContext for LaminarConsumerContext {
181 fn pre_rebalance(
182 &self,
183 _base_consumer: &rdkafka::consumer::BaseConsumer<Self>,
184 rebalance: &rdkafka::consumer::Rebalance<'_>,
185 ) {
186 use rdkafka::consumer::Rebalance;
187
188 match rebalance {
189 Rebalance::Revoke(tpl) => {
190 let count = tpl.count();
191 info!(
192 partitions_revoked = count,
193 "kafka rebalance: partitions being revoked, requesting checkpoint"
194 );
195 let partitions: Vec<(String, i32)> = tpl
197 .elements()
198 .iter()
199 .map(|e| (e.topic().to_string(), e.partition()))
200 .collect();
201 self.lock_rebalance_state().on_revoke(&partitions);
202 self.revoke_generation
203 .fetch_add(1, std::sync::atomic::Ordering::Release);
204 self.rebalance_count
205 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
206 self.rebalance_metric
207 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
208 self.checkpoint_requested.store(true, Ordering::Release);
209 }
210 Rebalance::Assign(tpl) => {
211 let count = tpl.count();
212 info!(
213 partitions_assigned = count,
214 "kafka rebalance: new partitions assigned"
215 );
216 let partitions: Vec<(String, i32)> = tpl
218 .elements()
219 .iter()
220 .map(|e| (e.topic().to_string(), e.partition()))
221 .collect();
222 self.lock_rebalance_state().on_assign(&partitions);
223 self.rebalance_count
224 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
225 self.rebalance_metric
226 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
227 }
228 Rebalance::Error(msg) => {
229 warn!(error = %msg, "kafka rebalance error");
230 }
231 }
232 }
233
234 fn commit_callback(
235 &self,
236 result: rdkafka::error::KafkaResult<()>,
237 offsets: &rdkafka::TopicPartitionList,
238 ) {
239 match result {
240 Ok(()) => {
241 self.commits_counter.inc();
242 tracing::debug!(
243 partition_count = offsets.count(),
244 "broker offset commit confirmed"
245 );
246 }
247 Err(e) => {
248 self.commit_failures_counter.inc();
249 self.commit_retry_needed.store(true, Ordering::Release);
250 warn!(
251 error = %e,
252 partition_count = offsets.count(),
253 "broker offset commit failed — scheduling sync retry"
254 );
255 }
256 }
257 }
258
259 fn post_rebalance(
260 &self,
261 base_consumer: &rdkafka::consumer::BaseConsumer<Self>,
262 rebalance: &rdkafka::consumer::Rebalance<'_>,
263 ) {
264 use rdkafka::consumer::Rebalance;
265
266 if let Rebalance::Assign(tpl) = rebalance {
267 let assigned: Vec<(String, i32)> = tpl
275 .elements()
276 .iter()
277 .map(|e| (e.topic().to_string(), e.partition()))
278 .collect();
279
280 let seek_tpl = self.lock_offset_snapshot().to_seek_tpl(&assigned);
281
282 if seek_tpl.count() > 0 {
283 match base_consumer.seek_partitions(seek_tpl, std::time::Duration::from_secs(10)) {
285 Ok(result) => {
286 let errors: Vec<_> = result
287 .elements()
288 .iter()
289 .filter(|e| e.error().is_err())
290 .map(|e| format!("{}[{}]: {:?}", e.topic(), e.partition(), e.error()))
291 .collect();
292 if errors.is_empty() {
293 info!(
294 partition_count = result.count(),
295 "seeked assigned partitions to tracked offsets"
296 );
297 } else {
298 warn!(?errors, "some partitions failed to seek to tracked offsets");
299 }
300 }
301 Err(e) => warn!(
302 error = %e,
303 "failed to seek assigned partitions to tracked offsets"
304 ),
305 }
306 }
307
308 if self.reader_paused.load(Ordering::Acquire) {
310 if let Err(e) = base_consumer.pause(tpl) {
311 warn!(error = %e, "failed to re-pause newly assigned partitions");
312 } else {
313 info!(
314 partition_count = tpl.count(),
315 "re-paused newly assigned partitions (reader backpressure active)"
316 );
317 }
318 }
319 }
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326
327 #[test]
328 fn test_assign() {
329 let mut state = RebalanceState::new();
330 state.on_assign(&[
331 ("events".into(), 0),
332 ("events".into(), 1),
333 ("events".into(), 2),
334 ]);
335
336 assert_eq!(state.assigned_partitions().len(), 3);
337 assert!(state.is_assigned("events", 0));
338 assert!(state.is_assigned("events", 1));
339 assert!(state.is_assigned("events", 2));
340 assert!(!state.is_assigned("events", 3));
341 assert_eq!(state.rebalance_count(), 1);
342 }
343
344 #[test]
345 fn test_revoke() {
346 let mut state = RebalanceState::new();
347 state.on_assign(&[("events".into(), 0), ("events".into(), 1)]);
348 state.on_revoke(&[("events".into(), 1)]);
349
350 assert_eq!(state.assigned_partitions().len(), 1);
351 assert!(state.is_assigned("events", 0));
352 assert!(!state.is_assigned("events", 1));
353 }
354
355 #[test]
356 fn test_eager_reassign() {
357 let mut state = RebalanceState::new();
358 state.on_assign(&[("events".into(), 0), ("events".into(), 1)]);
359 state.on_revoke(&[("events".into(), 0), ("events".into(), 1)]);
361 state.on_assign(&[("events".into(), 2), ("events".into(), 3)]);
362
363 assert_eq!(state.assigned_partitions().len(), 2);
364 assert!(!state.is_assigned("events", 0));
365 assert!(state.is_assigned("events", 2));
366 assert_eq!(state.rebalance_count(), 2);
367 }
368
369 #[test]
370 fn test_cooperative_assign() {
371 let mut state = RebalanceState::new();
372 state.on_assign(&[("events".into(), 0), ("events".into(), 1)]);
373 state.on_revoke(&[("events".into(), 1)]);
375 state.on_assign(&[("events".into(), 2)]);
376
377 assert_eq!(state.assigned_partitions().len(), 2);
378 assert!(state.is_assigned("events", 0)); assert!(!state.is_assigned("events", 1)); assert!(state.is_assigned("events", 2)); }
382
383 #[test]
384 fn test_empty_state() {
385 let state = RebalanceState::new();
386 assert_eq!(state.assigned_partitions().len(), 0);
387 assert_eq!(state.rebalance_count(), 0);
388 assert!(!state.is_assigned("events", 0));
389 }
390
391 fn make_context() -> (Arc<AtomicBool>, LaminarConsumerContext) {
392 let flag = Arc::new(AtomicBool::new(false));
393 let state = Arc::new(Mutex::new(RebalanceState::new()));
394 let metric = Arc::new(AtomicU64::new(0));
395 let revoke_gen = Arc::new(AtomicU64::new(0));
396 let reader_paused = Arc::new(AtomicBool::new(false));
397 let commit_retry = Arc::new(AtomicBool::new(false));
398 let offset_snapshot = Arc::new(Mutex::new(super::super::offsets::OffsetTracker::new()));
399 let commits = IntCounter::new("test_commits", "test").unwrap();
400 let commit_failures = IntCounter::new("test_commit_failures", "test").unwrap();
401 let ctx = LaminarConsumerContext::new(
402 Arc::clone(&flag),
403 state,
404 metric,
405 revoke_gen,
406 reader_paused,
407 commit_retry,
408 offset_snapshot,
409 commits,
410 commit_failures,
411 );
412 (flag, ctx)
413 }
414
415 #[test]
416 fn test_consumer_context_initial_state() {
417 let (flag, ctx) = make_context();
418 assert!(!flag.load(Ordering::Relaxed));
419 assert_eq!(ctx.rebalance_count(), 0);
420 }
421
422 #[test]
423 fn test_consumer_context_shared_flag() {
424 let (flag, _ctx) = make_context();
425
426 flag.store(true, Ordering::Relaxed);
428 assert!(flag.load(Ordering::Relaxed));
429
430 assert!(flag.swap(false, Ordering::Relaxed));
432 assert!(!flag.load(Ordering::Relaxed));
433 }
434}