laminar_core/state/
in_process.rs1use async_trait::async_trait;
5use bytes::Bytes;
6use parking_lot::RwLock;
7use rustc_hash::FxHashMap;
8
9use super::backend::{StateBackend, StateBackendError};
10
11#[derive(Debug)]
13pub struct InProcessBackend {
14 partials: RwLock<FxHashMap<(u32, u64), Bytes>>,
15 vnode_capacity: u32,
16}
17
18impl InProcessBackend {
19 #[must_use]
21 pub fn new(vnode_capacity: u32) -> Self {
22 Self {
23 partials: RwLock::new(FxHashMap::default()),
24 vnode_capacity,
25 }
26 }
27
28 #[must_use]
30 pub fn vnode_capacity(&self) -> u32 {
31 self.vnode_capacity
32 }
33
34 fn check_vnode(&self, v: u32) -> Result<(), StateBackendError> {
35 if v >= self.vnode_capacity {
36 Err(StateBackendError::Io(format!(
37 "vnode {v} out of range (capacity {})",
38 self.vnode_capacity
39 )))
40 } else {
41 Ok(())
42 }
43 }
44}
45
46#[async_trait]
47impl StateBackend for InProcessBackend {
48 async fn write_partial(
52 &self,
53 vnode: u32,
54 epoch: u64,
55 _assignment_version: u64,
56 bytes: Bytes,
57 ) -> Result<(), StateBackendError> {
58 self.check_vnode(vnode)?;
59 self.partials.write().insert((vnode, epoch), bytes);
60 Ok(())
61 }
62
63 async fn read_partial(
64 &self,
65 vnode: u32,
66 epoch: u64,
67 ) -> Result<Option<Bytes>, StateBackendError> {
68 self.check_vnode(vnode)?;
69 Ok(self.partials.read().get(&(vnode, epoch)).cloned())
70 }
71
72 async fn epoch_complete(&self, epoch: u64, vnodes: &[u32]) -> Result<bool, StateBackendError> {
73 let map = self.partials.read();
74 for &v in vnodes {
75 self.check_vnode(v)?;
76 if !map.contains_key(&(v, epoch)) {
77 return Ok(false);
78 }
79 }
80 Ok(true)
81 }
82
83 async fn prune_before(&self, before: u64) -> Result<(), StateBackendError> {
84 self.partials
87 .write()
88 .retain(|&(_, epoch), _| epoch >= before);
89 Ok(())
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96
97 #[tokio::test]
98 async fn write_read_roundtrip() {
99 let b = InProcessBackend::new(4);
100 let payload = Bytes::from_static(b"hello");
101 b.write_partial(2, 7, 0, payload.clone()).await.unwrap();
102 let got = b.read_partial(2, 7).await.unwrap().unwrap();
103 assert_eq!(got, payload);
104 assert!(b.read_partial(2, 8).await.unwrap().is_none());
105 }
106
107 #[tokio::test]
108 async fn epoch_complete_requires_every_vnode() {
109 let b = InProcessBackend::new(4);
110 let vnodes = [0u32, 1, 2];
111 assert!(!b.epoch_complete(1, &vnodes).await.unwrap());
112 b.write_partial(0, 1, 0, Bytes::from_static(b"a"))
113 .await
114 .unwrap();
115 b.write_partial(1, 1, 0, Bytes::from_static(b"b"))
116 .await
117 .unwrap();
118 assert!(!b.epoch_complete(1, &vnodes).await.unwrap());
119 b.write_partial(2, 1, 0, Bytes::from_static(b"c"))
120 .await
121 .unwrap();
122 assert!(b.epoch_complete(1, &vnodes).await.unwrap());
123 assert!(!b.epoch_complete(2, &vnodes).await.unwrap());
124 }
125
126 #[tokio::test]
127 async fn out_of_range_vnode_errors() {
128 let b = InProcessBackend::new(2);
129 let r = b
130 .write_partial(5, 1, 0, Bytes::from_static(b"x"))
131 .await
132 .unwrap_err();
133 assert!(matches!(r, StateBackendError::Io(_)));
134 }
135
136 #[test]
137 fn state_backend_is_object_safe() {
138 let _: std::sync::Arc<dyn StateBackend> = std::sync::Arc::new(InProcessBackend::new(2));
139 }
140
141 #[tokio::test]
142 async fn prune_before_drops_old_epochs() {
143 let b = InProcessBackend::new(4);
144 for epoch in 1..=5 {
145 b.write_partial(0, epoch, 0, Bytes::from_static(b"x"))
146 .await
147 .unwrap();
148 b.write_partial(1, epoch, 0, Bytes::from_static(b"y"))
149 .await
150 .unwrap();
151 }
152 b.prune_before(4).await.unwrap();
154 for epoch in 1..=3 {
155 assert!(b.read_partial(0, epoch).await.unwrap().is_none());
156 assert!(b.read_partial(1, epoch).await.unwrap().is_none());
157 }
158 for epoch in 4..=5 {
159 assert!(b.read_partial(0, epoch).await.unwrap().is_some());
160 assert!(b.read_partial(1, epoch).await.unwrap().is_some());
161 }
162 }
163}