1use async_trait::async_trait;
5use bytes::Bytes;
6use parking_lot::RwLock;
7use rustc_hash::FxHashMap;
8
9use super::backend::{StateBackend, StateBackendError};
10
11#[derive(Debug)]
13pub struct InProcessBackend {
14 partials: RwLock<FxHashMap<(u32, u64), Bytes>>,
15 committed_high: RwLock<Option<u64>>,
20 vnode_capacity: u32,
21}
22
23impl InProcessBackend {
24 #[must_use]
26 pub fn new(vnode_capacity: u32) -> Self {
27 Self {
28 partials: RwLock::new(FxHashMap::default()),
29 committed_high: RwLock::new(None),
30 vnode_capacity,
31 }
32 }
33
34 #[must_use]
36 pub fn vnode_capacity(&self) -> u32 {
37 self.vnode_capacity
38 }
39
40 fn check_vnode(&self, v: u32) -> Result<(), StateBackendError> {
41 if v >= self.vnode_capacity {
42 Err(StateBackendError::Io(format!(
43 "vnode {v} out of range (capacity {})",
44 self.vnode_capacity
45 )))
46 } else {
47 Ok(())
48 }
49 }
50}
51
52#[async_trait]
53impl StateBackend for InProcessBackend {
54 async fn write_partial(
58 &self,
59 vnode: u32,
60 epoch: u64,
61 _assignment_version: u64,
62 bytes: Bytes,
63 ) -> Result<(), StateBackendError> {
64 self.check_vnode(vnode)?;
65 self.partials.write().insert((vnode, epoch), bytes);
66 Ok(())
67 }
68
69 async fn read_partial(
70 &self,
71 vnode: u32,
72 epoch: u64,
73 ) -> Result<Option<Bytes>, StateBackendError> {
74 self.check_vnode(vnode)?;
75 Ok(self.partials.read().get(&(vnode, epoch)).cloned())
76 }
77
78 async fn epoch_complete(&self, epoch: u64, vnodes: &[u32]) -> Result<bool, StateBackendError> {
79 {
80 let map = self.partials.read();
81 for &v in vnodes {
82 self.check_vnode(v)?;
83 if !map.contains_key(&(v, epoch)) {
84 return Ok(false);
85 }
86 }
87 }
88 let mut hi = self.committed_high.write();
91 *hi = Some(hi.map_or(epoch, |h| h.max(epoch)));
92 Ok(true)
93 }
94
95 async fn prune_before(&self, before: u64) -> Result<(), StateBackendError> {
96 self.partials
99 .write()
100 .retain(|&(_, epoch), _| epoch >= before);
101 Ok(())
102 }
103
104 async fn latest_committed_epoch(&self) -> Result<Option<u64>, StateBackendError> {
105 Ok(*self.committed_high.read())
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112
113 #[tokio::test]
114 async fn write_read_roundtrip() {
115 let b = InProcessBackend::new(4);
116 let payload = Bytes::from_static(b"hello");
117 b.write_partial(2, 7, 0, payload.clone()).await.unwrap();
118 let got = b.read_partial(2, 7).await.unwrap().unwrap();
119 assert_eq!(got, payload);
120 assert!(b.read_partial(2, 8).await.unwrap().is_none());
121 }
122
123 #[tokio::test]
124 async fn epoch_complete_requires_every_vnode() {
125 let b = InProcessBackend::new(4);
126 let vnodes = [0u32, 1, 2];
127 assert!(!b.epoch_complete(1, &vnodes).await.unwrap());
128 b.write_partial(0, 1, 0, Bytes::from_static(b"a"))
129 .await
130 .unwrap();
131 b.write_partial(1, 1, 0, Bytes::from_static(b"b"))
132 .await
133 .unwrap();
134 assert!(!b.epoch_complete(1, &vnodes).await.unwrap());
135 b.write_partial(2, 1, 0, Bytes::from_static(b"c"))
136 .await
137 .unwrap();
138 assert!(b.epoch_complete(1, &vnodes).await.unwrap());
139 assert!(!b.epoch_complete(2, &vnodes).await.unwrap());
140 }
141
142 #[tokio::test]
143 async fn latest_committed_epoch_follows_epoch_complete() {
144 let b = InProcessBackend::new(4);
145 let vnodes = [0u32, 1];
146 assert_eq!(b.latest_committed_epoch().await.unwrap(), None);
147
148 b.write_partial(0, 2, 0, Bytes::from_static(b"a"))
150 .await
151 .unwrap();
152 assert!(!b.epoch_complete(2, &vnodes).await.unwrap());
153 assert_eq!(b.latest_committed_epoch().await.unwrap(), None);
154
155 b.write_partial(1, 2, 0, Bytes::from_static(b"b"))
157 .await
158 .unwrap();
159 assert!(b.epoch_complete(2, &vnodes).await.unwrap());
160 assert_eq!(b.latest_committed_epoch().await.unwrap(), Some(2));
161
162 for v in &vnodes {
163 b.write_partial(*v, 5, 0, Bytes::from_static(b"c"))
164 .await
165 .unwrap();
166 }
167 assert!(b.epoch_complete(5, &vnodes).await.unwrap());
168 assert_eq!(b.latest_committed_epoch().await.unwrap(), Some(5));
169 }
170
171 #[tokio::test]
172 async fn out_of_range_vnode_errors() {
173 let b = InProcessBackend::new(2);
174 let r = b
175 .write_partial(5, 1, 0, Bytes::from_static(b"x"))
176 .await
177 .unwrap_err();
178 assert!(matches!(r, StateBackendError::Io(_)));
179 }
180
181 #[test]
182 fn state_backend_is_object_safe() {
183 let _: std::sync::Arc<dyn StateBackend> = std::sync::Arc::new(InProcessBackend::new(2));
184 }
185
186 #[tokio::test]
187 async fn prune_before_drops_old_epochs() {
188 let b = InProcessBackend::new(4);
189 for epoch in 1..=5 {
190 b.write_partial(0, epoch, 0, Bytes::from_static(b"x"))
191 .await
192 .unwrap();
193 b.write_partial(1, epoch, 0, Bytes::from_static(b"y"))
194 .await
195 .unwrap();
196 }
197 b.prune_before(4).await.unwrap();
199 for epoch in 1..=3 {
200 assert!(b.read_partial(0, epoch).await.unwrap().is_none());
201 assert!(b.read_partial(1, epoch).await.unwrap().is_none());
202 }
203 for epoch in 4..=5 {
204 assert!(b.read_partial(0, epoch).await.unwrap().is_some());
205 assert!(b.read_partial(1, epoch).await.unwrap().is_some());
206 }
207 }
208}