laminar_connectors/kafka/
rebalance.rs1use std::collections::HashSet;
11use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
12use std::sync::{Arc, Mutex};
13
14use prometheus::IntCounter;
15use rdkafka::consumer::{Consumer, ConsumerContext};
16use rdkafka::ClientContext;
17use tracing::{info, warn};
18
19#[derive(Debug, Clone, Default)]
21pub struct RebalanceState {
22 assigned: HashSet<(String, i32)>,
24 rebalance_count: u64,
26}
27
28impl RebalanceState {
29 #[must_use]
31 pub fn new() -> Self {
32 Self::default()
33 }
34
35 pub fn on_assign(&mut self, partitions: &[(String, i32)]) {
43 for (topic, partition) in partitions {
44 self.assigned.insert((topic.clone(), *partition));
45 }
46 self.rebalance_count += 1;
47 }
48
49 pub fn on_revoke(&mut self, partitions: &[(String, i32)]) {
53 for (topic, partition) in partitions {
54 self.assigned.remove(&(topic.clone(), *partition));
55 }
56 }
57
58 #[must_use]
60 pub fn assigned_partitions(&self) -> &HashSet<(String, i32)> {
61 &self.assigned
62 }
63
64 #[must_use]
66 pub fn rebalance_count(&self) -> u64 {
67 self.rebalance_count
68 }
69
70 #[must_use]
72 pub fn is_assigned(&self, topic: &str, partition: i32) -> bool {
73 self.assigned.contains(&(topic.to_string(), partition))
74 }
75}
76
77pub struct LaminarConsumerContext {
87 checkpoint_requested: Arc<AtomicBool>,
88 rebalance_count: AtomicU64,
89 rebalance_state: Arc<Mutex<RebalanceState>>,
91 rebalance_metric: Arc<AtomicU64>,
93 revoke_generation: Arc<AtomicU64>,
100 reader_paused: Arc<AtomicBool>,
104 offset_snapshot: Arc<Mutex<super::offsets::OffsetTracker>>,
108 commits_counter: IntCounter,
112 commit_failures_counter: IntCounter,
114}
115
116impl LaminarConsumerContext {
117 #[must_use]
119 #[allow(clippy::too_many_arguments)]
120 pub fn new(
121 checkpoint_requested: Arc<AtomicBool>,
122 rebalance_state: Arc<Mutex<RebalanceState>>,
123 rebalance_metric: Arc<AtomicU64>,
124 revoke_generation: Arc<AtomicU64>,
125 reader_paused: Arc<AtomicBool>,
126 offset_snapshot: Arc<Mutex<super::offsets::OffsetTracker>>,
127 commits_counter: IntCounter,
128 commit_failures_counter: IntCounter,
129 ) -> Self {
130 Self {
131 checkpoint_requested,
132 rebalance_count: AtomicU64::new(0),
133 rebalance_state,
134 rebalance_metric,
135 revoke_generation,
136 reader_paused,
137 offset_snapshot,
138 commits_counter,
139 commit_failures_counter,
140 }
141 }
142
143 #[must_use]
145 pub fn rebalance_count(&self) -> u64 {
146 self.rebalance_count.load(Ordering::Relaxed)
147 }
148
149 #[must_use]
151 pub fn revoke_generation(&self) -> &Arc<AtomicU64> {
152 &self.revoke_generation
153 }
154
155 fn lock_rebalance_state(&self) -> std::sync::MutexGuard<'_, RebalanceState> {
157 self.rebalance_state.lock().unwrap_or_else(|poisoned| {
158 warn!("rebalance_state mutex poisoned, recovering");
159 poisoned.into_inner()
160 })
161 }
162
163 fn lock_offset_snapshot(&self) -> std::sync::MutexGuard<'_, super::offsets::OffsetTracker> {
165 self.offset_snapshot.lock().unwrap_or_else(|poisoned| {
166 warn!("offset_snapshot mutex poisoned, recovering");
167 poisoned.into_inner()
168 })
169 }
170}
171
172impl ClientContext for LaminarConsumerContext {}
173
174impl ConsumerContext for LaminarConsumerContext {
175 fn pre_rebalance(
176 &self,
177 _base_consumer: &rdkafka::consumer::BaseConsumer<Self>,
178 rebalance: &rdkafka::consumer::Rebalance<'_>,
179 ) {
180 use rdkafka::consumer::Rebalance;
181
182 match rebalance {
183 Rebalance::Revoke(tpl) => {
184 let count = tpl.count();
185 info!(
186 partitions_revoked = count,
187 "kafka rebalance: partitions being revoked, requesting checkpoint"
188 );
189 let partitions: Vec<(String, i32)> = tpl
191 .elements()
192 .iter()
193 .map(|e| (e.topic().to_string(), e.partition()))
194 .collect();
195 self.lock_rebalance_state().on_revoke(&partitions);
196 self.revoke_generation
197 .fetch_add(1, std::sync::atomic::Ordering::Release);
198 self.rebalance_count
199 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
200 self.rebalance_metric
201 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
202 self.checkpoint_requested.store(true, Ordering::Release);
203 }
204 Rebalance::Assign(tpl) => {
205 let count = tpl.count();
206 info!(
207 partitions_assigned = count,
208 "kafka rebalance: new partitions assigned"
209 );
210 let partitions: Vec<(String, i32)> = tpl
212 .elements()
213 .iter()
214 .map(|e| (e.topic().to_string(), e.partition()))
215 .collect();
216 self.lock_rebalance_state().on_assign(&partitions);
217 self.rebalance_count
218 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
219 self.rebalance_metric
220 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
221 }
222 Rebalance::Error(msg) => {
223 warn!(error = %msg, "kafka rebalance error");
224 }
225 }
226 }
227
228 fn commit_callback(
229 &self,
230 result: rdkafka::error::KafkaResult<()>,
231 offsets: &rdkafka::TopicPartitionList,
232 ) {
233 match result {
234 Ok(()) => {
235 self.commits_counter.inc();
236 tracing::debug!(
237 partition_count = offsets.count(),
238 "broker offset commit confirmed"
239 );
240 }
241 Err(e) => {
242 self.commit_failures_counter.inc();
243 warn!(
244 error = %e,
245 partition_count = offsets.count(),
246 "broker offset commit failed (callback)"
247 );
248 }
249 }
250 }
251
252 fn post_rebalance(
253 &self,
254 base_consumer: &rdkafka::consumer::BaseConsumer<Self>,
255 rebalance: &rdkafka::consumer::Rebalance<'_>,
256 ) {
257 use rdkafka::consumer::Rebalance;
258
259 if let Rebalance::Assign(tpl) = rebalance {
260 let assigned: Vec<(String, i32)> = tpl
268 .elements()
269 .iter()
270 .map(|e| (e.topic().to_string(), e.partition()))
271 .collect();
272
273 let seek_tpl = self.lock_offset_snapshot().to_seek_tpl(&assigned);
274
275 if seek_tpl.count() > 0 {
276 match base_consumer.seek_partitions(seek_tpl, std::time::Duration::from_secs(10)) {
278 Ok(result) => {
279 let errors: Vec<_> = result
280 .elements()
281 .iter()
282 .filter(|e| e.error().is_err())
283 .map(|e| format!("{}[{}]: {:?}", e.topic(), e.partition(), e.error()))
284 .collect();
285 if errors.is_empty() {
286 info!(
287 partition_count = result.count(),
288 "seeked assigned partitions to tracked offsets"
289 );
290 } else {
291 warn!(?errors, "some partitions failed to seek to tracked offsets");
292 }
293 }
294 Err(e) => warn!(
295 error = %e,
296 "failed to seek assigned partitions to tracked offsets"
297 ),
298 }
299 }
300
301 if self.reader_paused.load(Ordering::Acquire) {
303 if let Err(e) = base_consumer.pause(tpl) {
304 warn!(error = %e, "failed to re-pause newly assigned partitions");
305 } else {
306 info!(
307 partition_count = tpl.count(),
308 "re-paused newly assigned partitions (reader backpressure active)"
309 );
310 }
311 }
312 }
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319
320 #[test]
321 fn test_assign() {
322 let mut state = RebalanceState::new();
323 state.on_assign(&[
324 ("events".into(), 0),
325 ("events".into(), 1),
326 ("events".into(), 2),
327 ]);
328
329 assert_eq!(state.assigned_partitions().len(), 3);
330 assert!(state.is_assigned("events", 0));
331 assert!(state.is_assigned("events", 1));
332 assert!(state.is_assigned("events", 2));
333 assert!(!state.is_assigned("events", 3));
334 assert_eq!(state.rebalance_count(), 1);
335 }
336
337 #[test]
338 fn test_revoke() {
339 let mut state = RebalanceState::new();
340 state.on_assign(&[("events".into(), 0), ("events".into(), 1)]);
341 state.on_revoke(&[("events".into(), 1)]);
342
343 assert_eq!(state.assigned_partitions().len(), 1);
344 assert!(state.is_assigned("events", 0));
345 assert!(!state.is_assigned("events", 1));
346 }
347
348 #[test]
349 fn test_eager_reassign() {
350 let mut state = RebalanceState::new();
351 state.on_assign(&[("events".into(), 0), ("events".into(), 1)]);
352 state.on_revoke(&[("events".into(), 0), ("events".into(), 1)]);
354 state.on_assign(&[("events".into(), 2), ("events".into(), 3)]);
355
356 assert_eq!(state.assigned_partitions().len(), 2);
357 assert!(!state.is_assigned("events", 0));
358 assert!(state.is_assigned("events", 2));
359 assert_eq!(state.rebalance_count(), 2);
360 }
361
362 #[test]
363 fn test_cooperative_assign() {
364 let mut state = RebalanceState::new();
365 state.on_assign(&[("events".into(), 0), ("events".into(), 1)]);
366 state.on_revoke(&[("events".into(), 1)]);
368 state.on_assign(&[("events".into(), 2)]);
369
370 assert_eq!(state.assigned_partitions().len(), 2);
371 assert!(state.is_assigned("events", 0)); assert!(!state.is_assigned("events", 1)); assert!(state.is_assigned("events", 2)); }
375
376 #[test]
377 fn test_empty_state() {
378 let state = RebalanceState::new();
379 assert_eq!(state.assigned_partitions().len(), 0);
380 assert_eq!(state.rebalance_count(), 0);
381 assert!(!state.is_assigned("events", 0));
382 }
383
384 fn make_context() -> (Arc<AtomicBool>, LaminarConsumerContext) {
385 let flag = Arc::new(AtomicBool::new(false));
386 let state = Arc::new(Mutex::new(RebalanceState::new()));
387 let metric = Arc::new(AtomicU64::new(0));
388 let revoke_gen = Arc::new(AtomicU64::new(0));
389 let reader_paused = Arc::new(AtomicBool::new(false));
390 let offset_snapshot = Arc::new(Mutex::new(super::super::offsets::OffsetTracker::new()));
391 let commits = IntCounter::new("test_commits", "test").unwrap();
392 let commit_failures = IntCounter::new("test_commit_failures", "test").unwrap();
393 let ctx = LaminarConsumerContext::new(
394 Arc::clone(&flag),
395 state,
396 metric,
397 revoke_gen,
398 reader_paused,
399 offset_snapshot,
400 commits,
401 commit_failures,
402 );
403 (flag, ctx)
404 }
405
406 #[test]
407 fn test_consumer_context_initial_state() {
408 let (flag, ctx) = make_context();
409 assert!(!flag.load(Ordering::Relaxed));
410 assert_eq!(ctx.rebalance_count(), 0);
411 }
412
413 #[test]
414 fn test_consumer_context_shared_flag() {
415 let (flag, _ctx) = make_context();
416
417 flag.store(true, Ordering::Relaxed);
419 assert!(flag.load(Ordering::Relaxed));
420
421 assert!(flag.swap(false, Ordering::Relaxed));
423 assert!(!flag.load(Ordering::Relaxed));
424 }
425}