//! SWIM-based membership protocol implementation. //! //! This module implements a SWIM-like protocol for cluster membership: //! //! - **Ping**: Direct health check to random peer //! - **Indirect Probe**: Ask K peers to check unresponsive node //! - **Suspicion**: Mark unresponsive nodes as suspect //! - **Gossip**: Piggyback membership updates on protocol messages use dashmap::DashMap; use metrics::{counter, gauge}; use parking_lot::RwLock; use rand::seq::SliceRandom; use std::collections::VecDeque; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::Arc; use std::time::Instant; use tokio::sync::broadcast; use tracing::{debug, info, instrument, warn}; use crate::config::SwimConfig; use crate::membership::types::{MembershipEntry, MembershipEvent, NodeId, NodeInfo, NodeState}; use crate::Result; /// SWIM-based cluster membership manager. /// /// Manages the list of known cluster members, detects failures via probing, /// and disseminates membership changes via gossip. pub struct SwimMembership { /// This node's information. local_node: RwLock, /// Known cluster members (excluding self). members: DashMap, /// Nodes currently under suspicion. suspects: DashMap, /// Event broadcaster for membership changes. event_tx: broadcast::Sender, /// Configuration. config: SwimConfig, /// Lamport clock for ordering events. lamport_clock: AtomicU64, /// Queue of membership updates to gossip. gossip_queue: RwLock>, /// Whether the membership protocol is running. running: AtomicBool, /// Whether this node has joined a cluster. joined: AtomicBool, } impl SwimMembership { /// Creates a new SWIM membership manager. pub fn new(local_node: NodeInfo, config: SwimConfig) -> Self { let (event_tx, _) = broadcast::channel(1024); Self { local_node: RwLock::new(local_node), members: DashMap::new(), suspects: DashMap::new(), event_tx, config, lamport_clock: AtomicU64::new(0), gossip_queue: RwLock::new(VecDeque::with_capacity(1000)), running: AtomicBool::new(false), joined: AtomicBool::new(false), } } /// Returns this node's ID. pub fn local_id(&self) -> NodeId { self.local_node.read().id } /// Returns this node's information. pub fn local_info(&self) -> NodeInfo { self.local_node.read().clone() } /// Updates this node's information. pub fn update_local_info(&self, info: NodeInfo) { let mut local = self.local_node.write(); *local = info; } /// Joins the cluster by contacting seed nodes via gRPC ping. /// /// # Algorithm /// /// 1. For each seed, attempt a `Ping` RPC to verify reachability /// 2. If at least one seed is reachable, mark as joined /// 3. If no seeds are reachable, start as an isolated node (not an error — /// gossip and anti-entropy will sync state once the network recovers) /// /// # Errors /// /// Never returns an error — isolated startup is acceptable. #[instrument(skip(self), fields(seed_count = seeds.len()))] pub async fn join(&self, seeds: Vec) -> Result<()> { if seeds.is_empty() { // No seeds = this is the first node (bootstrap) info!("No seed nodes, bootstrapping as first node"); self.joined.store(true, Ordering::SeqCst); return Ok(()); } info!("Joining cluster via seeds"); let local_id = self.local_id(); let local_rpc_addr = self.local_info().rpc_addr; let mut contacted = 0usize; for seed_addr in &seeds { // Skip our own RPC address to avoid self-pinging if *seed_addr == local_rpc_addr { continue; } let addr = format!("http://{}", seed_addr); let client = match stemedb_rpc::SyncClient::connect(&addr).await { Ok(c) => c, Err(e) => { warn!(seed = %seed_addr, error = %e, "Cannot connect to seed, skipping"); continue; } }; let ping = stemedb_rpc::proto::PingRequest { node_id: local_id.as_bytes().to_vec(), updates: Vec::new(), }; match client.ping(ping).await { Ok(resp) => { if resp.node_id.len() >= 16 { let mut seed_id_bytes = [0u8; 16]; seed_id_bytes.copy_from_slice(&resp.node_id[..16]); let seed_node_id = NodeId::from_bytes(seed_id_bytes); // Use addresses from PingResponse, falling back to seed_addr let seed_rpc: std::net::SocketAddr = resp.rpc_addr.parse().unwrap_or(*seed_addr); let seed_api: std::net::SocketAddr = resp.api_addr.parse().unwrap_or_else(|_| { std::net::SocketAddr::new( seed_addr.ip(), seed_addr.port().saturating_sub(2), ) }); let seed_info = NodeInfo::new(seed_node_id, seed_rpc, seed_api); self.alive_node(seed_node_id, seed_info); info!( seed = %seed_addr, seed_id = %seed_node_id.short_hex(), "Registered seed in membership table" ); } else { info!( seed = %seed_addr, "Seed reachable but returned short node_id" ); } contacted += 1; } Err(e) => { warn!(seed = %seed_addr, error = %e, "Seed ping failed"); } } } if contacted == 0 { warn!("No seeds reachable — starting as isolated node (anti-entropy will sync later)"); } else { info!(contacted, "Joined cluster via seeds"); } self.joined.store(true, Ordering::SeqCst); Ok(()) } /// Gracefully leaves the cluster. /// /// Broadcasts a leave message so other nodes mark us as Left rather than Dead. #[instrument(skip(self))] pub async fn leave(&self) -> Result<()> { if !self.joined.load(Ordering::SeqCst) { return Ok(()); } info!("Leaving cluster gracefully"); // Broadcast leave to all known members let local_id = self.local_id(); let _ = self.event_tx.send(MembershipEvent::NodeLeft(local_id)); self.joined.store(false, Ordering::SeqCst); self.running.store(false, Ordering::SeqCst); Ok(()) } /// Returns all currently known alive members. pub fn members(&self) -> Vec { self.members .iter() .filter(|entry| entry.state == NodeState::Alive) .map(|entry| entry.node.clone()) .collect() } /// Returns all members including suspects. pub fn all_members(&self) -> Vec<(NodeInfo, NodeState)> { self.members.iter().map(|entry| (entry.node.clone(), entry.state)).collect() } /// Returns the count of alive members. pub fn member_count(&self) -> usize { self.members.iter().filter(|e| e.state == NodeState::Alive).count() } /// Checks if a specific node is a known member. pub fn is_member(&self, node_id: NodeId) -> bool { self.members.get(&node_id).map(|e| e.state == NodeState::Alive).unwrap_or(false) } /// Gets information about a specific node. pub fn get_member(&self, node_id: NodeId) -> Option { self.members.get(&node_id).map(|e| e.node.clone()) } /// Subscribes to membership events. pub fn subscribe(&self) -> broadcast::Receiver { self.event_tx.subscribe() } /// Processes a membership update from a remote node. /// /// Merges the update into our local state if it's newer. #[instrument(skip(self, entry), fields(node_id = %entry.node.id.short_hex()))] pub fn process_membership_update(&self, entry: MembershipEntry) { let node_id = entry.node.id; // Don't process updates about ourselves if node_id == self.local_id() { return; } // Update Lamport clock self.lamport_clock.fetch_max(entry.lamport_time + 1, Ordering::SeqCst); // Check if we should accept this update (extract data then drop lock) let should_update = { if let Some(existing) = self.members.get(&node_id) { if entry.is_newer_than(&existing) { Some(Some(existing.state)) // newer → update with old state } else { debug!( existing_gen = existing.node.incarnation, incoming_gen = entry.node.incarnation, "Ignoring older membership update" ); None // stale → skip } } else { Some(None) // new node → update with no old state } }; // DashMap Ref dropped here let old_state = match should_update { Some(old) => old, None => return, }; let new_state = entry.state; let node_info = entry.node.clone(); self.members.insert(node_id, entry); // Emit appropriate event and record metrics match (old_state, new_state) { (None, NodeState::Alive) => { info!(node = %node_id.short_hex(), "Node joined"); let _ = self.event_tx.send(MembershipEvent::NodeJoined(node_info)); counter!("stemedb_membership_events_total", "type" => "joined").increment(1); } (Some(NodeState::Alive), NodeState::Suspect) => { warn!(node = %node_id.short_hex(), "Node suspected"); let _ = self.event_tx.send(MembershipEvent::NodeSuspected(node_id)); self.suspects.insert(node_id, Instant::now()); counter!("stemedb_membership_events_total", "type" => "suspected").increment(1); } (Some(_), NodeState::Dead) => { warn!(node = %node_id.short_hex(), "Node failed"); let _ = self.event_tx.send(MembershipEvent::NodeFailed(node_id)); self.suspects.remove(&node_id); counter!("stemedb_membership_events_total", "type" => "failed").increment(1); } (Some(_), NodeState::Left) => { info!(node = %node_id.short_hex(), "Node left"); let _ = self.event_tx.send(MembershipEvent::NodeLeft(node_id)); self.suspects.remove(&node_id); counter!("stemedb_membership_events_total", "type" => "left").increment(1); } (Some(NodeState::Suspect), NodeState::Alive) => { info!(node = %node_id.short_hex(), "Node recovered"); let _ = self.event_tx.send(MembershipEvent::NodeUpdated(node_info)); self.suspects.remove(&node_id); counter!("stemedb_membership_events_total", "type" => "recovered").increment(1); } (Some(_), _) => { // Other updates let _ = self.event_tx.send(MembershipEvent::NodeUpdated(node_info)); } (None, _) => { // First time seeing this node in non-alive state, ignore } } // Update cluster node gauges self.update_node_gauges(); } /// Updates the Prometheus gauges for cluster node counts. fn update_node_gauges(&self) { let mut alive_count = 0usize; let mut suspect_count = 0usize; for entry in self.members.iter() { match entry.state { NodeState::Alive => alive_count += 1, NodeState::Suspect => suspect_count += 1, _ => {} } } gauge!("stemedb_cluster_nodes_alive").set(alive_count as f64); gauge!("stemedb_cluster_nodes_suspect").set(suspect_count as f64); gauge!("stemedb_cluster_nodes_total").set((alive_count + suspect_count) as f64); } /// Marks a node as suspected (failed to respond to probe). #[instrument(skip(self))] pub fn suspect_node(&self, node_id: NodeId) { // IMPORTANT: Clone the entry and drop the RefMut BEFORE calling update_node_gauges. // DashMap::get_mut holds a shard write lock; update_node_gauges calls iter() which // acquires read locks on all shards. parking_lot write locks are non-reentrant — // calling iter() while get_mut's RefMut is alive deadlocks on the same shard. let gossip_entry = { if let Some(mut entry) = self.members.get_mut(&node_id) { if entry.state == NodeState::Alive { entry.state = NodeState::Suspect; entry.lamport_time = self.tick(); info!(node = %node_id.short_hex(), "Marking node as suspect"); let _ = self.event_tx.send(MembershipEvent::NodeSuspected(node_id)); self.suspects.insert(node_id, Instant::now()); counter!("stemedb_membership_events_total", "type" => "suspected").increment(1); Some(entry.clone()) } else { None } } else { None } }; // RefMut dropped here — safe to iterate the map now if let Some(entry) = gossip_entry { self.update_node_gauges(); self.queue_gossip(entry); } } /// Marks a node as dead (suspicion timeout expired). #[instrument(skip(self))] pub fn fail_node(&self, node_id: NodeId) { // IMPORTANT: same deadlock hazard as suspect_node — drop RefMut before update_node_gauges. let gossip_entry = { if let Some(mut entry) = self.members.get_mut(&node_id) { if entry.state == NodeState::Suspect { entry.state = NodeState::Dead; entry.lamport_time = self.tick(); warn!(node = %node_id.short_hex(), "Marking node as dead"); let _ = self.event_tx.send(MembershipEvent::NodeFailed(node_id)); self.suspects.remove(&node_id); counter!("stemedb_membership_events_total", "type" => "failed").increment(1); Some(entry.clone()) } else { None } } else { None } }; // RefMut dropped here if let Some(entry) = gossip_entry { self.update_node_gauges(); self.queue_gossip(entry); } } /// Marks a node as alive (responded to probe or refuted suspicion). #[instrument(skip(self))] pub fn alive_node(&self, node_id: NodeId, info: NodeInfo) { // Never add ourselves to the members map — self is tracked separately if node_id == self.local_id() { return; } let lamport = self.tick(); // IMPORTANT: same deadlock hazard — drop RefMut from get_mut before update_node_gauges. let result = { match self.members.get_mut(&node_id) { Some(mut entry) => { // Only update if incarnation is higher or equal if info.incarnation >= entry.node.incarnation { let was_suspect = entry.state == NodeState::Suspect; entry.node = info.clone(); entry.state = NodeState::Alive; entry.lamport_time = lamport; self.suspects.remove(&node_id); if was_suspect { counter!("stemedb_membership_events_total", "type" => "recovered") .increment(1); } Some((entry.clone(), MembershipEvent::NodeUpdated(info))) } else { None } } None => { // New node — insert() releases any lock immediately, so update_node_gauges // is safe to call right after. let entry = MembershipEntry::new(info.clone(), NodeState::Alive, lamport); self.members.insert(node_id, entry.clone()); self.queue_gossip(entry); counter!("stemedb_membership_events_total", "type" => "joined").increment(1); self.update_node_gauges(); let _ = self.event_tx.send(MembershipEvent::NodeJoined(info)); return; } } }; // RefMut dropped here if let Some((entry, event)) = result { self.update_node_gauges(); self.queue_gossip(entry); let _ = self.event_tx.send(event); } } /// Selects a random member for probing. pub fn select_probe_target(&self) -> Option { let candidates: Vec<_> = self .members .iter() .filter(|e| e.state == NodeState::Alive) .map(|e| e.node.id) .collect(); if candidates.is_empty() { return None; } let mut rng = rand::thread_rng(); candidates.choose(&mut rng).copied() } /// Selects K random members for indirect probing. pub fn select_indirect_targets(&self, exclude: NodeId) -> Vec { let candidates: Vec<_> = self .members .iter() .filter(|e| e.state == NodeState::Alive && e.node.id != exclude) .map(|e| e.node.id) .collect(); if candidates.is_empty() { return Vec::new(); } let mut rng = rand::thread_rng(); candidates.choose_multiple(&mut rng, self.config.indirect_probe_count).copied().collect() } /// Checks suspicion timeouts and promotes suspects to dead. pub fn check_suspicion_timeouts(&self) { let timeout = self.config.suspicion_timeout; let now = Instant::now(); let expired: Vec<_> = self .suspects .iter() .filter(|entry| now.duration_since(*entry.value()) > timeout) .map(|entry| *entry.key()) .collect(); for node_id in expired { self.fail_node(node_id); } } /// Gets pending gossip messages (up to max_count). pub fn get_gossip_batch(&self, max_count: usize) -> Vec { let mut queue = self.gossip_queue.write(); let count = max_count.min(queue.len()); queue.drain(..count).collect() } /// Queues a membership entry for gossip. fn queue_gossip(&self, entry: MembershipEntry) { let mut queue = self.gossip_queue.write(); if queue.len() < self.config.gossip_queue_size { queue.push_back(entry); } } /// Increments and returns the Lamport clock. fn tick(&self) -> u64 { self.lamport_clock.fetch_add(1, Ordering::SeqCst) + 1 } /// Returns whether this node has joined a cluster. pub fn is_joined(&self) -> bool { self.joined.load(Ordering::SeqCst) } /// Starts the background SWIM protocol tasks. /// /// This spawns background tasks for: /// - Periodic probing /// - Suspicion timeout checking /// - Gossip dissemination /// /// Marks the protocol as running. /// /// Background probe/gossip tasks are not yet spawned internally. /// The protocol logic is currently driven externally via /// `check_suspicion_timeouts()`, `select_probe_target()`, and /// `get_gossip_batch()`. pub fn start(&self) { self.running.store(true, Ordering::SeqCst); } /// Spawns background SWIM protocol tasks (probe, suspicion check, gossip). /// /// Must be called after `start()` and `join()`. Spawns 3 tokio tasks that /// run until `stop()` is called. pub fn spawn_background_tasks(self: &Arc) { let membership = Arc::clone(self); let probe_interval = membership.config.probe_interval; // 1. Probe loop — pings a random member every probe_interval let m = Arc::clone(&membership); tokio::spawn(async move { let mut ticker = tokio::time::interval(probe_interval); loop { ticker.tick().await; if !m.is_running() { break; } let target_id = match m.select_probe_target() { Some(id) => id, None => continue, }; let target_info = match m.get_member(target_id) { Some(info) => info, None => continue, }; let addr = format!("http://{}", target_info.rpc_addr); let local_id = m.local_id(); let gossip_batch = m.get_gossip_batch(5); let updates: Vec = gossip_batch .iter() .map(|entry| stemedb_rpc::proto::MembershipUpdate { node_id: entry.node.id.as_bytes().to_vec(), rpc_addr: entry.node.rpc_addr.to_string(), api_addr: entry.node.api_addr.to_string(), state: match entry.state { NodeState::Alive => 0, NodeState::Suspect => 1, NodeState::Dead => 2, NodeState::Left => 3, }, lamport_time: entry.lamport_time, incarnation: entry.node.incarnation, }) .collect(); match stemedb_rpc::SyncClient::connect(&addr).await { Ok(client) => { let ping = stemedb_rpc::proto::PingRequest { node_id: local_id.as_bytes().to_vec(), updates, }; match client.ping(ping).await { Ok(resp) => { if resp.node_id.len() >= 16 { let mut id_bytes = [0u8; 16]; id_bytes.copy_from_slice(&resp.node_id[..16]); let peer_id = NodeId::from_bytes(id_bytes); let peer_rpc: std::net::SocketAddr = resp.rpc_addr.parse().unwrap_or(target_info.rpc_addr); let peer_api: std::net::SocketAddr = resp.api_addr.parse().unwrap_or(target_info.api_addr); m.alive_node( peer_id, NodeInfo::new(peer_id, peer_rpc, peer_api), ); } // Process piggybacked membership updates for update in resp.updates { if update.node_id.len() >= 16 { let mut id_bytes = [0u8; 16]; id_bytes.copy_from_slice(&update.node_id[..16]); let upd_id = NodeId::from_bytes(id_bytes); if upd_id == m.local_id() { continue; } let upd_rpc: std::net::SocketAddr = match update.rpc_addr.parse() { Ok(a) => a, Err(_) => continue, }; let upd_api: std::net::SocketAddr = match update.api_addr.parse() { Ok(a) => a, Err(_) => continue, }; let mut node_info = NodeInfo::new(upd_id, upd_rpc, upd_api); node_info.incarnation = update.incarnation; let state = match update.state { 0 => NodeState::Alive, 1 => NodeState::Suspect, 2 => NodeState::Dead, _ => NodeState::Left, }; let entry = MembershipEntry::new( node_info, state, update.lamport_time, ); m.process_membership_update(entry); } } } Err(_) => { m.suspect_node(target_id); } } } Err(_) => { m.suspect_node(target_id); } } } }); // 2. Suspicion checker — promotes suspects to dead let m = Arc::clone(&membership); tokio::spawn(async move { let mut ticker = tokio::time::interval(std::time::Duration::from_secs(1)); loop { ticker.tick().await; if !m.is_running() { break; } m.check_suspicion_timeouts(); } }); // 3. Gossip disseminator — sends batched updates to random peers let m = Arc::clone(&membership); let gossip_interval = membership.config.gossip_interval; tokio::spawn(async move { let mut ticker = tokio::time::interval(gossip_interval); loop { ticker.tick().await; if !m.is_running() { break; } let batch = m.get_gossip_batch(5); if batch.is_empty() { continue; } // Pick a random alive member to send gossip to let target = match m.select_probe_target() { Some(id) => id, None => continue, }; let target_info = match m.get_member(target) { Some(info) => info, None => continue, }; let addr = format!("http://{}", target_info.rpc_addr); let updates: Vec = batch .iter() .map(|entry| stemedb_rpc::proto::MembershipUpdate { node_id: entry.node.id.as_bytes().to_vec(), rpc_addr: entry.node.rpc_addr.to_string(), api_addr: entry.node.api_addr.to_string(), state: match entry.state { NodeState::Alive => 0, NodeState::Suspect => 1, NodeState::Dead => 2, NodeState::Left => 3, }, lamport_time: entry.lamport_time, incarnation: entry.node.incarnation, }) .collect(); if let Ok(client) = stemedb_rpc::SyncClient::connect(&addr).await { let ping = stemedb_rpc::proto::PingRequest { node_id: m.local_id().as_bytes().to_vec(), updates, }; let _ = client.ping(ping).await; } } }); } /// Stops the background SWIM protocol tasks. pub fn stop(&self) { self.running.store(false, Ordering::SeqCst); } /// Returns whether the protocol is running. pub fn is_running(&self) -> bool { self.running.load(Ordering::SeqCst) } } impl std::fmt::Debug for SwimMembership { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SwimMembership") .field("local_id", &self.local_id().short_hex()) .field("member_count", &self.member_count()) .field("joined", &self.joined.load(Ordering::SeqCst)) .field("running", &self.running.load(Ordering::SeqCst)) .finish() } } #[cfg(test)] #[path = "swim_tests.rs"] mod tests;