Skip to main content

laminar_core/io_uring/
buffer_pool.rs

1//! Pre-registered I/O buffer pool for zero-copy operations.
2//!
3//! Registers buffers once at startup to avoid per-operation buffer mapping overhead.
4//! Uses fixed-size buffers for predictable performance.
5#![allow(clippy::disallowed_types)] // cold path: buffer pool setup (io_uring infrastructure)
6
7use io_uring::types::Fd;
8use io_uring::{opcode, IoUring};
9use std::collections::{HashSet, VecDeque};
10use std::os::fd::RawFd;
11
12use super::error::IoUringError;
13
14/// A pre-registered buffer pool for `io_uring` operations.
15///
16/// Buffers are registered with the kernel once at creation time, eliminating
17/// the per-operation buffer mapping overhead that would otherwise occur.
18///
19/// # Example
20///
21/// ```rust,ignore
22/// use laminar_core::io_uring::RegisteredBufferPool;
23///
24/// let mut pool = RegisteredBufferPool::new(&mut ring, 64 * 1024, 256)?;
25///
26/// // Acquire a buffer
27/// let (idx, buf) = pool.acquire()?;
28/// buf[..5].copy_from_slice(b"hello");
29///
30/// // Use the buffer for I/O
31/// pool.submit_write_fixed(fd, idx, 0, 5)?;
32///
33/// // After completion, release the buffer
34/// pool.release(idx);
35/// ```
36pub struct RegisteredBufferPool {
37    /// Pre-allocated buffers.
38    buffers: Vec<Vec<u8>>,
39    /// Size of each buffer.
40    buffer_size: usize,
41    /// Free buffer indices.
42    free_list: VecDeque<u16>,
43    /// Next `user_data` ID for tracking operations.
44    next_id: u64,
45    /// Total buffers in pool.
46    total_count: usize,
47    /// Buffers currently acquired.
48    acquired_count: usize,
49    /// Cumulative number of successful buffer acquisitions.
50    acquisitions: u64,
51    /// Cumulative number of times acquisition failed (pool exhausted).
52    exhaustions: u64,
53    /// Buffers currently submitted to `io_uring` (debug-only tracking).
54    /// Used to detect use-after-free: releasing a buffer that is still in-flight.
55    in_flight: HashSet<u16>,
56}
57
58impl RegisteredBufferPool {
59    /// Create a new buffer pool and register it with the kernel.
60    ///
61    /// # Arguments
62    ///
63    /// * `ring` - The `io_uring` instance to register buffers with
64    /// * `buffer_size` - Size of each buffer in bytes
65    /// * `buffer_count` - Number of buffers to allocate
66    ///
67    /// # Errors
68    ///
69    /// Returns an error if buffer registration fails.
70    pub fn new(
71        ring: &mut IoUring,
72        buffer_size: usize,
73        buffer_count: usize,
74    ) -> Result<Self, IoUringError> {
75        if buffer_count > u16::MAX as usize {
76            return Err(IoUringError::InvalidConfig(format!(
77                "buffer_count {} exceeds maximum {}",
78                buffer_count,
79                u16::MAX
80            )));
81        }
82
83        // Allocate aligned buffers
84        let buffers: Vec<Vec<u8>> = (0..buffer_count)
85            .map(|_| {
86                // Allocate with page alignment for O_DIRECT compatibility
87                vec![0; buffer_size]
88            })
89            .collect();
90
91        // Create iovec slice for registration
92        let iovecs: Vec<libc::iovec> = buffers
93            .iter()
94            .map(|buf| libc::iovec {
95                iov_base: buf.as_ptr() as *mut _,
96                iov_len: buf.len(),
97            })
98            .collect();
99
100        // Register buffers with the kernel
101        // SAFETY: The iovecs point to valid, owned memory that will outlive the registration.
102        // The buffers Vec owns the memory and is stored in the struct.
103        unsafe {
104            ring.submitter()
105                .register_buffers(&iovecs)
106                .map_err(IoUringError::BufferRegistration)?;
107        }
108
109        #[allow(clippy::cast_possible_truncation)]
110        let free_list = (0..buffer_count as u16).collect();
111
112        Ok(Self {
113            buffers,
114            buffer_size,
115            free_list,
116            next_id: 0,
117            total_count: buffer_count,
118            acquired_count: 0,
119            acquisitions: 0,
120            exhaustions: 0,
121            in_flight: HashSet::new(),
122        })
123    }
124
125    /// Acquire a buffer from the pool.
126    ///
127    /// Returns the buffer index and a mutable reference to the buffer.
128    ///
129    /// # Errors
130    ///
131    /// Returns an error if no buffers are available.
132    pub fn acquire(&mut self) -> Result<(u16, &mut [u8]), IoUringError> {
133        let Some(idx) = self.free_list.pop_front() else {
134            self.exhaustions += 1;
135            return Err(IoUringError::BufferPoolExhausted);
136        };
137
138        self.acquired_count += 1;
139        self.acquisitions += 1;
140
141        Ok((idx, &mut self.buffers[idx as usize]))
142    }
143
144    /// Try to acquire a buffer without blocking.
145    ///
146    /// Returns `None` if no buffers are available.
147    #[must_use]
148    pub fn try_acquire(&mut self) -> Option<(u16, &mut [u8])> {
149        let Some(idx) = self.free_list.pop_front() else {
150            self.exhaustions += 1;
151            return None;
152        };
153        self.acquired_count += 1;
154        self.acquisitions += 1;
155        Some((idx, &mut self.buffers[idx as usize]))
156    }
157
158    /// Release a buffer back to the pool.
159    ///
160    /// The buffer must not be in-flight (submitted to `io_uring` but not yet
161    /// completed). In debug builds, this is checked via assertion.
162    ///
163    /// # Panics
164    ///
165    /// Panics in debug builds if the buffer index is invalid or the buffer
166    /// is still in-flight.
167    pub fn release(&mut self, buf_index: u16) {
168        debug_assert!(
169            (buf_index as usize) < self.total_count,
170            "Invalid buffer index"
171        );
172        debug_assert!(
173            !self.in_flight.contains(&buf_index),
174            "Releasing buffer {buf_index} that is still in-flight — \
175             wait for the CQE before releasing"
176        );
177        self.free_list.push_back(buf_index);
178        self.acquired_count = self.acquired_count.saturating_sub(1);
179    }
180
181    /// Mark a buffer as in-flight (submitted to `io_uring`).
182    ///
183    /// Call this after submitting a read/write operation that uses this buffer.
184    /// Call [`Self::complete_in_flight`] when the CQE arrives.
185    pub fn mark_in_flight(&mut self, buf_index: u16) {
186        self.in_flight.insert(buf_index);
187    }
188
189    /// Mark a buffer as no longer in-flight (CQE received).
190    pub fn complete_in_flight(&mut self, buf_index: u16) {
191        self.in_flight.remove(&buf_index);
192    }
193
194    /// Get a reference to a buffer by index.
195    ///
196    /// # Errors
197    ///
198    /// Returns an error if the index is invalid.
199    pub fn get(&self, buf_index: u16) -> Result<&[u8], IoUringError> {
200        self.buffers
201            .get(buf_index as usize)
202            .map(Vec::as_slice)
203            .ok_or(IoUringError::InvalidBufferIndex(buf_index))
204    }
205
206    /// Get a mutable reference to a buffer by index.
207    ///
208    /// # Errors
209    ///
210    /// Returns an error if the index is invalid.
211    pub fn get_mut(&mut self, buf_index: u16) -> Result<&mut [u8], IoUringError> {
212        self.buffers
213            .get_mut(buf_index as usize)
214            .map(Vec::as_mut_slice)
215            .ok_or(IoUringError::InvalidBufferIndex(buf_index))
216    }
217
218    /// Submit a read operation using a registered buffer.
219    ///
220    /// The data will be read into the buffer at the given index.
221    ///
222    /// # Arguments
223    ///
224    /// * `ring` - The `io_uring` instance
225    /// * `fd` - File descriptor to read from
226    /// * `buf_index` - Index of the registered buffer
227    /// * `offset` - File offset to read from
228    /// * `len` - Number of bytes to read
229    ///
230    /// # Returns
231    ///
232    /// The `user_data` value that will identify this operation in completions.
233    ///
234    /// # Errors
235    ///
236    /// Returns an error if the submission queue is full.
237    pub fn submit_read_fixed(
238        &mut self,
239        ring: &mut IoUring,
240        fd: RawFd,
241        buf_index: u16,
242        offset: u64,
243        len: u32,
244    ) -> Result<u64, IoUringError> {
245        // Get user_data first to avoid borrow conflict
246        let user_data = self.next_user_data();
247
248        let buf = self
249            .buffers
250            .get_mut(buf_index as usize)
251            .ok_or(IoUringError::InvalidBufferIndex(buf_index))?;
252
253        let entry = opcode::ReadFixed::new(Fd(fd), buf.as_mut_ptr(), len, buf_index)
254            .offset(offset)
255            .build()
256            .user_data(user_data);
257
258        // SAFETY: We're submitting a valid SQE with a properly registered buffer.
259        // The buffer at buf_index was registered with this ring during pool creation
260        // and remains valid because the pool owns the backing Vec<u8>.
261        unsafe {
262            ring.submission()
263                .push(&entry)
264                .map_err(|_| IoUringError::SubmissionQueueFull)?;
265        }
266
267        self.mark_in_flight(buf_index);
268        Ok(user_data)
269    }
270
271    /// Submit a write operation using a registered buffer.
272    ///
273    /// The data in the buffer at the given index will be written.
274    ///
275    /// # Arguments
276    ///
277    /// * `ring` - The `io_uring` instance
278    /// * `fd` - File descriptor to write to
279    /// * `buf_index` - Index of the registered buffer
280    /// * `offset` - File offset to write to
281    /// * `len` - Number of bytes to write
282    ///
283    /// # Returns
284    ///
285    /// The `user_data` value that will identify this operation in completions.
286    ///
287    /// # Errors
288    ///
289    /// Returns an error if the submission queue is full.
290    pub fn submit_write_fixed(
291        &mut self,
292        ring: &mut IoUring,
293        fd: RawFd,
294        buf_index: u16,
295        offset: u64,
296        len: u32,
297    ) -> Result<u64, IoUringError> {
298        let user_data = self.next_user_data();
299
300        let buf = self
301            .buffers
302            .get(buf_index as usize)
303            .ok_or(IoUringError::InvalidBufferIndex(buf_index))?;
304
305        let entry = opcode::WriteFixed::new(Fd(fd), buf.as_ptr(), len, buf_index)
306            .offset(offset)
307            .build()
308            .user_data(user_data);
309
310        // SAFETY: We're submitting a valid SQE with a properly registered buffer.
311        // The buffer at buf_index was registered with this ring during pool creation
312        // and remains valid because the pool owns the backing Vec<u8>.
313        unsafe {
314            ring.submission()
315                .push(&entry)
316                .map_err(|_| IoUringError::SubmissionQueueFull)?;
317        }
318
319        self.mark_in_flight(buf_index);
320        Ok(user_data)
321    }
322
323    /// Get the size of each buffer.
324    #[must_use]
325    pub const fn buffer_size(&self) -> usize {
326        self.buffer_size
327    }
328
329    /// Get the total number of buffers.
330    #[must_use]
331    pub const fn total_count(&self) -> usize {
332        self.total_count
333    }
334
335    /// Get the number of available buffers.
336    #[must_use]
337    pub fn available_count(&self) -> usize {
338        self.free_list.len()
339    }
340
341    /// Get the number of acquired buffers.
342    #[must_use]
343    pub const fn acquired_count(&self) -> usize {
344        self.acquired_count
345    }
346
347    /// Check if the pool is exhausted.
348    #[must_use]
349    pub fn is_exhausted(&self) -> bool {
350        self.free_list.is_empty()
351    }
352
353    /// Generate the next `user_data` ID.
354    fn next_user_data(&mut self) -> u64 {
355        let id = self.next_id;
356        self.next_id += 1;
357        id
358    }
359
360    /// Get statistics about the buffer pool.
361    #[must_use]
362    pub fn stats(&self) -> BufferPoolStats {
363        BufferPoolStats {
364            total_count: self.total_count,
365            available_count: self.free_list.len(),
366            acquired_count: self.acquired_count,
367            buffer_size: self.buffer_size,
368            total_bytes: self.total_count * self.buffer_size,
369            acquisitions: self.acquisitions,
370            exhaustions: self.exhaustions,
371        }
372    }
373}
374
375impl Drop for RegisteredBufferPool {
376    fn drop(&mut self) {
377        // Buffer memory is heap-allocated Vec<u8> that drops automatically.
378        // `unregister_buffers()` is NOT called here because the io_uring ring
379        // may already be dropped (field drop order in CoreRingManager: ring drops
380        // before pool). The ring's own drop handles kernel-side cleanup.
381        // The manager is responsible for draining in-flight operations before
382        // dropping either the ring or the pool.
383        debug_assert!(
384            self.acquired_count == 0,
385            "RegisteredBufferPool dropped with {} buffers still acquired — \
386             possible in-flight I/O operations referencing freed memory",
387            self.acquired_count,
388        );
389    }
390}
391
392/// Statistics about the buffer pool.
393#[derive(Debug, Clone, Copy, PartialEq, Eq)]
394pub struct BufferPoolStats {
395    /// Total number of buffers.
396    pub total_count: usize,
397    /// Number of available buffers.
398    pub available_count: usize,
399    /// Number of acquired buffers.
400    pub acquired_count: usize,
401    /// Size of each buffer in bytes.
402    pub buffer_size: usize,
403    /// Total bytes allocated.
404    pub total_bytes: usize,
405    /// Cumulative successful acquisitions.
406    pub acquisitions: u64,
407    /// Cumulative pool exhaustion events (acquire failed).
408    pub exhaustions: u64,
409}
410
411impl std::fmt::Display for BufferPoolStats {
412    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
413        write!(
414            f,
415            "BufferPool({} available/{} total, {}KB each, {}MB total)",
416            self.available_count,
417            self.total_count,
418            self.buffer_size / 1024,
419            self.total_bytes / (1024 * 1024)
420        )
421    }
422}
423
424#[cfg(test)]
425#[allow(
426    clippy::manual_let_else,
427    clippy::single_match_else,
428    clippy::items_after_statements
429)]
430mod tests {
431    use super::*;
432    use std::fs::OpenOptions;
433    use std::io::Write;
434    use tempfile::tempdir;
435
436    fn create_test_ring() -> Option<IoUring> {
437        IoUring::builder().build(32).ok()
438    }
439
440    #[test]
441    fn test_buffer_pool_creation() {
442        let mut ring = match create_test_ring() {
443            Some(r) => r,
444            None => {
445                tracing::warn!("io_uring not available, skipping test");
446                return;
447            }
448        };
449
450        let pool = RegisteredBufferPool::new(&mut ring, 4096, 16);
451        match pool {
452            Ok(p) => {
453                assert_eq!(p.total_count(), 16);
454                assert_eq!(p.buffer_size(), 4096);
455                assert_eq!(p.available_count(), 16);
456                assert_eq!(p.acquired_count(), 0);
457            }
458            Err(e) => {
459                tracing::warn!("Buffer registration failed: {e}");
460            }
461        }
462    }
463
464    #[test]
465    fn test_acquire_release() {
466        let mut ring = match create_test_ring() {
467            Some(r) => r,
468            None => return,
469        };
470
471        let mut pool = match RegisteredBufferPool::new(&mut ring, 4096, 4) {
472            Ok(p) => p,
473            Err(_) => return,
474        };
475
476        // Acquire all buffers
477        let mut indices = Vec::new();
478        for _ in 0..4 {
479            let (idx, _buf) = pool.acquire().unwrap();
480            indices.push(idx);
481        }
482
483        assert_eq!(pool.available_count(), 0);
484        assert_eq!(pool.acquired_count(), 4);
485        assert!(pool.is_exhausted());
486
487        // Should fail to acquire more
488        assert!(pool.acquire().is_err());
489
490        // Release one
491        pool.release(indices[0]);
492        assert_eq!(pool.available_count(), 1);
493        assert!(!pool.is_exhausted());
494
495        // Can acquire again
496        let (idx, _) = pool.acquire().unwrap();
497        assert_eq!(idx, indices[0]);
498
499        // Release all before drop
500        for &i in &indices[1..] {
501            pool.release(i);
502        }
503        pool.release(idx);
504    }
505
506    #[test]
507    fn test_buffer_access() {
508        let mut ring = match create_test_ring() {
509            Some(r) => r,
510            None => return,
511        };
512
513        let mut pool = match RegisteredBufferPool::new(&mut ring, 1024, 4) {
514            Ok(p) => p,
515            Err(_) => return,
516        };
517
518        let (idx, buf) = pool.acquire().unwrap();
519        buf[..5].copy_from_slice(b"hello");
520
521        // Release and get again
522        pool.release(idx);
523
524        let data = pool.get(idx).unwrap();
525        assert_eq!(&data[..5], b"hello");
526    }
527
528    #[test]
529    fn test_stats() {
530        let mut ring = match create_test_ring() {
531            Some(r) => r,
532            None => return,
533        };
534
535        let pool = match RegisteredBufferPool::new(&mut ring, 4096, 16) {
536            Ok(p) => p,
537            Err(_) => return,
538        };
539
540        let stats = pool.stats();
541        assert_eq!(stats.total_count, 16);
542        assert_eq!(stats.available_count, 16);
543        assert_eq!(stats.acquired_count, 0);
544        assert_eq!(stats.buffer_size, 4096);
545        assert_eq!(stats.total_bytes, 16 * 4096);
546        assert_eq!(stats.acquisitions, 0);
547        assert_eq!(stats.exhaustions, 0);
548
549        let display = format!("{stats}");
550        assert!(display.contains("16"));
551        assert!(display.contains("4KB"));
552    }
553
554    #[test]
555    fn test_buffer_pool_stats_initial() {
556        let mut ring = match create_test_ring() {
557            Some(r) => r,
558            None => return,
559        };
560
561        let pool = match RegisteredBufferPool::new(&mut ring, 4096, 8) {
562            Ok(p) => p,
563            Err(_) => return,
564        };
565
566        let stats = pool.stats();
567        assert_eq!(stats.total_count, 8);
568        assert_eq!(stats.available_count, 8);
569        assert_eq!(stats.acquired_count, 0);
570        assert_eq!(stats.acquisitions, 0);
571        assert_eq!(stats.exhaustions, 0);
572    }
573
574    #[test]
575    fn test_buffer_pool_stats_after_acquire() {
576        let mut ring = match create_test_ring() {
577            Some(r) => r,
578            None => return,
579        };
580
581        let mut pool = match RegisteredBufferPool::new(&mut ring, 4096, 4) {
582            Ok(p) => p,
583            Err(_) => return,
584        };
585
586        let (idx, _) = pool.acquire().unwrap();
587        let stats = pool.stats();
588        assert_eq!(stats.acquisitions, 1);
589        assert_eq!(stats.exhaustions, 0);
590        assert_eq!(stats.acquired_count, 1);
591        assert_eq!(stats.available_count, 3);
592
593        pool.release(idx);
594        let (a, _) = pool.acquire().unwrap();
595        let (b, _) = pool.acquire().unwrap();
596        let stats = pool.stats();
597        assert_eq!(stats.acquisitions, 3);
598        assert_eq!(stats.exhaustions, 0);
599
600        // Release all before drop
601        pool.release(a);
602        pool.release(b);
603    }
604
605    #[test]
606    fn test_buffer_pool_exhaustion_counter() {
607        let mut ring = match create_test_ring() {
608            Some(r) => r,
609            None => return,
610        };
611
612        let mut pool = match RegisteredBufferPool::new(&mut ring, 4096, 2) {
613            Ok(p) => p,
614            Err(_) => return,
615        };
616
617        // Exhaust the pool
618        let _ = pool.acquire().unwrap();
619        let _ = pool.acquire().unwrap();
620
621        // This should fail and increment exhaustions
622        assert!(pool.acquire().is_err());
623        assert!(pool.try_acquire().is_none());
624
625        let stats = pool.stats();
626        assert_eq!(stats.acquisitions, 2);
627        assert_eq!(stats.exhaustions, 2); // One from acquire(), one from try_acquire()
628        assert_eq!(stats.acquired_count, 2);
629        assert_eq!(stats.available_count, 0);
630
631        // Release all before drop
632        pool.release(0);
633        pool.release(1);
634    }
635
636    #[test]
637    fn test_write_read_fixed() {
638        let mut ring = match create_test_ring() {
639            Some(r) => r,
640            None => return,
641        };
642
643        let mut pool = match RegisteredBufferPool::new(&mut ring, 4096, 4) {
644            Ok(p) => p,
645            Err(_) => return,
646        };
647
648        // Create a temp file
649        let dir = tempdir().unwrap();
650        let path = dir.path().join("test.dat");
651        let mut file = OpenOptions::new()
652            .read(true)
653            .write(true)
654            .create(true)
655            .truncate(true)
656            .open(&path)
657            .unwrap();
658
659        file.write_all(&[0u8; 4096]).unwrap();
660        file.flush().unwrap();
661        drop(file);
662
663        let file = OpenOptions::new()
664            .read(true)
665            .write(true)
666            .open(&path)
667            .unwrap();
668        use std::os::unix::io::AsRawFd;
669        let fd = file.as_raw_fd();
670
671        // Write using registered buffer
672        let (idx, buf) = pool.acquire().unwrap();
673        buf[..5].copy_from_slice(b"hello");
674
675        let user_data = pool.submit_write_fixed(&mut ring, fd, idx, 0, 5).unwrap();
676
677        // Submit and wait
678        ring.submit_and_wait(1).unwrap();
679
680        // Check completion
681        let mut cq = ring.completion();
682        let cqe = cq.next().unwrap();
683        assert_eq!(cqe.user_data(), user_data);
684        assert!(cqe.result() >= 0);
685        drop(cq);
686
687        // Mark buffer as no longer in-flight before releasing
688        pool.complete_in_flight(idx);
689        pool.release(idx);
690    }
691}