Skip to main content

laminar_connectors/kafka/
discovery.rs

1//! Kafka-based discovery for delta nodes.
2//!
3//! Uses Kafka consumer group protocol for node discovery and membership.
4//! Each node joins a shared consumer group; the group coordinator handles
5//! membership, heartbeats, and rebalancing. This provides zero-infrastructure
6//! discovery when Kafka is already deployed.
7//!
8//! ## How It Works
9//!
10//! 1. Each node creates a Kafka consumer in a shared group (the "discovery group").
11//! 2. Node metadata is published to a dedicated Kafka topic as keyed messages.
12//! 3. A background task polls the topic for membership changes.
13//! 4. Consumer group rebalance callbacks detect join/leave events.
14//!
15//! ## Key Format
16//!
17//! Topic: `_laminardb_discovery` (configurable)
18//! Key: `node:{node_id}`
19//! Value: JSON-serialized `NodeInfo`
20
21use std::collections::HashMap;
22use std::sync::Arc;
23
24use parking_lot::RwLock;
25use tokio::sync::watch;
26use tokio_util::sync::CancellationToken;
27
28use laminar_core::delta::discovery::{
29    Discovery, DiscoveryError, NodeId, NodeInfo, NodeMetadata, NodeState,
30};
31
32/// Configuration for Kafka-based discovery.
33#[derive(Debug, Clone)]
34pub struct KafkaDiscoveryConfig {
35    /// Kafka bootstrap servers.
36    pub bootstrap_servers: String,
37    /// Consumer group ID for discovery.
38    pub group_id: String,
39    /// Topic for discovery messages.
40    pub discovery_topic: String,
41    /// How often to publish this node's heartbeat (ms).
42    pub heartbeat_interval_ms: u64,
43    /// How many missed heartbeats before marking a node as suspected.
44    pub missed_heartbeat_threshold: u32,
45    /// How many missed heartbeats before marking a node as left.
46    pub dead_heartbeat_threshold: u32,
47    /// SASL/SSL configuration for Kafka.
48    pub security_protocol: String,
49    /// SASL mechanism.
50    pub sasl_mechanism: Option<String>,
51    /// SASL username.
52    pub sasl_username: Option<String>,
53    /// SASL password.
54    pub sasl_password: Option<String>,
55    /// This node's ID.
56    pub node_id: NodeId,
57    /// This node's RPC address.
58    pub rpc_address: String,
59    /// This node's Raft address.
60    pub raft_address: String,
61    /// This node's metadata.
62    pub node_metadata: NodeMetadata,
63}
64
65impl Default for KafkaDiscoveryConfig {
66    fn default() -> Self {
67        Self {
68            bootstrap_servers: "localhost:9092".to_string(),
69            group_id: "laminardb-discovery".to_string(),
70            discovery_topic: "_laminardb_discovery".to_string(),
71            heartbeat_interval_ms: 1000,
72            missed_heartbeat_threshold: 3,
73            dead_heartbeat_threshold: 10,
74            security_protocol: "plaintext".to_string(),
75            sasl_mechanism: None,
76            sasl_username: None,
77            sasl_password: None,
78            node_id: NodeId(1),
79            rpc_address: "127.0.0.1:9000".to_string(),
80            raft_address: "127.0.0.1:9001".to_string(),
81            node_metadata: NodeMetadata::default(),
82        }
83    }
84}
85
86/// Kafka-based node discovery.
87///
88/// Uses the Kafka consumer group protocol for membership management.
89/// When Kafka is already deployed, this provides discovery without
90/// additional infrastructure (no separate gossip or etcd needed).
91pub struct KafkaDiscovery {
92    /// Configuration.
93    config: KafkaDiscoveryConfig,
94    /// Known peers (excluding self).
95    peers: Arc<RwLock<HashMap<u64, NodeInfo>>>,
96    /// Membership watch channel sender.
97    membership_tx: watch::Sender<Vec<NodeInfo>>,
98    /// Membership watch channel receiver.
99    membership_rx: watch::Receiver<Vec<NodeInfo>>,
100    /// Cancellation token for background tasks.
101    cancel: CancellationToken,
102    /// Whether the discovery service has been started.
103    started: bool,
104}
105
106impl std::fmt::Debug for KafkaDiscovery {
107    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108        f.debug_struct("KafkaDiscovery")
109            .field("config", &self.config)
110            .field("started", &self.started)
111            .finish_non_exhaustive()
112    }
113}
114
115impl KafkaDiscovery {
116    /// Create a new Kafka discovery instance.
117    #[must_use]
118    pub fn new(config: KafkaDiscoveryConfig) -> Self {
119        let (membership_tx, membership_rx) = watch::channel(Vec::new());
120        Self {
121            config,
122            peers: Arc::new(RwLock::new(HashMap::new())),
123            membership_tx,
124            membership_rx,
125            cancel: CancellationToken::new(),
126            started: false,
127        }
128    }
129
130    /// Get the discovery topic name.
131    #[must_use]
132    pub fn discovery_topic(&self) -> &str {
133        &self.config.discovery_topic
134    }
135
136    /// Get the group ID.
137    #[must_use]
138    pub fn group_id(&self) -> &str {
139        &self.config.group_id
140    }
141
142    /// Build a `NodeInfo` for this node.
143    #[must_use]
144    pub fn local_node_info(&self) -> NodeInfo {
145        NodeInfo {
146            id: self.config.node_id,
147            name: format!("node-{}", self.config.node_id.0),
148            rpc_address: self.config.rpc_address.clone(),
149            raft_address: self.config.raft_address.clone(),
150            state: NodeState::Active,
151            metadata: self.config.node_metadata.clone(),
152            last_heartbeat_ms: chrono::Utc::now().timestamp_millis(),
153        }
154    }
155
156    /// Process a received heartbeat message from another node.
157    ///
158    /// Returns `true` if this is a new node (join event).
159    #[must_use]
160    pub fn process_heartbeat(&self, info: NodeInfo) -> bool {
161        if info.id == self.config.node_id {
162            return false; // Ignore self
163        }
164
165        let is_new = {
166            let mut peers = self.peers.write();
167            let existing = peers.insert(info.id.0, info);
168            existing.is_none()
169        };
170
171        // Update membership watch
172        let peers_vec: Vec<NodeInfo> = self.peers.read().values().cloned().collect();
173        let _ = self.membership_tx.send(peers_vec);
174
175        is_new
176    }
177
178    /// Check for nodes that have missed heartbeats and update their state.
179    pub fn check_liveness(&self) {
180        let now = chrono::Utc::now().timestamp_millis();
181        #[allow(clippy::cast_possible_wrap)]
182        let heartbeat_ms = self.config.heartbeat_interval_ms as i64;
183        let suspect_threshold = heartbeat_ms * i64::from(self.config.missed_heartbeat_threshold);
184        let dead_threshold = heartbeat_ms * i64::from(self.config.dead_heartbeat_threshold);
185
186        let mut changed = false;
187        {
188            let mut peers = self.peers.write();
189            for info in peers.values_mut() {
190                let elapsed = now.saturating_sub(info.last_heartbeat_ms);
191                let new_state = if elapsed >= dead_threshold {
192                    NodeState::Left
193                } else if elapsed >= suspect_threshold {
194                    NodeState::Suspected
195                } else {
196                    NodeState::Active
197                };
198                if info.state != new_state {
199                    info.state = new_state;
200                    changed = true;
201                }
202            }
203
204            // Remove nodes that have been Left for a while
205            peers.retain(|_, info| info.state != NodeState::Left);
206        }
207
208        if changed {
209            let peers_vec: Vec<NodeInfo> = self.peers.read().values().cloned().collect();
210            let _ = self.membership_tx.send(peers_vec);
211        }
212    }
213
214    /// Serialize a `NodeInfo` to a JSON message key and value.
215    #[must_use]
216    pub fn serialize_heartbeat(info: &NodeInfo) -> (String, String) {
217        let key = format!("node:{}", info.id.0);
218        let value = serde_json::to_string(info).unwrap_or_default();
219        (key, value)
220    }
221
222    /// Deserialize a heartbeat message.
223    #[must_use]
224    pub fn deserialize_heartbeat(value: &str) -> Option<NodeInfo> {
225        serde_json::from_str(value).ok()
226    }
227}
228
229impl Discovery for KafkaDiscovery {
230    async fn start(&mut self) -> Result<(), DiscoveryError> {
231        if self.started {
232            return Ok(());
233        }
234
235        // In a full implementation, we would:
236        // 1. Create a Kafka producer for heartbeats
237        // 2. Create a Kafka consumer in the discovery group
238        // 3. Subscribe to the discovery topic
239        // 4. Spawn background heartbeat publisher
240        // 5. Spawn background consumer poller
241        //
242        // For now, mark as started. The actual Kafka integration will
243        // be wired when rdkafka async producer/consumer is connected.
244
245        self.started = true;
246        Ok(())
247    }
248
249    async fn peers(&self) -> Result<Vec<NodeInfo>, DiscoveryError> {
250        if !self.started {
251            return Err(DiscoveryError::NotStarted);
252        }
253        Ok(self.peers.read().values().cloned().collect())
254    }
255
256    async fn announce(&self, info: NodeInfo) -> Result<(), DiscoveryError> {
257        if !self.started {
258            return Err(DiscoveryError::NotStarted);
259        }
260        let _ = self.process_heartbeat(info);
261        Ok(())
262    }
263
264    fn membership_watch(&self) -> watch::Receiver<Vec<NodeInfo>> {
265        self.membership_rx.clone()
266    }
267
268    async fn stop(&mut self) -> Result<(), DiscoveryError> {
269        if !self.started {
270            return Ok(());
271        }
272        self.cancel.cancel();
273        self.started = false;
274        Ok(())
275    }
276}
277
278/// Custom partition assignor for `LaminarDB` delta.
279///
280/// Assigns Kafka partitions weighted by node core count, ensuring
281/// that nodes with more cores get proportionally more partitions.
282#[derive(Debug, Clone)]
283pub struct LaminarPartitionAssignor {
284    /// Weights per node (`node_id` to weight).
285    weights: HashMap<u64, u32>,
286}
287
288impl LaminarPartitionAssignor {
289    /// Create a new assignor with core-weighted nodes.
290    #[must_use]
291    pub fn new(node_cores: &HashMap<u64, u32>) -> Self {
292        Self {
293            weights: node_cores.clone(),
294        }
295    }
296
297    /// Assign partitions to nodes proportionally by weight.
298    #[must_use]
299    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
300    pub fn assign(&self, num_partitions: u32) -> HashMap<u64, Vec<u32>> {
301        if self.weights.is_empty() {
302            return HashMap::new();
303        }
304
305        let total_weight: u32 = self.weights.values().sum();
306        if total_weight == 0 {
307            return HashMap::new();
308        }
309
310        let mut assignments: HashMap<u64, Vec<u32>> = HashMap::new();
311        let mut assigned = 0u32;
312
313        // Sort nodes for deterministic assignment
314        let mut nodes: Vec<(u64, u32)> = self.weights.iter().map(|(&k, &v)| (k, v)).collect();
315        nodes.sort_by_key(|(id, _)| *id);
316
317        for (i, (node_id, weight)) in nodes.iter().enumerate() {
318            let share = if i == nodes.len() - 1 {
319                // Last node gets the remainder
320                num_partitions - assigned
321            } else {
322                let share_f =
323                    f64::from(num_partitions) * f64::from(*weight) / f64::from(total_weight);
324                share_f.round() as u32
325            };
326
327            let partitions: Vec<u32> = (assigned..assigned + share).collect();
328            assignments.insert(*node_id, partitions);
329            assigned += share;
330        }
331
332        assignments
333    }
334
335    /// Get the weight for a node.
336    #[must_use]
337    pub fn weight_for(&self, node_id: u64) -> u32 {
338        self.weights.get(&node_id).copied().unwrap_or(0)
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345
346    #[test]
347    fn test_config_default() {
348        let config = KafkaDiscoveryConfig::default();
349        assert_eq!(config.bootstrap_servers, "localhost:9092");
350        assert_eq!(config.group_id, "laminardb-discovery");
351        assert_eq!(config.discovery_topic, "_laminardb_discovery");
352        assert_eq!(config.heartbeat_interval_ms, 1000);
353        assert_eq!(config.missed_heartbeat_threshold, 3);
354    }
355
356    #[test]
357    fn test_kafka_discovery_new() {
358        let discovery = KafkaDiscovery::new(KafkaDiscoveryConfig::default());
359        assert!(!discovery.started);
360        assert_eq!(discovery.discovery_topic(), "_laminardb_discovery");
361        assert_eq!(discovery.group_id(), "laminardb-discovery");
362    }
363
364    #[test]
365    fn test_local_node_info() {
366        let config = KafkaDiscoveryConfig {
367            node_id: NodeId(42),
368            rpc_address: "10.0.0.1:9000".to_string(),
369            raft_address: "10.0.0.1:9001".to_string(),
370            ..KafkaDiscoveryConfig::default()
371        };
372        let discovery = KafkaDiscovery::new(config);
373        let info = discovery.local_node_info();
374
375        assert_eq!(info.id, NodeId(42));
376        assert_eq!(info.rpc_address, "10.0.0.1:9000");
377        assert_eq!(info.raft_address, "10.0.0.1:9001");
378        assert_eq!(info.state, NodeState::Active);
379    }
380
381    #[test]
382    fn test_process_heartbeat_new_node() {
383        let discovery = KafkaDiscovery::new(KafkaDiscoveryConfig::default());
384        let info = NodeInfo {
385            id: NodeId(2),
386            name: "node-2".to_string(),
387            rpc_address: "10.0.0.2:9000".to_string(),
388            raft_address: "10.0.0.2:9001".to_string(),
389            state: NodeState::Active,
390            metadata: NodeMetadata::default(),
391            last_heartbeat_ms: chrono::Utc::now().timestamp_millis(),
392        };
393
394        assert!(discovery.process_heartbeat(info));
395        assert_eq!(discovery.peers.read().len(), 1);
396    }
397
398    #[test]
399    fn test_process_heartbeat_ignores_self() {
400        let config = KafkaDiscoveryConfig {
401            node_id: NodeId(1),
402            ..KafkaDiscoveryConfig::default()
403        };
404        let discovery = KafkaDiscovery::new(config);
405        let info = NodeInfo {
406            id: NodeId(1), // same as local
407            name: "node-1".to_string(),
408            rpc_address: "10.0.0.1:9000".to_string(),
409            raft_address: "10.0.0.1:9001".to_string(),
410            state: NodeState::Active,
411            metadata: NodeMetadata::default(),
412            last_heartbeat_ms: chrono::Utc::now().timestamp_millis(),
413        };
414
415        assert!(!discovery.process_heartbeat(info));
416        assert_eq!(discovery.peers.read().len(), 0);
417    }
418
419    #[test]
420    fn test_process_heartbeat_update() {
421        let discovery = KafkaDiscovery::new(KafkaDiscoveryConfig::default());
422        let info1 = NodeInfo {
423            id: NodeId(2),
424            name: "node-2".to_string(),
425            rpc_address: "10.0.0.2:9000".to_string(),
426            raft_address: "10.0.0.2:9001".to_string(),
427            state: NodeState::Active,
428            metadata: NodeMetadata::default(),
429            last_heartbeat_ms: 1000,
430        };
431        let info2 = NodeInfo {
432            id: NodeId(2),
433            name: "node-2".to_string(),
434            rpc_address: "10.0.0.2:9000".to_string(),
435            raft_address: "10.0.0.2:9001".to_string(),
436            state: NodeState::Active,
437            metadata: NodeMetadata::default(),
438            last_heartbeat_ms: 2000,
439        };
440
441        assert!(discovery.process_heartbeat(info1));
442        assert!(!discovery.process_heartbeat(info2)); // Update, not new
443        assert_eq!(discovery.peers.read().len(), 1);
444
445        let peer = discovery.peers.read().get(&2).cloned().unwrap();
446        assert_eq!(peer.last_heartbeat_ms, 2000);
447    }
448
449    #[test]
450    fn test_serialize_deserialize_heartbeat() {
451        let info = NodeInfo {
452            id: NodeId(5),
453            name: "node-5".to_string(),
454            rpc_address: "10.0.0.5:9000".to_string(),
455            raft_address: "10.0.0.5:9001".to_string(),
456            state: NodeState::Active,
457            metadata: NodeMetadata::default(),
458            last_heartbeat_ms: 12345,
459        };
460
461        let (key, value) = KafkaDiscovery::serialize_heartbeat(&info);
462        assert_eq!(key, "node:5");
463        assert!(!value.is_empty());
464
465        let deserialized = KafkaDiscovery::deserialize_heartbeat(&value).unwrap();
466        assert_eq!(deserialized.id, NodeId(5));
467        assert_eq!(deserialized.rpc_address, "10.0.0.5:9000");
468        assert_eq!(deserialized.last_heartbeat_ms, 12345);
469    }
470
471    #[test]
472    fn test_deserialize_invalid() {
473        assert!(KafkaDiscovery::deserialize_heartbeat("not json").is_none());
474        assert!(KafkaDiscovery::deserialize_heartbeat("{}").is_none());
475    }
476
477    #[test]
478    fn test_check_liveness_suspects_stale() {
479        let config = KafkaDiscoveryConfig {
480            heartbeat_interval_ms: 100,
481            missed_heartbeat_threshold: 3,
482            dead_heartbeat_threshold: 10,
483            ..KafkaDiscoveryConfig::default()
484        };
485        let discovery = KafkaDiscovery::new(config);
486
487        // Add a peer with old heartbeat (well past suspect threshold)
488        let now = chrono::Utc::now().timestamp_millis();
489        let info = NodeInfo {
490            id: NodeId(2),
491            name: "node-2".to_string(),
492            rpc_address: "10.0.0.2:9000".to_string(),
493            raft_address: "10.0.0.2:9001".to_string(),
494            state: NodeState::Active,
495            metadata: NodeMetadata::default(),
496            last_heartbeat_ms: now - 500, // 500ms old with 100ms interval, 3 threshold = 300ms
497        };
498        let _ = discovery.process_heartbeat(info);
499
500        discovery.check_liveness();
501
502        let peers = discovery.peers.read();
503        let peer = peers.get(&2).unwrap();
504        assert_eq!(peer.state, NodeState::Suspected);
505    }
506
507    #[test]
508    fn test_check_liveness_removes_dead() {
509        let config = KafkaDiscoveryConfig {
510            heartbeat_interval_ms: 100,
511            missed_heartbeat_threshold: 3,
512            dead_heartbeat_threshold: 5,
513            ..KafkaDiscoveryConfig::default()
514        };
515        let discovery = KafkaDiscovery::new(config);
516
517        let now = chrono::Utc::now().timestamp_millis();
518        let info = NodeInfo {
519            id: NodeId(2),
520            name: "node-2".to_string(),
521            rpc_address: "10.0.0.2:9000".to_string(),
522            raft_address: "10.0.0.2:9001".to_string(),
523            state: NodeState::Active,
524            metadata: NodeMetadata::default(),
525            last_heartbeat_ms: now - 1000, // 1000ms old with 100ms*5 = 500ms dead threshold
526        };
527        let _ = discovery.process_heartbeat(info);
528
529        discovery.check_liveness();
530
531        // Dead nodes should be removed
532        assert_eq!(discovery.peers.read().len(), 0);
533    }
534
535    #[tokio::test]
536    async fn test_discovery_trait_not_started() {
537        let discovery = KafkaDiscovery::new(KafkaDiscoveryConfig::default());
538        let result = discovery.peers().await;
539        assert!(matches!(result, Err(DiscoveryError::NotStarted)));
540    }
541
542    #[tokio::test]
543    async fn test_discovery_trait_start_stop() {
544        let mut discovery = KafkaDiscovery::new(KafkaDiscoveryConfig::default());
545        discovery.start().await.unwrap();
546
547        let peers = discovery.peers().await.unwrap();
548        assert!(peers.is_empty());
549
550        discovery.stop().await.unwrap();
551
552        let result = discovery.peers().await;
553        assert!(matches!(result, Err(DiscoveryError::NotStarted)));
554    }
555
556    #[tokio::test]
557    async fn test_discovery_announce() {
558        let mut discovery = KafkaDiscovery::new(KafkaDiscoveryConfig::default());
559        discovery.start().await.unwrap();
560
561        let info = NodeInfo {
562            id: NodeId(3),
563            name: "node-3".to_string(),
564            rpc_address: "10.0.0.3:9000".to_string(),
565            raft_address: "10.0.0.3:9001".to_string(),
566            state: NodeState::Active,
567            metadata: NodeMetadata::default(),
568            last_heartbeat_ms: chrono::Utc::now().timestamp_millis(),
569        };
570        discovery.announce(info).await.unwrap();
571
572        let peers = discovery.peers().await.unwrap();
573        assert_eq!(peers.len(), 1);
574        assert_eq!(peers[0].id, NodeId(3));
575    }
576
577    #[test]
578    fn test_membership_watch() {
579        let discovery = KafkaDiscovery::new(KafkaDiscoveryConfig::default());
580        let rx = discovery.membership_watch();
581        assert!(rx.borrow().is_empty());
582
583        let info = NodeInfo {
584            id: NodeId(2),
585            name: "node-2".to_string(),
586            rpc_address: "10.0.0.2:9000".to_string(),
587            raft_address: "10.0.0.2:9001".to_string(),
588            state: NodeState::Active,
589            metadata: NodeMetadata::default(),
590            last_heartbeat_ms: chrono::Utc::now().timestamp_millis(),
591        };
592        let _ = discovery.process_heartbeat(info);
593
594        assert_eq!(rx.borrow().len(), 1);
595    }
596
597    // ── LaminarPartitionAssignor tests ──
598
599    #[test]
600    fn test_assignor_equal_weights() {
601        let mut cores = HashMap::new();
602        cores.insert(1, 4);
603        cores.insert(2, 4);
604        let assignor = LaminarPartitionAssignor::new(&cores);
605
606        let result = assignor.assign(8);
607        assert_eq!(result.len(), 2);
608        assert_eq!(result[&1].len(), 4);
609        assert_eq!(result[&2].len(), 4);
610    }
611
612    #[test]
613    fn test_assignor_weighted() {
614        let mut cores = HashMap::new();
615        cores.insert(1, 8); // 2/3 weight
616        cores.insert(2, 4); // 1/3 weight
617        let assignor = LaminarPartitionAssignor::new(&cores);
618
619        let result = assignor.assign(12);
620        // Node 1 should get ~8 partitions, node 2 ~4
621        let total: usize = result.values().map(Vec::len).sum();
622        assert_eq!(total, 12);
623        assert!(result[&1].len() >= 7); // ~8
624        assert!(result[&2].len() >= 3); // ~4
625    }
626
627    #[test]
628    fn test_assignor_single_node() {
629        let mut cores = HashMap::new();
630        cores.insert(1, 4);
631        let assignor = LaminarPartitionAssignor::new(&cores);
632
633        let result = assignor.assign(16);
634        assert_eq!(result[&1].len(), 16);
635    }
636
637    #[test]
638    fn test_assignor_empty() {
639        let assignor = LaminarPartitionAssignor::new(&HashMap::new());
640        let result = assignor.assign(16);
641        assert!(result.is_empty());
642    }
643
644    #[test]
645    fn test_assignor_deterministic() {
646        let mut cores = HashMap::new();
647        cores.insert(1, 4);
648        cores.insert(2, 8);
649        cores.insert(3, 4);
650        let assignor = LaminarPartitionAssignor::new(&cores);
651
652        let r1 = assignor.assign(32);
653        let r2 = assignor.assign(32);
654        assert_eq!(r1, r2);
655    }
656
657    #[test]
658    fn test_assignor_weight_for() {
659        let mut cores = HashMap::new();
660        cores.insert(1, 4);
661        cores.insert(2, 8);
662        let assignor = LaminarPartitionAssignor::new(&cores);
663
664        assert_eq!(assignor.weight_for(1), 4);
665        assert_eq!(assignor.weight_for(2), 8);
666        assert_eq!(assignor.weight_for(99), 0);
667    }
668}