Skip to main content

laminar_ai/backends/
rate_limited.rs

1//! Client-side rate limiting for remote providers.
2//!
3//! Wraps any [`InferenceProvider`] in a token bucket (`governor`, the standard
4//! Rust limiter) so request bursts are shaped to a steady requests-per-second
5//! rather than sent unbounded. One cell is acquired per input row before the
6//! batch is dispatched. The wait happens on the Ring 1 inference worker — the
7//! only place `infer_batch` is awaited — never on Ring 0; the limiter itself is
8//! lock-free (atomics), so nothing blocking is reachable from the compute thread.
9
10use std::num::NonZeroU32;
11use std::sync::Arc;
12
13use async_trait::async_trait;
14use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
15
16use crate::provider::{InferenceProvider, InferenceRequest, InferenceResponse, ProviderError};
17
18/// An [`InferenceProvider`] that paces calls to a steady rate.
19pub struct RateLimitedProvider {
20    inner: Arc<dyn InferenceProvider>,
21    limiter: DefaultDirectRateLimiter,
22}
23
24impl RateLimitedProvider {
25    /// Wrap `inner`, limiting to `requests_per_second` (burst up to the same).
26    #[must_use]
27    pub fn new(inner: Arc<dyn InferenceProvider>, requests_per_second: NonZeroU32) -> Self {
28        Self {
29            inner,
30            limiter: RateLimiter::direct(Quota::per_second(requests_per_second)),
31        }
32    }
33}
34
35#[async_trait]
36impl InferenceProvider for RateLimitedProvider {
37    async fn infer_batch(
38        &self,
39        request: InferenceRequest,
40    ) -> Result<InferenceResponse, ProviderError> {
41        // One permit per row: the batch waits until the bucket has paced it.
42        for _ in 0..request.inputs.len().max(1) {
43            self.limiter.until_ready().await;
44        }
45        self.inner.infer_batch(request).await
46    }
47
48    fn name(&self) -> &'static str {
49        self.inner.name()
50    }
51
52    fn intrinsic_labels(&self, model: &str) -> Option<Vec<String>> {
53        self.inner.intrinsic_labels(model)
54    }
55}
56
57#[cfg(test)]
58mod tests {
59    use super::*;
60    use crate::provider::{InferenceOutputs, InferenceParams, Usage};
61    use crate::registry::Task;
62    use std::time::Instant;
63
64    struct Echo;
65
66    #[async_trait]
67    impl InferenceProvider for Echo {
68        async fn infer_batch(
69            &self,
70            request: InferenceRequest,
71        ) -> Result<InferenceResponse, ProviderError> {
72            Ok(InferenceResponse {
73                outputs: InferenceOutputs::Text(request.inputs),
74                usage: Usage::ZERO,
75            })
76        }
77        fn name(&self) -> &'static str {
78            "echo"
79        }
80    }
81
82    fn request(rows: usize) -> InferenceRequest {
83        InferenceRequest {
84            task: Task::Sentiment,
85            model: "m".into(),
86            inputs: vec!["x".to_string(); rows],
87            params: InferenceParams::default(),
88        }
89    }
90
91    /// A burst beyond the per-second budget is delayed, not dropped or sent
92    /// unbounded; the output still passes through unchanged.
93    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
94    async fn burst_beyond_the_rate_is_paced() {
95        // 100 rps → ~10 ms/cell, burst 100. 130 rows = 30 over budget ⇒ ≥ ~300 ms.
96        let p = RateLimitedProvider::new(Arc::new(Echo), NonZeroU32::new(100).unwrap());
97        let start = Instant::now();
98        let resp = p.infer_batch(request(130)).await.unwrap();
99        assert_eq!(resp.outputs.len(), 130);
100        assert!(
101            start.elapsed() >= std::time::Duration::from_millis(200),
102            "burst was not paced: {:?}",
103            start.elapsed()
104        );
105    }
106
107    #[tokio::test]
108    async fn name_delegates_to_inner() {
109        let p = RateLimitedProvider::new(Arc::new(Echo), NonZeroU32::new(1000).unwrap());
110        assert_eq!(p.name(), "echo");
111    }
112}