1#![allow(clippy::disallowed_types)] use io_uring::types::Fd;
8use io_uring::{opcode, IoUring};
9use std::collections::{HashSet, VecDeque};
10use std::os::fd::RawFd;
11
12use super::error::IoUringError;
13
14pub struct RegisteredBufferPool {
37 buffers: Vec<Vec<u8>>,
39 buffer_size: usize,
41 free_list: VecDeque<u16>,
43 next_id: u64,
45 total_count: usize,
47 acquired_count: usize,
49 acquisitions: u64,
51 exhaustions: u64,
53 in_flight: HashSet<u16>,
56}
57
58impl RegisteredBufferPool {
59 pub fn new(
71 ring: &mut IoUring,
72 buffer_size: usize,
73 buffer_count: usize,
74 ) -> Result<Self, IoUringError> {
75 if buffer_count > u16::MAX as usize {
76 return Err(IoUringError::InvalidConfig(format!(
77 "buffer_count {} exceeds maximum {}",
78 buffer_count,
79 u16::MAX
80 )));
81 }
82
83 let buffers: Vec<Vec<u8>> = (0..buffer_count)
85 .map(|_| {
86 vec![0; buffer_size]
88 })
89 .collect();
90
91 let iovecs: Vec<libc::iovec> = buffers
93 .iter()
94 .map(|buf| libc::iovec {
95 iov_base: buf.as_ptr() as *mut _,
96 iov_len: buf.len(),
97 })
98 .collect();
99
100 unsafe {
104 ring.submitter()
105 .register_buffers(&iovecs)
106 .map_err(IoUringError::BufferRegistration)?;
107 }
108
109 #[allow(clippy::cast_possible_truncation)]
110 let free_list = (0..buffer_count as u16).collect();
111
112 Ok(Self {
113 buffers,
114 buffer_size,
115 free_list,
116 next_id: 0,
117 total_count: buffer_count,
118 acquired_count: 0,
119 acquisitions: 0,
120 exhaustions: 0,
121 in_flight: HashSet::new(),
122 })
123 }
124
125 pub fn acquire(&mut self) -> Result<(u16, &mut [u8]), IoUringError> {
133 let Some(idx) = self.free_list.pop_front() else {
134 self.exhaustions += 1;
135 return Err(IoUringError::BufferPoolExhausted);
136 };
137
138 self.acquired_count += 1;
139 self.acquisitions += 1;
140
141 Ok((idx, &mut self.buffers[idx as usize]))
142 }
143
144 #[must_use]
148 pub fn try_acquire(&mut self) -> Option<(u16, &mut [u8])> {
149 let Some(idx) = self.free_list.pop_front() else {
150 self.exhaustions += 1;
151 return None;
152 };
153 self.acquired_count += 1;
154 self.acquisitions += 1;
155 Some((idx, &mut self.buffers[idx as usize]))
156 }
157
158 pub fn release(&mut self, buf_index: u16) {
168 debug_assert!(
169 (buf_index as usize) < self.total_count,
170 "Invalid buffer index"
171 );
172 debug_assert!(
173 !self.in_flight.contains(&buf_index),
174 "Releasing buffer {buf_index} that is still in-flight — \
175 wait for the CQE before releasing"
176 );
177 self.free_list.push_back(buf_index);
178 self.acquired_count = self.acquired_count.saturating_sub(1);
179 }
180
181 pub fn mark_in_flight(&mut self, buf_index: u16) {
186 self.in_flight.insert(buf_index);
187 }
188
189 pub fn complete_in_flight(&mut self, buf_index: u16) {
191 self.in_flight.remove(&buf_index);
192 }
193
194 pub fn get(&self, buf_index: u16) -> Result<&[u8], IoUringError> {
200 self.buffers
201 .get(buf_index as usize)
202 .map(Vec::as_slice)
203 .ok_or(IoUringError::InvalidBufferIndex(buf_index))
204 }
205
206 pub fn get_mut(&mut self, buf_index: u16) -> Result<&mut [u8], IoUringError> {
212 self.buffers
213 .get_mut(buf_index as usize)
214 .map(Vec::as_mut_slice)
215 .ok_or(IoUringError::InvalidBufferIndex(buf_index))
216 }
217
218 pub fn submit_read_fixed(
238 &mut self,
239 ring: &mut IoUring,
240 fd: RawFd,
241 buf_index: u16,
242 offset: u64,
243 len: u32,
244 ) -> Result<u64, IoUringError> {
245 let user_data = self.next_user_data();
247
248 let buf = self
249 .buffers
250 .get_mut(buf_index as usize)
251 .ok_or(IoUringError::InvalidBufferIndex(buf_index))?;
252
253 let entry = opcode::ReadFixed::new(Fd(fd), buf.as_mut_ptr(), len, buf_index)
254 .offset(offset)
255 .build()
256 .user_data(user_data);
257
258 unsafe {
262 ring.submission()
263 .push(&entry)
264 .map_err(|_| IoUringError::SubmissionQueueFull)?;
265 }
266
267 self.mark_in_flight(buf_index);
268 Ok(user_data)
269 }
270
271 pub fn submit_write_fixed(
291 &mut self,
292 ring: &mut IoUring,
293 fd: RawFd,
294 buf_index: u16,
295 offset: u64,
296 len: u32,
297 ) -> Result<u64, IoUringError> {
298 let user_data = self.next_user_data();
299
300 let buf = self
301 .buffers
302 .get(buf_index as usize)
303 .ok_or(IoUringError::InvalidBufferIndex(buf_index))?;
304
305 let entry = opcode::WriteFixed::new(Fd(fd), buf.as_ptr(), len, buf_index)
306 .offset(offset)
307 .build()
308 .user_data(user_data);
309
310 unsafe {
314 ring.submission()
315 .push(&entry)
316 .map_err(|_| IoUringError::SubmissionQueueFull)?;
317 }
318
319 self.mark_in_flight(buf_index);
320 Ok(user_data)
321 }
322
323 #[must_use]
325 pub const fn buffer_size(&self) -> usize {
326 self.buffer_size
327 }
328
329 #[must_use]
331 pub const fn total_count(&self) -> usize {
332 self.total_count
333 }
334
335 #[must_use]
337 pub fn available_count(&self) -> usize {
338 self.free_list.len()
339 }
340
341 #[must_use]
343 pub const fn acquired_count(&self) -> usize {
344 self.acquired_count
345 }
346
347 #[must_use]
349 pub fn is_exhausted(&self) -> bool {
350 self.free_list.is_empty()
351 }
352
353 fn next_user_data(&mut self) -> u64 {
355 let id = self.next_id;
356 self.next_id += 1;
357 id
358 }
359
360 #[must_use]
362 pub fn stats(&self) -> BufferPoolStats {
363 BufferPoolStats {
364 total_count: self.total_count,
365 available_count: self.free_list.len(),
366 acquired_count: self.acquired_count,
367 buffer_size: self.buffer_size,
368 total_bytes: self.total_count * self.buffer_size,
369 acquisitions: self.acquisitions,
370 exhaustions: self.exhaustions,
371 }
372 }
373}
374
375impl Drop for RegisteredBufferPool {
376 fn drop(&mut self) {
377 debug_assert!(
384 self.acquired_count == 0,
385 "RegisteredBufferPool dropped with {} buffers still acquired — \
386 possible in-flight I/O operations referencing freed memory",
387 self.acquired_count,
388 );
389 }
390}
391
392#[derive(Debug, Clone, Copy, PartialEq, Eq)]
394pub struct BufferPoolStats {
395 pub total_count: usize,
397 pub available_count: usize,
399 pub acquired_count: usize,
401 pub buffer_size: usize,
403 pub total_bytes: usize,
405 pub acquisitions: u64,
407 pub exhaustions: u64,
409}
410
411impl std::fmt::Display for BufferPoolStats {
412 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
413 write!(
414 f,
415 "BufferPool({} available/{} total, {}KB each, {}MB total)",
416 self.available_count,
417 self.total_count,
418 self.buffer_size / 1024,
419 self.total_bytes / (1024 * 1024)
420 )
421 }
422}
423
424#[cfg(test)]
425#[allow(
426 clippy::manual_let_else,
427 clippy::single_match_else,
428 clippy::items_after_statements
429)]
430mod tests {
431 use super::*;
432 use std::fs::OpenOptions;
433 use std::io::Write;
434 use tempfile::tempdir;
435
436 fn create_test_ring() -> Option<IoUring> {
437 IoUring::builder().build(32).ok()
438 }
439
440 #[test]
441 fn test_buffer_pool_creation() {
442 let mut ring = match create_test_ring() {
443 Some(r) => r,
444 None => {
445 tracing::warn!("io_uring not available, skipping test");
446 return;
447 }
448 };
449
450 let pool = RegisteredBufferPool::new(&mut ring, 4096, 16);
451 match pool {
452 Ok(p) => {
453 assert_eq!(p.total_count(), 16);
454 assert_eq!(p.buffer_size(), 4096);
455 assert_eq!(p.available_count(), 16);
456 assert_eq!(p.acquired_count(), 0);
457 }
458 Err(e) => {
459 tracing::warn!("Buffer registration failed: {e}");
460 }
461 }
462 }
463
464 #[test]
465 fn test_acquire_release() {
466 let mut ring = match create_test_ring() {
467 Some(r) => r,
468 None => return,
469 };
470
471 let mut pool = match RegisteredBufferPool::new(&mut ring, 4096, 4) {
472 Ok(p) => p,
473 Err(_) => return,
474 };
475
476 let mut indices = Vec::new();
478 for _ in 0..4 {
479 let (idx, _buf) = pool.acquire().unwrap();
480 indices.push(idx);
481 }
482
483 assert_eq!(pool.available_count(), 0);
484 assert_eq!(pool.acquired_count(), 4);
485 assert!(pool.is_exhausted());
486
487 assert!(pool.acquire().is_err());
489
490 pool.release(indices[0]);
492 assert_eq!(pool.available_count(), 1);
493 assert!(!pool.is_exhausted());
494
495 let (idx, _) = pool.acquire().unwrap();
497 assert_eq!(idx, indices[0]);
498
499 for &i in &indices[1..] {
501 pool.release(i);
502 }
503 pool.release(idx);
504 }
505
506 #[test]
507 fn test_buffer_access() {
508 let mut ring = match create_test_ring() {
509 Some(r) => r,
510 None => return,
511 };
512
513 let mut pool = match RegisteredBufferPool::new(&mut ring, 1024, 4) {
514 Ok(p) => p,
515 Err(_) => return,
516 };
517
518 let (idx, buf) = pool.acquire().unwrap();
519 buf[..5].copy_from_slice(b"hello");
520
521 pool.release(idx);
523
524 let data = pool.get(idx).unwrap();
525 assert_eq!(&data[..5], b"hello");
526 }
527
528 #[test]
529 fn test_stats() {
530 let mut ring = match create_test_ring() {
531 Some(r) => r,
532 None => return,
533 };
534
535 let pool = match RegisteredBufferPool::new(&mut ring, 4096, 16) {
536 Ok(p) => p,
537 Err(_) => return,
538 };
539
540 let stats = pool.stats();
541 assert_eq!(stats.total_count, 16);
542 assert_eq!(stats.available_count, 16);
543 assert_eq!(stats.acquired_count, 0);
544 assert_eq!(stats.buffer_size, 4096);
545 assert_eq!(stats.total_bytes, 16 * 4096);
546 assert_eq!(stats.acquisitions, 0);
547 assert_eq!(stats.exhaustions, 0);
548
549 let display = format!("{stats}");
550 assert!(display.contains("16"));
551 assert!(display.contains("4KB"));
552 }
553
554 #[test]
555 fn test_buffer_pool_stats_initial() {
556 let mut ring = match create_test_ring() {
557 Some(r) => r,
558 None => return,
559 };
560
561 let pool = match RegisteredBufferPool::new(&mut ring, 4096, 8) {
562 Ok(p) => p,
563 Err(_) => return,
564 };
565
566 let stats = pool.stats();
567 assert_eq!(stats.total_count, 8);
568 assert_eq!(stats.available_count, 8);
569 assert_eq!(stats.acquired_count, 0);
570 assert_eq!(stats.acquisitions, 0);
571 assert_eq!(stats.exhaustions, 0);
572 }
573
574 #[test]
575 fn test_buffer_pool_stats_after_acquire() {
576 let mut ring = match create_test_ring() {
577 Some(r) => r,
578 None => return,
579 };
580
581 let mut pool = match RegisteredBufferPool::new(&mut ring, 4096, 4) {
582 Ok(p) => p,
583 Err(_) => return,
584 };
585
586 let (idx, _) = pool.acquire().unwrap();
587 let stats = pool.stats();
588 assert_eq!(stats.acquisitions, 1);
589 assert_eq!(stats.exhaustions, 0);
590 assert_eq!(stats.acquired_count, 1);
591 assert_eq!(stats.available_count, 3);
592
593 pool.release(idx);
594 let (a, _) = pool.acquire().unwrap();
595 let (b, _) = pool.acquire().unwrap();
596 let stats = pool.stats();
597 assert_eq!(stats.acquisitions, 3);
598 assert_eq!(stats.exhaustions, 0);
599
600 pool.release(a);
602 pool.release(b);
603 }
604
605 #[test]
606 fn test_buffer_pool_exhaustion_counter() {
607 let mut ring = match create_test_ring() {
608 Some(r) => r,
609 None => return,
610 };
611
612 let mut pool = match RegisteredBufferPool::new(&mut ring, 4096, 2) {
613 Ok(p) => p,
614 Err(_) => return,
615 };
616
617 let _ = pool.acquire().unwrap();
619 let _ = pool.acquire().unwrap();
620
621 assert!(pool.acquire().is_err());
623 assert!(pool.try_acquire().is_none());
624
625 let stats = pool.stats();
626 assert_eq!(stats.acquisitions, 2);
627 assert_eq!(stats.exhaustions, 2); assert_eq!(stats.acquired_count, 2);
629 assert_eq!(stats.available_count, 0);
630
631 pool.release(0);
633 pool.release(1);
634 }
635
636 #[test]
637 fn test_write_read_fixed() {
638 let mut ring = match create_test_ring() {
639 Some(r) => r,
640 None => return,
641 };
642
643 let mut pool = match RegisteredBufferPool::new(&mut ring, 4096, 4) {
644 Ok(p) => p,
645 Err(_) => return,
646 };
647
648 let dir = tempdir().unwrap();
650 let path = dir.path().join("test.dat");
651 let mut file = OpenOptions::new()
652 .read(true)
653 .write(true)
654 .create(true)
655 .truncate(true)
656 .open(&path)
657 .unwrap();
658
659 file.write_all(&[0u8; 4096]).unwrap();
660 file.flush().unwrap();
661 drop(file);
662
663 let file = OpenOptions::new()
664 .read(true)
665 .write(true)
666 .open(&path)
667 .unwrap();
668 use std::os::unix::io::AsRawFd;
669 let fd = file.as_raw_fd();
670
671 let (idx, buf) = pool.acquire().unwrap();
673 buf[..5].copy_from_slice(b"hello");
674
675 let user_data = pool.submit_write_fixed(&mut ring, fd, idx, 0, 5).unwrap();
676
677 ring.submit_and_wait(1).unwrap();
679
680 let mut cq = ring.completion();
682 let cqe = cq.next().unwrap();
683 assert_eq!(cqe.user_data(), user_data);
684 assert!(cqe.result() >= 0);
685 drop(cq);
686
687 pool.complete_in_flight(idx);
689 pool.release(idx);
690 }
691}