1use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use bytes::Bytes;
14use object_store::path::Path as OsPath;
15use object_store::{ObjectStore, ObjectStoreExt, PutMode, PutOptions, PutPayload};
16
17use super::backend::{StateBackend, StateBackendError};
18
19pub struct ObjectStoreBackend {
21 store: Arc<dyn ObjectStore>,
22 instance_id: String,
23 committer_bytes: Bytes,
27 vnode_capacity: u32,
28 authoritative_version: Arc<AtomicU64>,
37}
38
39impl std::fmt::Debug for ObjectStoreBackend {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 f.debug_struct("ObjectStoreBackend")
42 .field("instance_id", &self.instance_id)
43 .field("vnode_capacity", &self.vnode_capacity)
44 .finish_non_exhaustive()
45 }
46}
47
48impl ObjectStoreBackend {
49 #[must_use]
51 pub fn new(
52 store: Arc<dyn ObjectStore>,
53 instance_id: impl Into<String>,
54 vnode_capacity: u32,
55 ) -> Self {
56 let instance_id = instance_id.into();
57 let committer_bytes = Bytes::from(instance_id.clone().into_bytes());
58 Self {
59 store,
60 instance_id,
61 committer_bytes,
62 vnode_capacity,
63 authoritative_version: Arc::new(AtomicU64::new(0)),
64 }
65 }
66
67 #[must_use]
69 pub fn vnode_capacity(&self) -> u32 {
70 self.vnode_capacity
71 }
72
73 #[must_use]
78 pub fn authoritative_version_handle(&self) -> Arc<AtomicU64> {
79 Arc::clone(&self.authoritative_version)
80 }
81
82 fn check_vnode(&self, v: u32) -> Result<(), StateBackendError> {
83 if v >= self.vnode_capacity {
84 Err(StateBackendError::Io(format!(
85 "vnode {v} out of range (capacity {})",
86 self.vnode_capacity
87 )))
88 } else {
89 Ok(())
90 }
91 }
92
93 fn partial_path(epoch: u64, vnode: u32) -> OsPath {
94 OsPath::from(format!("epoch={epoch}/vnode={vnode}/partial.bin"))
95 }
96
97 fn commit_path(epoch: u64) -> OsPath {
98 OsPath::from(format!("epoch={epoch}/_COMMIT"))
99 }
100}
101
102#[async_trait]
103impl StateBackend for ObjectStoreBackend {
104 async fn write_partial(
105 &self,
106 vnode: u32,
107 epoch: u64,
108 assignment_version: u64,
109 bytes: Bytes,
110 ) -> Result<(), StateBackendError> {
111 self.check_vnode(vnode)?;
112 let authoritative = self.authoritative_version.load(Ordering::Acquire);
118 if authoritative > 0 && assignment_version < authoritative {
119 return Err(StateBackendError::StaleVersion {
120 caller: assignment_version,
121 authoritative,
122 });
123 }
124 let path = Self::partial_path(epoch, vnode);
125 self.store
126 .put(&path, PutPayload::from(bytes))
127 .await
128 .map_err(|e| StateBackendError::Io(e.to_string()))?;
129 Ok(())
130 }
131
132 async fn read_partial(
133 &self,
134 vnode: u32,
135 epoch: u64,
136 ) -> Result<Option<Bytes>, StateBackendError> {
137 self.check_vnode(vnode)?;
138 let path = Self::partial_path(epoch, vnode);
139 match self.store.get(&path).await {
140 Ok(res) => {
141 let b = res
142 .bytes()
143 .await
144 .map_err(|e| StateBackendError::Io(e.to_string()))?;
145 Ok(Some(b))
146 }
147 Err(object_store::Error::NotFound { .. }) => Ok(None),
148 Err(e) => Err(StateBackendError::Io(e.to_string())),
149 }
150 }
151
152 async fn epoch_complete(&self, epoch: u64, vnodes: &[u32]) -> Result<bool, StateBackendError> {
153 let commit = Self::commit_path(epoch);
154 match self.store.head(&commit).await {
159 Ok(_) => return self.verify_commit_marker(&commit).await,
160 Err(object_store::Error::NotFound { .. }) => {}
161 Err(e) => return Err(StateBackendError::Io(e.to_string())),
162 }
163
164 let mut set = tokio::task::JoinSet::new();
169 for &v in vnodes {
170 self.check_vnode(v)?;
171 let store = Arc::clone(&self.store);
172 let path = Self::partial_path(epoch, v);
173 set.spawn(async move { store.head(&path).await });
174 }
175 while let Some(joined) = set.join_next().await {
176 match joined {
177 Ok(Ok(_)) => {}
178 Ok(Err(object_store::Error::NotFound { .. })) => {
179 set.abort_all();
180 return Ok(false);
181 }
182 Ok(Err(e)) => {
183 set.abort_all();
184 return Err(StateBackendError::Io(e.to_string()));
185 }
186 Err(join_err) => {
187 set.abort_all();
188 return Err(StateBackendError::Io(format!(
189 "epoch_complete HEAD task failed: {join_err}"
190 )));
191 }
192 }
193 }
194
195 let payload = PutPayload::from(self.committer_bytes.clone());
197 let opts = PutOptions {
198 mode: PutMode::Create,
199 ..Default::default()
200 };
201 match self.store.put_opts(&commit, payload, opts).await {
202 Ok(_) => Ok(true),
203 Err(object_store::Error::AlreadyExists { .. }) => {
207 self.verify_commit_marker(&commit).await
208 }
209 Err(e) => Err(StateBackendError::Io(e.to_string())),
210 }
211 }
212
213 async fn prune_before(&self, before: u64) -> Result<(), StateBackendError> {
214 use tokio_stream::StreamExt;
215
216 let mut entries = self.store.list(None);
217 let mut victims: Vec<OsPath> = Vec::new();
218 while let Some(entry) = entries.next().await {
219 let entry = entry.map_err(|e| StateBackendError::Io(e.to_string()))?;
220 let loc = entry.location.as_ref();
221 let first = loc.split('/').next().unwrap_or("");
224 let Some(rest) = first.strip_prefix("epoch=") else {
225 continue;
226 };
227 let Ok(epoch) = rest.parse::<u64>() else {
228 continue;
229 };
230 if epoch < before {
231 victims.push(entry.location);
232 }
233 }
234
235 for victim in victims {
240 match self.store.delete(&victim).await {
241 Ok(()) | Err(object_store::Error::NotFound { .. }) => {}
242 Err(e) => tracing::warn!(error = %e, "state backend prune: delete failed"),
243 }
244 }
245 Ok(())
246 }
247
248 fn set_authoritative_version(&self, version: u64) {
249 let mut cur = self.authoritative_version.load(Ordering::Acquire);
251 while version > cur {
252 match self.authoritative_version.compare_exchange(
253 cur,
254 version,
255 Ordering::AcqRel,
256 Ordering::Acquire,
257 ) {
258 Ok(_) => return,
259 Err(observed) => cur = observed,
260 }
261 }
262 }
263
264 fn authoritative_version(&self) -> u64 {
265 self.authoritative_version.load(Ordering::Acquire)
266 }
267}
268
269impl ObjectStoreBackend {
270 async fn verify_commit_marker(&self, commit: &OsPath) -> Result<bool, StateBackendError> {
276 let res = self
277 .store
278 .get(commit)
279 .await
280 .map_err(|e| StateBackendError::Io(e.to_string()))?;
281 let bytes = res
282 .bytes()
283 .await
284 .map_err(|e| StateBackendError::Io(e.to_string()))?;
285 let committer = std::str::from_utf8(&bytes).map_err(|e| {
286 StateBackendError::Serialization(format!("commit marker not utf8: {e}"))
287 })?;
288 if committer == self.instance_id.as_str() {
289 Ok(true)
290 } else {
291 Err(StateBackendError::SplitBrainCommit {
292 committer: committer.to_string(),
293 self_id: self.instance_id.clone(),
294 })
295 }
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302 use object_store::local::LocalFileSystem;
303 use tempfile::tempdir;
304
305 fn make_store(dir: &std::path::Path) -> Arc<dyn ObjectStore> {
306 Arc::new(LocalFileSystem::new_with_prefix(dir).unwrap())
307 }
308
309 #[tokio::test]
310 async fn write_read_roundtrip() {
311 let dir = tempdir().unwrap();
312 let backend = ObjectStoreBackend::new(make_store(dir.path()), "node-0", 4);
313 backend
314 .write_partial(0, 1, 0, Bytes::from_static(b"hello"))
315 .await
316 .unwrap();
317 let got = backend.read_partial(0, 1).await.unwrap().unwrap();
318 assert_eq!(&got[..], b"hello");
319 }
320
321 #[tokio::test]
322 async fn epoch_complete_cas_commit() {
323 let dir = tempdir().unwrap();
324 let backend = ObjectStoreBackend::new(make_store(dir.path()), "node-0", 4);
325 let vnodes = [0u32, 1, 2];
326
327 assert!(!backend.epoch_complete(1, &vnodes).await.unwrap());
328 for v in &vnodes {
329 backend
330 .write_partial(*v, 1, 0, Bytes::from_static(b"y"))
331 .await
332 .unwrap();
333 }
334 assert!(backend.epoch_complete(1, &vnodes).await.unwrap());
335 assert!(backend.epoch_complete(1, &vnodes).await.unwrap());
337 }
338
339 #[tokio::test]
345 async fn epoch_complete_detects_split_brain_committer() {
346 let dir = tempdir().unwrap();
347 let store = make_store(dir.path());
348 let winner = ObjectStoreBackend::new(Arc::clone(&store), "winner", 4);
349 let loser = ObjectStoreBackend::new(Arc::clone(&store), "loser", 4);
350
351 let vnodes = [0u32, 1];
352 for v in &vnodes {
354 winner
355 .write_partial(*v, 7, 0, Bytes::from_static(b"w"))
356 .await
357 .unwrap();
358 }
359
360 assert!(winner.epoch_complete(7, &vnodes).await.unwrap());
362
363 let err = loser.epoch_complete(7, &vnodes).await.unwrap_err();
366 match err {
367 StateBackendError::SplitBrainCommit { committer, self_id } => {
368 assert_eq!(committer, "winner");
369 assert_eq!(self_id, "loser");
370 }
371 other => panic!("expected SplitBrainCommit, got {other:?}"),
372 }
373
374 assert!(winner.epoch_complete(7, &vnodes).await.unwrap());
376 }
377
378 #[tokio::test]
383 async fn epoch_complete_detects_split_brain_on_cas_loser_path() {
384 let dir = tempdir().unwrap();
385 let store = make_store(dir.path());
386 let winner = ObjectStoreBackend::new(Arc::clone(&store), "winner", 4);
387 let loser = ObjectStoreBackend::new(Arc::clone(&store), "loser", 4);
388
389 let vnodes = [0u32, 1];
390 for v in &vnodes {
391 winner
392 .write_partial(*v, 3, 0, Bytes::from_static(b"w"))
393 .await
394 .unwrap();
395 }
396 let commit = ObjectStoreBackend::commit_path(3);
400 store
401 .put(&commit, PutPayload::from(Bytes::from_static(b"winner")))
402 .await
403 .unwrap();
404
405 let err = loser.epoch_complete(3, &vnodes).await.unwrap_err();
406 assert!(matches!(
407 err,
408 StateBackendError::SplitBrainCommit { ref committer, .. }
409 if committer == "winner"
410 ));
411 }
412
413 #[tokio::test]
414 async fn stale_version_rejected() {
415 let dir = tempdir().unwrap();
419 let store = make_store(dir.path());
420 let stale = ObjectStoreBackend::new(Arc::clone(&store), "node-stale", 4);
421 let fresh = ObjectStoreBackend::new(Arc::clone(&store), "node-fresh", 4);
422
423 fresh.set_authoritative_version(2);
426
427 fresh
429 .write_partial(0, 1, 2, Bytes::from_static(b"fresh"))
430 .await
431 .unwrap();
432
433 stale.set_authoritative_version(2);
438 let err = stale
439 .write_partial(0, 1, 1, Bytes::from_static(b"stale"))
440 .await
441 .unwrap_err();
442 match err {
443 StateBackendError::StaleVersion {
444 caller,
445 authoritative,
446 } => {
447 assert_eq!(caller, 1);
448 assert_eq!(authoritative, 2);
449 }
450 other => panic!("expected StaleVersion, got {other:?}"),
451 }
452
453 let unfenced = ObjectStoreBackend::new(Arc::clone(&store), "node-unfenced", 4);
456 unfenced
457 .write_partial(1, 1, 0, Bytes::from_static(b"ok"))
458 .await
459 .unwrap();
460 }
461
462 #[test]
463 fn authoritative_version_is_monotonic() {
464 let dir = tempdir().unwrap();
465 let b = ObjectStoreBackend::new(make_store(dir.path()), "node", 2);
466 assert_eq!(b.authoritative_version(), 0);
467 b.set_authoritative_version(3);
468 assert_eq!(b.authoritative_version(), 3);
469 b.set_authoritative_version(1);
471 assert_eq!(b.authoritative_version(), 3);
472 b.set_authoritative_version(4);
473 assert_eq!(b.authoritative_version(), 4);
474 }
475
476 #[tokio::test]
477 async fn object_safe_behind_arc() {
478 let dir = tempdir().unwrap();
479 let _: Arc<dyn StateBackend> =
480 Arc::new(ObjectStoreBackend::new(make_store(dir.path()), "node-0", 2));
481 }
482
483 #[tokio::test]
484 async fn prune_before_deletes_old_epochs() {
485 let dir = tempdir().unwrap();
486 let backend = ObjectStoreBackend::new(make_store(dir.path()), "node-0", 4);
487
488 for epoch in 1..=5u64 {
490 backend
491 .write_partial(0, epoch, 0, Bytes::from_static(b"x"))
492 .await
493 .unwrap();
494 }
495
496 backend.prune_before(4).await.unwrap();
497
498 for epoch in 1..=3 {
499 assert!(
500 backend.read_partial(0, epoch).await.unwrap().is_none(),
501 "epoch {epoch} should be pruned",
502 );
503 }
504 for epoch in 4..=5 {
505 assert!(
506 backend.read_partial(0, epoch).await.unwrap().is_some(),
507 "epoch {epoch} should be retained",
508 );
509 }
510 }
511}