Skip to main content

laminar_core/state/
in_process.rs

1//! [`InProcessBackend`] — non-durable [`StateBackend`] backed by an
2//! in-memory hashmap. Used for tests and embedded single-process runs.
3
4use async_trait::async_trait;
5use bytes::Bytes;
6use parking_lot::RwLock;
7use rustc_hash::FxHashMap;
8
9use super::backend::{StateBackend, StateBackendError};
10
11/// In-process, non-durable state backend.
12#[derive(Debug)]
13pub struct InProcessBackend {
14    partials: RwLock<FxHashMap<(u32, u64), Bytes>>,
15    /// Highest epoch for which [`epoch_complete`](StateBackend::epoch_complete)
16    /// observed every requested vnode present — the in-memory analogue of
17    /// the object-store `_COMMIT` marker, surfaced by
18    /// [`latest_committed_epoch`](StateBackend::latest_committed_epoch).
19    committed_high: RwLock<Option<u64>>,
20    vnode_capacity: u32,
21}
22
23impl InProcessBackend {
24    /// Create a new backend sized for `vnode_capacity` vnodes.
25    #[must_use]
26    pub fn new(vnode_capacity: u32) -> Self {
27        Self {
28            partials: RwLock::new(FxHashMap::default()),
29            committed_high: RwLock::new(None),
30            vnode_capacity,
31        }
32    }
33
34    /// Vnode range this backend is configured for.
35    #[must_use]
36    pub fn vnode_capacity(&self) -> u32 {
37        self.vnode_capacity
38    }
39
40    fn check_vnode(&self, v: u32) -> Result<(), StateBackendError> {
41        if v >= self.vnode_capacity {
42            Err(StateBackendError::Io(format!(
43                "vnode {v} out of range (capacity {})",
44                self.vnode_capacity
45            )))
46        } else {
47            Ok(())
48        }
49    }
50}
51
52#[async_trait]
53impl StateBackend for InProcessBackend {
54    /// In-process backend opts out of the split-brain fence — there's
55    /// only one process so the scenario is moot. `assignment_version`
56    /// is accepted and ignored.
57    async fn write_partial(
58        &self,
59        vnode: u32,
60        epoch: u64,
61        _assignment_version: u64,
62        bytes: Bytes,
63    ) -> Result<(), StateBackendError> {
64        self.check_vnode(vnode)?;
65        self.partials.write().insert((vnode, epoch), bytes);
66        Ok(())
67    }
68
69    async fn read_partial(
70        &self,
71        vnode: u32,
72        epoch: u64,
73    ) -> Result<Option<Bytes>, StateBackendError> {
74        self.check_vnode(vnode)?;
75        Ok(self.partials.read().get(&(vnode, epoch)).cloned())
76    }
77
78    async fn epoch_complete(&self, epoch: u64, vnodes: &[u32]) -> Result<bool, StateBackendError> {
79        {
80            let map = self.partials.read();
81            for &v in vnodes {
82                self.check_vnode(v)?;
83                if !map.contains_key(&(v, epoch)) {
84                    return Ok(false);
85                }
86            }
87        }
88        // Every vnode is durable: this epoch is sealed. Record it as the
89        // committed high-water mark so rehydration can find it later.
90        let mut hi = self.committed_high.write();
91        *hi = Some(hi.map_or(epoch, |h| h.max(epoch)));
92        Ok(true)
93    }
94
95    async fn prune_before(&self, before: u64) -> Result<(), StateBackendError> {
96        // Without this, every checkpoint leaks one Bytes per vnode
97        // forever.
98        self.partials
99            .write()
100            .retain(|&(_, epoch), _| epoch >= before);
101        Ok(())
102    }
103
104    async fn latest_committed_epoch(&self) -> Result<Option<u64>, StateBackendError> {
105        Ok(*self.committed_high.read())
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112
113    #[tokio::test]
114    async fn write_read_roundtrip() {
115        let b = InProcessBackend::new(4);
116        let payload = Bytes::from_static(b"hello");
117        b.write_partial(2, 7, 0, payload.clone()).await.unwrap();
118        let got = b.read_partial(2, 7).await.unwrap().unwrap();
119        assert_eq!(got, payload);
120        assert!(b.read_partial(2, 8).await.unwrap().is_none());
121    }
122
123    #[tokio::test]
124    async fn epoch_complete_requires_every_vnode() {
125        let b = InProcessBackend::new(4);
126        let vnodes = [0u32, 1, 2];
127        assert!(!b.epoch_complete(1, &vnodes).await.unwrap());
128        b.write_partial(0, 1, 0, Bytes::from_static(b"a"))
129            .await
130            .unwrap();
131        b.write_partial(1, 1, 0, Bytes::from_static(b"b"))
132            .await
133            .unwrap();
134        assert!(!b.epoch_complete(1, &vnodes).await.unwrap());
135        b.write_partial(2, 1, 0, Bytes::from_static(b"c"))
136            .await
137            .unwrap();
138        assert!(b.epoch_complete(1, &vnodes).await.unwrap());
139        assert!(!b.epoch_complete(2, &vnodes).await.unwrap());
140    }
141
142    #[tokio::test]
143    async fn latest_committed_epoch_follows_epoch_complete() {
144        let b = InProcessBackend::new(4);
145        let vnodes = [0u32, 1];
146        assert_eq!(b.latest_committed_epoch().await.unwrap(), None);
147
148        // A speculative gate that returns false must not advance the mark.
149        b.write_partial(0, 2, 0, Bytes::from_static(b"a"))
150            .await
151            .unwrap();
152        assert!(!b.epoch_complete(2, &vnodes).await.unwrap());
153        assert_eq!(b.latest_committed_epoch().await.unwrap(), None);
154
155        // Complete epoch 2, then epoch 5 — the mark tracks the highest.
156        b.write_partial(1, 2, 0, Bytes::from_static(b"b"))
157            .await
158            .unwrap();
159        assert!(b.epoch_complete(2, &vnodes).await.unwrap());
160        assert_eq!(b.latest_committed_epoch().await.unwrap(), Some(2));
161
162        for v in &vnodes {
163            b.write_partial(*v, 5, 0, Bytes::from_static(b"c"))
164                .await
165                .unwrap();
166        }
167        assert!(b.epoch_complete(5, &vnodes).await.unwrap());
168        assert_eq!(b.latest_committed_epoch().await.unwrap(), Some(5));
169    }
170
171    #[tokio::test]
172    async fn out_of_range_vnode_errors() {
173        let b = InProcessBackend::new(2);
174        let r = b
175            .write_partial(5, 1, 0, Bytes::from_static(b"x"))
176            .await
177            .unwrap_err();
178        assert!(matches!(r, StateBackendError::Io(_)));
179    }
180
181    #[test]
182    fn state_backend_is_object_safe() {
183        let _: std::sync::Arc<dyn StateBackend> = std::sync::Arc::new(InProcessBackend::new(2));
184    }
185
186    #[tokio::test]
187    async fn prune_before_drops_old_epochs() {
188        let b = InProcessBackend::new(4);
189        for epoch in 1..=5 {
190            b.write_partial(0, epoch, 0, Bytes::from_static(b"x"))
191                .await
192                .unwrap();
193            b.write_partial(1, epoch, 0, Bytes::from_static(b"y"))
194                .await
195                .unwrap();
196        }
197        // Retain epochs >= 4. Entries for 1,2,3 must go away.
198        b.prune_before(4).await.unwrap();
199        for epoch in 1..=3 {
200            assert!(b.read_partial(0, epoch).await.unwrap().is_none());
201            assert!(b.read_partial(1, epoch).await.unwrap().is_none());
202        }
203        for epoch in 4..=5 {
204            assert!(b.read_partial(0, epoch).await.unwrap().is_some());
205            assert!(b.read_partial(1, epoch).await.unwrap().is_some());
206        }
207    }
208}