laminar_ai/backends/
rate_limited.rs1use std::num::NonZeroU32;
11use std::sync::Arc;
12
13use async_trait::async_trait;
14use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
15
16use crate::provider::{InferenceProvider, InferenceRequest, InferenceResponse, ProviderError};
17
18pub struct RateLimitedProvider {
20 inner: Arc<dyn InferenceProvider>,
21 limiter: DefaultDirectRateLimiter,
22}
23
24impl RateLimitedProvider {
25 #[must_use]
27 pub fn new(inner: Arc<dyn InferenceProvider>, requests_per_second: NonZeroU32) -> Self {
28 Self {
29 inner,
30 limiter: RateLimiter::direct(Quota::per_second(requests_per_second)),
31 }
32 }
33}
34
35#[async_trait]
36impl InferenceProvider for RateLimitedProvider {
37 async fn infer_batch(
38 &self,
39 request: InferenceRequest,
40 ) -> Result<InferenceResponse, ProviderError> {
41 for _ in 0..request.inputs.len().max(1) {
43 self.limiter.until_ready().await;
44 }
45 self.inner.infer_batch(request).await
46 }
47
48 fn name(&self) -> &'static str {
49 self.inner.name()
50 }
51
52 fn intrinsic_labels(&self, model: &str) -> Option<Vec<String>> {
53 self.inner.intrinsic_labels(model)
54 }
55}
56
57#[cfg(test)]
58mod tests {
59 use super::*;
60 use crate::provider::{InferenceOutputs, InferenceParams, Usage};
61 use crate::registry::Task;
62 use std::time::Instant;
63
64 struct Echo;
65
66 #[async_trait]
67 impl InferenceProvider for Echo {
68 async fn infer_batch(
69 &self,
70 request: InferenceRequest,
71 ) -> Result<InferenceResponse, ProviderError> {
72 Ok(InferenceResponse {
73 outputs: InferenceOutputs::Text(request.inputs),
74 usage: Usage::ZERO,
75 })
76 }
77 fn name(&self) -> &'static str {
78 "echo"
79 }
80 }
81
82 fn request(rows: usize) -> InferenceRequest {
83 InferenceRequest {
84 task: Task::Sentiment,
85 model: "m".into(),
86 inputs: vec!["x".to_string(); rows],
87 params: InferenceParams::default(),
88 }
89 }
90
91 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
94 async fn burst_beyond_the_rate_is_paced() {
95 let p = RateLimitedProvider::new(Arc::new(Echo), NonZeroU32::new(100).unwrap());
97 let start = Instant::now();
98 let resp = p.infer_batch(request(130)).await.unwrap();
99 assert_eq!(resp.outputs.len(), 130);
100 assert!(
101 start.elapsed() >= std::time::Duration::from_millis(200),
102 "burst was not paced: {:?}",
103 start.elapsed()
104 );
105 }
106
107 #[tokio::test]
108 async fn name_delegates_to_inner() {
109 let p = RateLimitedProvider::new(Arc::new(Echo), NonZeroU32::new(1000).unwrap());
110 assert_eq!(p.name(), "echo");
111 }
112}