Skip to main content

laminar_core/state/
in_process.rs

1//! [`InProcessBackend`] — non-durable [`StateBackend`] backed by an
2//! in-memory hashmap. Used for tests and embedded single-process runs.
3
4use async_trait::async_trait;
5use bytes::Bytes;
6use parking_lot::RwLock;
7use rustc_hash::FxHashMap;
8
9use super::backend::{StateBackend, StateBackendError};
10
11/// In-process, non-durable state backend.
12#[derive(Debug)]
13pub struct InProcessBackend {
14    partials: RwLock<FxHashMap<(u32, u64), Bytes>>,
15    vnode_capacity: u32,
16}
17
18impl InProcessBackend {
19    /// Create a new backend sized for `vnode_capacity` vnodes.
20    #[must_use]
21    pub fn new(vnode_capacity: u32) -> Self {
22        Self {
23            partials: RwLock::new(FxHashMap::default()),
24            vnode_capacity,
25        }
26    }
27
28    /// Vnode range this backend is configured for.
29    #[must_use]
30    pub fn vnode_capacity(&self) -> u32 {
31        self.vnode_capacity
32    }
33
34    fn check_vnode(&self, v: u32) -> Result<(), StateBackendError> {
35        if v >= self.vnode_capacity {
36            Err(StateBackendError::Io(format!(
37                "vnode {v} out of range (capacity {})",
38                self.vnode_capacity
39            )))
40        } else {
41            Ok(())
42        }
43    }
44}
45
46#[async_trait]
47impl StateBackend for InProcessBackend {
48    /// In-process backend opts out of the split-brain fence — there's
49    /// only one process so the scenario is moot. `assignment_version`
50    /// is accepted and ignored.
51    async fn write_partial(
52        &self,
53        vnode: u32,
54        epoch: u64,
55        _assignment_version: u64,
56        bytes: Bytes,
57    ) -> Result<(), StateBackendError> {
58        self.check_vnode(vnode)?;
59        self.partials.write().insert((vnode, epoch), bytes);
60        Ok(())
61    }
62
63    async fn read_partial(
64        &self,
65        vnode: u32,
66        epoch: u64,
67    ) -> Result<Option<Bytes>, StateBackendError> {
68        self.check_vnode(vnode)?;
69        Ok(self.partials.read().get(&(vnode, epoch)).cloned())
70    }
71
72    async fn epoch_complete(&self, epoch: u64, vnodes: &[u32]) -> Result<bool, StateBackendError> {
73        let map = self.partials.read();
74        for &v in vnodes {
75            self.check_vnode(v)?;
76            if !map.contains_key(&(v, epoch)) {
77                return Ok(false);
78            }
79        }
80        Ok(true)
81    }
82
83    async fn prune_before(&self, before: u64) -> Result<(), StateBackendError> {
84        // Without this, every checkpoint leaks one Bytes per vnode
85        // forever.
86        self.partials
87            .write()
88            .retain(|&(_, epoch), _| epoch >= before);
89        Ok(())
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96
97    #[tokio::test]
98    async fn write_read_roundtrip() {
99        let b = InProcessBackend::new(4);
100        let payload = Bytes::from_static(b"hello");
101        b.write_partial(2, 7, 0, payload.clone()).await.unwrap();
102        let got = b.read_partial(2, 7).await.unwrap().unwrap();
103        assert_eq!(got, payload);
104        assert!(b.read_partial(2, 8).await.unwrap().is_none());
105    }
106
107    #[tokio::test]
108    async fn epoch_complete_requires_every_vnode() {
109        let b = InProcessBackend::new(4);
110        let vnodes = [0u32, 1, 2];
111        assert!(!b.epoch_complete(1, &vnodes).await.unwrap());
112        b.write_partial(0, 1, 0, Bytes::from_static(b"a"))
113            .await
114            .unwrap();
115        b.write_partial(1, 1, 0, Bytes::from_static(b"b"))
116            .await
117            .unwrap();
118        assert!(!b.epoch_complete(1, &vnodes).await.unwrap());
119        b.write_partial(2, 1, 0, Bytes::from_static(b"c"))
120            .await
121            .unwrap();
122        assert!(b.epoch_complete(1, &vnodes).await.unwrap());
123        assert!(!b.epoch_complete(2, &vnodes).await.unwrap());
124    }
125
126    #[tokio::test]
127    async fn out_of_range_vnode_errors() {
128        let b = InProcessBackend::new(2);
129        let r = b
130            .write_partial(5, 1, 0, Bytes::from_static(b"x"))
131            .await
132            .unwrap_err();
133        assert!(matches!(r, StateBackendError::Io(_)));
134    }
135
136    #[test]
137    fn state_backend_is_object_safe() {
138        let _: std::sync::Arc<dyn StateBackend> = std::sync::Arc::new(InProcessBackend::new(2));
139    }
140
141    #[tokio::test]
142    async fn prune_before_drops_old_epochs() {
143        let b = InProcessBackend::new(4);
144        for epoch in 1..=5 {
145            b.write_partial(0, epoch, 0, Bytes::from_static(b"x"))
146                .await
147                .unwrap();
148            b.write_partial(1, epoch, 0, Bytes::from_static(b"y"))
149                .await
150                .unwrap();
151        }
152        // Retain epochs >= 4. Entries for 1,2,3 must go away.
153        b.prune_before(4).await.unwrap();
154        for epoch in 1..=3 {
155            assert!(b.read_partial(0, epoch).await.unwrap().is_none());
156            assert!(b.read_partial(1, epoch).await.unwrap().is_none());
157        }
158        for epoch in 4..=5 {
159            assert!(b.read_partial(0, epoch).await.unwrap().is_some());
160            assert!(b.read_partial(1, epoch).await.unwrap().is_some());
161        }
162    }
163}