laminar_connectors/kafka/
rebalance.rs1use std::collections::HashSet;
11use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
12use std::sync::{Arc, Mutex};
13
14use rdkafka::consumer::ConsumerContext;
15use rdkafka::ClientContext;
16use tracing::{info, warn};
17
18#[derive(Debug, Clone, Default)]
20pub struct RebalanceState {
21 assigned: HashSet<(String, i32)>,
23 rebalance_count: u64,
25}
26
27impl RebalanceState {
28 #[must_use]
30 pub fn new() -> Self {
31 Self::default()
32 }
33
34 pub fn on_assign(&mut self, partitions: &[(String, i32)]) {
38 self.assigned.clear();
39 for (topic, partition) in partitions {
40 self.assigned.insert((topic.clone(), *partition));
41 }
42 self.rebalance_count += 1;
43 }
44
45 pub fn on_revoke(&mut self, partitions: &[(String, i32)]) {
49 for (topic, partition) in partitions {
50 self.assigned.remove(&(topic.clone(), *partition));
51 }
52 }
53
54 #[must_use]
56 pub fn assigned_partitions(&self) -> &HashSet<(String, i32)> {
57 &self.assigned
58 }
59
60 #[must_use]
62 pub fn rebalance_count(&self) -> u64 {
63 self.rebalance_count
64 }
65
66 #[must_use]
68 pub fn is_assigned(&self, topic: &str, partition: i32) -> bool {
69 self.assigned.contains(&(topic.to_string(), partition))
70 }
71}
72
73pub struct LaminarConsumerContext {
83 checkpoint_requested: Arc<AtomicBool>,
84 rebalance_count: AtomicU64,
85 rebalance_state: Arc<Mutex<RebalanceState>>,
87 rebalance_metric: Arc<AtomicU64>,
89 revoke_generation: Arc<AtomicU64>,
96}
97
98impl LaminarConsumerContext {
99 #[must_use]
101 pub fn new(
102 checkpoint_requested: Arc<AtomicBool>,
103 rebalance_state: Arc<Mutex<RebalanceState>>,
104 rebalance_metric: Arc<AtomicU64>,
105 revoke_generation: Arc<AtomicU64>,
106 ) -> Self {
107 Self {
108 checkpoint_requested,
109 rebalance_count: AtomicU64::new(0),
110 rebalance_state,
111 rebalance_metric,
112 revoke_generation,
113 }
114 }
115
116 #[must_use]
118 pub fn rebalance_count(&self) -> u64 {
119 self.rebalance_count.load(Ordering::Relaxed)
120 }
121
122 #[must_use]
124 pub fn revoke_generation(&self) -> &Arc<AtomicU64> {
125 &self.revoke_generation
126 }
127}
128
129impl ClientContext for LaminarConsumerContext {}
130
131impl ConsumerContext for LaminarConsumerContext {
132 fn pre_rebalance(
133 &self,
134 _base_consumer: &rdkafka::consumer::BaseConsumer<Self>,
135 rebalance: &rdkafka::consumer::Rebalance<'_>,
136 ) {
137 use rdkafka::consumer::Rebalance;
138
139 match rebalance {
140 Rebalance::Revoke(tpl) => {
141 let count = tpl.count();
142 info!(
143 partitions_revoked = count,
144 "kafka rebalance: partitions being revoked, requesting checkpoint"
145 );
146 let partitions: Vec<(String, i32)> = tpl
148 .elements()
149 .iter()
150 .map(|e| (e.topic().to_string(), e.partition()))
151 .collect();
152 match self.rebalance_state.lock() {
153 Ok(mut state) => state.on_revoke(&partitions),
154 Err(poisoned) => {
155 warn!("rebalance_state mutex poisoned, recovering");
156 poisoned.into_inner().on_revoke(&partitions);
157 }
158 }
159 self.revoke_generation
160 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
161 self.rebalance_count
162 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
163 self.rebalance_metric
164 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
165 self.checkpoint_requested.store(true, Ordering::Release);
166 }
167 Rebalance::Assign(tpl) => {
168 let count = tpl.count();
169 info!(
170 partitions_assigned = count,
171 "kafka rebalance: new partitions assigned"
172 );
173 let partitions: Vec<(String, i32)> = tpl
175 .elements()
176 .iter()
177 .map(|e| (e.topic().to_string(), e.partition()))
178 .collect();
179 match self.rebalance_state.lock() {
180 Ok(mut state) => state.on_assign(&partitions),
181 Err(poisoned) => {
182 warn!("rebalance_state mutex poisoned, recovering");
183 poisoned.into_inner().on_assign(&partitions);
184 }
185 }
186 self.rebalance_count
187 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
188 self.rebalance_metric
189 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
190 }
191 Rebalance::Error(msg) => {
192 warn!(error = %msg, "kafka rebalance error");
193 }
194 }
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 #[test]
203 fn test_assign() {
204 let mut state = RebalanceState::new();
205 state.on_assign(&[
206 ("events".into(), 0),
207 ("events".into(), 1),
208 ("events".into(), 2),
209 ]);
210
211 assert_eq!(state.assigned_partitions().len(), 3);
212 assert!(state.is_assigned("events", 0));
213 assert!(state.is_assigned("events", 1));
214 assert!(state.is_assigned("events", 2));
215 assert!(!state.is_assigned("events", 3));
216 assert_eq!(state.rebalance_count(), 1);
217 }
218
219 #[test]
220 fn test_revoke() {
221 let mut state = RebalanceState::new();
222 state.on_assign(&[("events".into(), 0), ("events".into(), 1)]);
223 state.on_revoke(&[("events".into(), 1)]);
224
225 assert_eq!(state.assigned_partitions().len(), 1);
226 assert!(state.is_assigned("events", 0));
227 assert!(!state.is_assigned("events", 1));
228 }
229
230 #[test]
231 fn test_reassign() {
232 let mut state = RebalanceState::new();
233 state.on_assign(&[("events".into(), 0), ("events".into(), 1)]);
234 state.on_assign(&[("events".into(), 2), ("events".into(), 3)]);
236
237 assert_eq!(state.assigned_partitions().len(), 2);
238 assert!(!state.is_assigned("events", 0));
239 assert!(state.is_assigned("events", 2));
240 assert_eq!(state.rebalance_count(), 2);
241 }
242
243 #[test]
244 fn test_empty_state() {
245 let state = RebalanceState::new();
246 assert_eq!(state.assigned_partitions().len(), 0);
247 assert_eq!(state.rebalance_count(), 0);
248 assert!(!state.is_assigned("events", 0));
249 }
250
251 fn make_context() -> (Arc<AtomicBool>, LaminarConsumerContext) {
252 let flag = Arc::new(AtomicBool::new(false));
253 let state = Arc::new(Mutex::new(RebalanceState::new()));
254 let metric = Arc::new(AtomicU64::new(0));
255 let revoke_gen = Arc::new(AtomicU64::new(0));
256 let ctx = LaminarConsumerContext::new(Arc::clone(&flag), state, metric, revoke_gen);
257 (flag, ctx)
258 }
259
260 #[test]
261 fn test_consumer_context_initial_state() {
262 let (flag, ctx) = make_context();
263 assert!(!flag.load(Ordering::Relaxed));
264 assert_eq!(ctx.rebalance_count(), 0);
265 }
266
267 #[test]
268 fn test_consumer_context_shared_flag() {
269 let (flag, _ctx) = make_context();
270
271 flag.store(true, Ordering::Relaxed);
273 assert!(flag.load(Ordering::Relaxed));
274
275 assert!(flag.swap(false, Ordering::Relaxed));
277 assert!(!flag.load(Ordering::Relaxed));
278 }
279}