1use std::collections::HashMap;
22use std::sync::Arc;
23
24use parking_lot::RwLock;
25use tokio::sync::watch;
26use tokio_util::sync::CancellationToken;
27
28use laminar_core::delta::discovery::{
29 Discovery, DiscoveryError, NodeId, NodeInfo, NodeMetadata, NodeState,
30};
31
32#[derive(Debug, Clone)]
34pub struct KafkaDiscoveryConfig {
35 pub bootstrap_servers: String,
37 pub group_id: String,
39 pub discovery_topic: String,
41 pub heartbeat_interval_ms: u64,
43 pub missed_heartbeat_threshold: u32,
45 pub dead_heartbeat_threshold: u32,
47 pub security_protocol: String,
49 pub sasl_mechanism: Option<String>,
51 pub sasl_username: Option<String>,
53 pub sasl_password: Option<String>,
55 pub node_id: NodeId,
57 pub rpc_address: String,
59 pub raft_address: String,
61 pub node_metadata: NodeMetadata,
63}
64
65impl Default for KafkaDiscoveryConfig {
66 fn default() -> Self {
67 Self {
68 bootstrap_servers: "localhost:9092".to_string(),
69 group_id: "laminardb-discovery".to_string(),
70 discovery_topic: "_laminardb_discovery".to_string(),
71 heartbeat_interval_ms: 1000,
72 missed_heartbeat_threshold: 3,
73 dead_heartbeat_threshold: 10,
74 security_protocol: "plaintext".to_string(),
75 sasl_mechanism: None,
76 sasl_username: None,
77 sasl_password: None,
78 node_id: NodeId(1),
79 rpc_address: "127.0.0.1:9000".to_string(),
80 raft_address: "127.0.0.1:9001".to_string(),
81 node_metadata: NodeMetadata::default(),
82 }
83 }
84}
85
86pub struct KafkaDiscovery {
92 config: KafkaDiscoveryConfig,
94 peers: Arc<RwLock<HashMap<u64, NodeInfo>>>,
96 membership_tx: watch::Sender<Vec<NodeInfo>>,
98 membership_rx: watch::Receiver<Vec<NodeInfo>>,
100 cancel: CancellationToken,
102 started: bool,
104}
105
106impl std::fmt::Debug for KafkaDiscovery {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 f.debug_struct("KafkaDiscovery")
109 .field("config", &self.config)
110 .field("started", &self.started)
111 .finish_non_exhaustive()
112 }
113}
114
115impl KafkaDiscovery {
116 #[must_use]
118 pub fn new(config: KafkaDiscoveryConfig) -> Self {
119 let (membership_tx, membership_rx) = watch::channel(Vec::new());
120 Self {
121 config,
122 peers: Arc::new(RwLock::new(HashMap::new())),
123 membership_tx,
124 membership_rx,
125 cancel: CancellationToken::new(),
126 started: false,
127 }
128 }
129
130 #[must_use]
132 pub fn discovery_topic(&self) -> &str {
133 &self.config.discovery_topic
134 }
135
136 #[must_use]
138 pub fn group_id(&self) -> &str {
139 &self.config.group_id
140 }
141
142 #[must_use]
144 pub fn local_node_info(&self) -> NodeInfo {
145 NodeInfo {
146 id: self.config.node_id,
147 name: format!("node-{}", self.config.node_id.0),
148 rpc_address: self.config.rpc_address.clone(),
149 raft_address: self.config.raft_address.clone(),
150 state: NodeState::Active,
151 metadata: self.config.node_metadata.clone(),
152 last_heartbeat_ms: chrono::Utc::now().timestamp_millis(),
153 }
154 }
155
156 #[must_use]
160 pub fn process_heartbeat(&self, info: NodeInfo) -> bool {
161 if info.id == self.config.node_id {
162 return false; }
164
165 let is_new = {
166 let mut peers = self.peers.write();
167 let existing = peers.insert(info.id.0, info);
168 existing.is_none()
169 };
170
171 let peers_vec: Vec<NodeInfo> = self.peers.read().values().cloned().collect();
173 let _ = self.membership_tx.send(peers_vec);
174
175 is_new
176 }
177
178 pub fn check_liveness(&self) {
180 let now = chrono::Utc::now().timestamp_millis();
181 #[allow(clippy::cast_possible_wrap)]
182 let heartbeat_ms = self.config.heartbeat_interval_ms as i64;
183 let suspect_threshold = heartbeat_ms * i64::from(self.config.missed_heartbeat_threshold);
184 let dead_threshold = heartbeat_ms * i64::from(self.config.dead_heartbeat_threshold);
185
186 let mut changed = false;
187 {
188 let mut peers = self.peers.write();
189 for info in peers.values_mut() {
190 let elapsed = now.saturating_sub(info.last_heartbeat_ms);
191 let new_state = if elapsed >= dead_threshold {
192 NodeState::Left
193 } else if elapsed >= suspect_threshold {
194 NodeState::Suspected
195 } else {
196 NodeState::Active
197 };
198 if info.state != new_state {
199 info.state = new_state;
200 changed = true;
201 }
202 }
203
204 peers.retain(|_, info| info.state != NodeState::Left);
206 }
207
208 if changed {
209 let peers_vec: Vec<NodeInfo> = self.peers.read().values().cloned().collect();
210 let _ = self.membership_tx.send(peers_vec);
211 }
212 }
213
214 #[must_use]
216 pub fn serialize_heartbeat(info: &NodeInfo) -> (String, String) {
217 let key = format!("node:{}", info.id.0);
218 let value = serde_json::to_string(info).unwrap_or_default();
219 (key, value)
220 }
221
222 #[must_use]
224 pub fn deserialize_heartbeat(value: &str) -> Option<NodeInfo> {
225 serde_json::from_str(value).ok()
226 }
227}
228
229impl Discovery for KafkaDiscovery {
230 async fn start(&mut self) -> Result<(), DiscoveryError> {
231 if self.started {
232 return Ok(());
233 }
234
235 self.started = true;
246 Ok(())
247 }
248
249 async fn peers(&self) -> Result<Vec<NodeInfo>, DiscoveryError> {
250 if !self.started {
251 return Err(DiscoveryError::NotStarted);
252 }
253 Ok(self.peers.read().values().cloned().collect())
254 }
255
256 async fn announce(&self, info: NodeInfo) -> Result<(), DiscoveryError> {
257 if !self.started {
258 return Err(DiscoveryError::NotStarted);
259 }
260 let _ = self.process_heartbeat(info);
261 Ok(())
262 }
263
264 fn membership_watch(&self) -> watch::Receiver<Vec<NodeInfo>> {
265 self.membership_rx.clone()
266 }
267
268 async fn stop(&mut self) -> Result<(), DiscoveryError> {
269 if !self.started {
270 return Ok(());
271 }
272 self.cancel.cancel();
273 self.started = false;
274 Ok(())
275 }
276}
277
278#[derive(Debug, Clone)]
283pub struct LaminarPartitionAssignor {
284 weights: HashMap<u64, u32>,
286}
287
288impl LaminarPartitionAssignor {
289 #[must_use]
291 pub fn new(node_cores: &HashMap<u64, u32>) -> Self {
292 Self {
293 weights: node_cores.clone(),
294 }
295 }
296
297 #[must_use]
299 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
300 pub fn assign(&self, num_partitions: u32) -> HashMap<u64, Vec<u32>> {
301 if self.weights.is_empty() {
302 return HashMap::new();
303 }
304
305 let total_weight: u32 = self.weights.values().sum();
306 if total_weight == 0 {
307 return HashMap::new();
308 }
309
310 let mut assignments: HashMap<u64, Vec<u32>> = HashMap::new();
311 let mut assigned = 0u32;
312
313 let mut nodes: Vec<(u64, u32)> = self.weights.iter().map(|(&k, &v)| (k, v)).collect();
315 nodes.sort_by_key(|(id, _)| *id);
316
317 for (i, (node_id, weight)) in nodes.iter().enumerate() {
318 let share = if i == nodes.len() - 1 {
319 num_partitions - assigned
321 } else {
322 let share_f =
323 f64::from(num_partitions) * f64::from(*weight) / f64::from(total_weight);
324 share_f.round() as u32
325 };
326
327 let partitions: Vec<u32> = (assigned..assigned + share).collect();
328 assignments.insert(*node_id, partitions);
329 assigned += share;
330 }
331
332 assignments
333 }
334
335 #[must_use]
337 pub fn weight_for(&self, node_id: u64) -> u32 {
338 self.weights.get(&node_id).copied().unwrap_or(0)
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345
346 #[test]
347 fn test_config_default() {
348 let config = KafkaDiscoveryConfig::default();
349 assert_eq!(config.bootstrap_servers, "localhost:9092");
350 assert_eq!(config.group_id, "laminardb-discovery");
351 assert_eq!(config.discovery_topic, "_laminardb_discovery");
352 assert_eq!(config.heartbeat_interval_ms, 1000);
353 assert_eq!(config.missed_heartbeat_threshold, 3);
354 }
355
356 #[test]
357 fn test_kafka_discovery_new() {
358 let discovery = KafkaDiscovery::new(KafkaDiscoveryConfig::default());
359 assert!(!discovery.started);
360 assert_eq!(discovery.discovery_topic(), "_laminardb_discovery");
361 assert_eq!(discovery.group_id(), "laminardb-discovery");
362 }
363
364 #[test]
365 fn test_local_node_info() {
366 let config = KafkaDiscoveryConfig {
367 node_id: NodeId(42),
368 rpc_address: "10.0.0.1:9000".to_string(),
369 raft_address: "10.0.0.1:9001".to_string(),
370 ..KafkaDiscoveryConfig::default()
371 };
372 let discovery = KafkaDiscovery::new(config);
373 let info = discovery.local_node_info();
374
375 assert_eq!(info.id, NodeId(42));
376 assert_eq!(info.rpc_address, "10.0.0.1:9000");
377 assert_eq!(info.raft_address, "10.0.0.1:9001");
378 assert_eq!(info.state, NodeState::Active);
379 }
380
381 #[test]
382 fn test_process_heartbeat_new_node() {
383 let discovery = KafkaDiscovery::new(KafkaDiscoveryConfig::default());
384 let info = NodeInfo {
385 id: NodeId(2),
386 name: "node-2".to_string(),
387 rpc_address: "10.0.0.2:9000".to_string(),
388 raft_address: "10.0.0.2:9001".to_string(),
389 state: NodeState::Active,
390 metadata: NodeMetadata::default(),
391 last_heartbeat_ms: chrono::Utc::now().timestamp_millis(),
392 };
393
394 assert!(discovery.process_heartbeat(info));
395 assert_eq!(discovery.peers.read().len(), 1);
396 }
397
398 #[test]
399 fn test_process_heartbeat_ignores_self() {
400 let config = KafkaDiscoveryConfig {
401 node_id: NodeId(1),
402 ..KafkaDiscoveryConfig::default()
403 };
404 let discovery = KafkaDiscovery::new(config);
405 let info = NodeInfo {
406 id: NodeId(1), name: "node-1".to_string(),
408 rpc_address: "10.0.0.1:9000".to_string(),
409 raft_address: "10.0.0.1:9001".to_string(),
410 state: NodeState::Active,
411 metadata: NodeMetadata::default(),
412 last_heartbeat_ms: chrono::Utc::now().timestamp_millis(),
413 };
414
415 assert!(!discovery.process_heartbeat(info));
416 assert_eq!(discovery.peers.read().len(), 0);
417 }
418
419 #[test]
420 fn test_process_heartbeat_update() {
421 let discovery = KafkaDiscovery::new(KafkaDiscoveryConfig::default());
422 let info1 = NodeInfo {
423 id: NodeId(2),
424 name: "node-2".to_string(),
425 rpc_address: "10.0.0.2:9000".to_string(),
426 raft_address: "10.0.0.2:9001".to_string(),
427 state: NodeState::Active,
428 metadata: NodeMetadata::default(),
429 last_heartbeat_ms: 1000,
430 };
431 let info2 = NodeInfo {
432 id: NodeId(2),
433 name: "node-2".to_string(),
434 rpc_address: "10.0.0.2:9000".to_string(),
435 raft_address: "10.0.0.2:9001".to_string(),
436 state: NodeState::Active,
437 metadata: NodeMetadata::default(),
438 last_heartbeat_ms: 2000,
439 };
440
441 assert!(discovery.process_heartbeat(info1));
442 assert!(!discovery.process_heartbeat(info2)); assert_eq!(discovery.peers.read().len(), 1);
444
445 let peer = discovery.peers.read().get(&2).cloned().unwrap();
446 assert_eq!(peer.last_heartbeat_ms, 2000);
447 }
448
449 #[test]
450 fn test_serialize_deserialize_heartbeat() {
451 let info = NodeInfo {
452 id: NodeId(5),
453 name: "node-5".to_string(),
454 rpc_address: "10.0.0.5:9000".to_string(),
455 raft_address: "10.0.0.5:9001".to_string(),
456 state: NodeState::Active,
457 metadata: NodeMetadata::default(),
458 last_heartbeat_ms: 12345,
459 };
460
461 let (key, value) = KafkaDiscovery::serialize_heartbeat(&info);
462 assert_eq!(key, "node:5");
463 assert!(!value.is_empty());
464
465 let deserialized = KafkaDiscovery::deserialize_heartbeat(&value).unwrap();
466 assert_eq!(deserialized.id, NodeId(5));
467 assert_eq!(deserialized.rpc_address, "10.0.0.5:9000");
468 assert_eq!(deserialized.last_heartbeat_ms, 12345);
469 }
470
471 #[test]
472 fn test_deserialize_invalid() {
473 assert!(KafkaDiscovery::deserialize_heartbeat("not json").is_none());
474 assert!(KafkaDiscovery::deserialize_heartbeat("{}").is_none());
475 }
476
477 #[test]
478 fn test_check_liveness_suspects_stale() {
479 let config = KafkaDiscoveryConfig {
480 heartbeat_interval_ms: 100,
481 missed_heartbeat_threshold: 3,
482 dead_heartbeat_threshold: 10,
483 ..KafkaDiscoveryConfig::default()
484 };
485 let discovery = KafkaDiscovery::new(config);
486
487 let now = chrono::Utc::now().timestamp_millis();
489 let info = NodeInfo {
490 id: NodeId(2),
491 name: "node-2".to_string(),
492 rpc_address: "10.0.0.2:9000".to_string(),
493 raft_address: "10.0.0.2:9001".to_string(),
494 state: NodeState::Active,
495 metadata: NodeMetadata::default(),
496 last_heartbeat_ms: now - 500, };
498 let _ = discovery.process_heartbeat(info);
499
500 discovery.check_liveness();
501
502 let peers = discovery.peers.read();
503 let peer = peers.get(&2).unwrap();
504 assert_eq!(peer.state, NodeState::Suspected);
505 }
506
507 #[test]
508 fn test_check_liveness_removes_dead() {
509 let config = KafkaDiscoveryConfig {
510 heartbeat_interval_ms: 100,
511 missed_heartbeat_threshold: 3,
512 dead_heartbeat_threshold: 5,
513 ..KafkaDiscoveryConfig::default()
514 };
515 let discovery = KafkaDiscovery::new(config);
516
517 let now = chrono::Utc::now().timestamp_millis();
518 let info = NodeInfo {
519 id: NodeId(2),
520 name: "node-2".to_string(),
521 rpc_address: "10.0.0.2:9000".to_string(),
522 raft_address: "10.0.0.2:9001".to_string(),
523 state: NodeState::Active,
524 metadata: NodeMetadata::default(),
525 last_heartbeat_ms: now - 1000, };
527 let _ = discovery.process_heartbeat(info);
528
529 discovery.check_liveness();
530
531 assert_eq!(discovery.peers.read().len(), 0);
533 }
534
535 #[tokio::test]
536 async fn test_discovery_trait_not_started() {
537 let discovery = KafkaDiscovery::new(KafkaDiscoveryConfig::default());
538 let result = discovery.peers().await;
539 assert!(matches!(result, Err(DiscoveryError::NotStarted)));
540 }
541
542 #[tokio::test]
543 async fn test_discovery_trait_start_stop() {
544 let mut discovery = KafkaDiscovery::new(KafkaDiscoveryConfig::default());
545 discovery.start().await.unwrap();
546
547 let peers = discovery.peers().await.unwrap();
548 assert!(peers.is_empty());
549
550 discovery.stop().await.unwrap();
551
552 let result = discovery.peers().await;
553 assert!(matches!(result, Err(DiscoveryError::NotStarted)));
554 }
555
556 #[tokio::test]
557 async fn test_discovery_announce() {
558 let mut discovery = KafkaDiscovery::new(KafkaDiscoveryConfig::default());
559 discovery.start().await.unwrap();
560
561 let info = NodeInfo {
562 id: NodeId(3),
563 name: "node-3".to_string(),
564 rpc_address: "10.0.0.3:9000".to_string(),
565 raft_address: "10.0.0.3:9001".to_string(),
566 state: NodeState::Active,
567 metadata: NodeMetadata::default(),
568 last_heartbeat_ms: chrono::Utc::now().timestamp_millis(),
569 };
570 discovery.announce(info).await.unwrap();
571
572 let peers = discovery.peers().await.unwrap();
573 assert_eq!(peers.len(), 1);
574 assert_eq!(peers[0].id, NodeId(3));
575 }
576
577 #[test]
578 fn test_membership_watch() {
579 let discovery = KafkaDiscovery::new(KafkaDiscoveryConfig::default());
580 let rx = discovery.membership_watch();
581 assert!(rx.borrow().is_empty());
582
583 let info = NodeInfo {
584 id: NodeId(2),
585 name: "node-2".to_string(),
586 rpc_address: "10.0.0.2:9000".to_string(),
587 raft_address: "10.0.0.2:9001".to_string(),
588 state: NodeState::Active,
589 metadata: NodeMetadata::default(),
590 last_heartbeat_ms: chrono::Utc::now().timestamp_millis(),
591 };
592 let _ = discovery.process_heartbeat(info);
593
594 assert_eq!(rx.borrow().len(), 1);
595 }
596
597 #[test]
600 fn test_assignor_equal_weights() {
601 let mut cores = HashMap::new();
602 cores.insert(1, 4);
603 cores.insert(2, 4);
604 let assignor = LaminarPartitionAssignor::new(&cores);
605
606 let result = assignor.assign(8);
607 assert_eq!(result.len(), 2);
608 assert_eq!(result[&1].len(), 4);
609 assert_eq!(result[&2].len(), 4);
610 }
611
612 #[test]
613 fn test_assignor_weighted() {
614 let mut cores = HashMap::new();
615 cores.insert(1, 8); cores.insert(2, 4); let assignor = LaminarPartitionAssignor::new(&cores);
618
619 let result = assignor.assign(12);
620 let total: usize = result.values().map(Vec::len).sum();
622 assert_eq!(total, 12);
623 assert!(result[&1].len() >= 7); assert!(result[&2].len() >= 3); }
626
627 #[test]
628 fn test_assignor_single_node() {
629 let mut cores = HashMap::new();
630 cores.insert(1, 4);
631 let assignor = LaminarPartitionAssignor::new(&cores);
632
633 let result = assignor.assign(16);
634 assert_eq!(result[&1].len(), 16);
635 }
636
637 #[test]
638 fn test_assignor_empty() {
639 let assignor = LaminarPartitionAssignor::new(&HashMap::new());
640 let result = assignor.assign(16);
641 assert!(result.is_empty());
642 }
643
644 #[test]
645 fn test_assignor_deterministic() {
646 let mut cores = HashMap::new();
647 cores.insert(1, 4);
648 cores.insert(2, 8);
649 cores.insert(3, 4);
650 let assignor = LaminarPartitionAssignor::new(&cores);
651
652 let r1 = assignor.assign(32);
653 let r2 = assignor.assign(32);
654 assert_eq!(r1, r2);
655 }
656
657 #[test]
658 fn test_assignor_weight_for() {
659 let mut cores = HashMap::new();
660 cores.insert(1, 4);
661 cores.insert(2, 8);
662 let assignor = LaminarPartitionAssignor::new(&cores);
663
664 assert_eq!(assignor.weight_for(1), 4);
665 assert_eq!(assignor.weight_for(2), 8);
666 assert_eq!(assignor.weight_for(99), 0);
667 }
668}