1use std::sync::atomic::{AtomicU64, Ordering};
16
17use foyer::{Cache, CacheBuilder};
18
19use crate::provider::InferenceParams;
20use crate::registry::Task;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
30pub struct AiCacheKey {
31 pub content_hash: u128,
33 pub model_id: u32,
35 pub task: Task,
39 pub params_version: u64,
41}
42
43#[derive(Debug, Clone, PartialEq)]
47pub enum CachedOutput {
48 Text(String),
50 Vector(Vec<f32>),
52 Score(f64),
54}
55
56#[must_use]
59pub fn content_hash(input: &str) -> u128 {
60 xxhash_rust::xxh3::xxh3_128(input.as_bytes())
61}
62
63#[must_use]
72pub fn params_version(params: &InferenceParams) -> u64 {
73 use std::hash::{Hash, Hasher};
74 let mut hasher = xxhash_rust::xxh3::Xxh3::new();
75 params.labels.hash(&mut hasher);
76 hasher.finish()
77}
78
79#[derive(Debug, Clone, Copy)]
81pub struct AiResultCacheConfig {
82 pub capacity_bytes: usize,
86 pub shards: usize,
88}
89
90impl Default for AiResultCacheConfig {
91 fn default() -> Self {
92 Self {
93 capacity_bytes: 64 * 1024 * 1024,
94 shards: 16,
95 }
96 }
97}
98
99fn entry_weight(_key: &AiCacheKey, value: &CachedOutput) -> usize {
102 let payload = match value {
103 CachedOutput::Text(s) => s.len(),
104 CachedOutput::Vector(v) => v.len() * std::mem::size_of::<f32>(),
105 CachedOutput::Score(_) => std::mem::size_of::<f64>(),
106 };
107 payload + std::mem::size_of::<AiCacheKey>() + 32
108}
109
110pub struct AiResultCache {
115 cache: Cache<AiCacheKey, CachedOutput>,
116 hits: AtomicU64,
117 misses: AtomicU64,
118}
119
120impl AiResultCache {
121 #[must_use]
123 pub fn new(config: AiResultCacheConfig) -> Self {
124 let cache = CacheBuilder::new(config.capacity_bytes)
125 .with_shards(config.shards)
126 .with_weighter(entry_weight)
127 .build();
128 Self {
129 cache,
130 hits: AtomicU64::new(0),
131 misses: AtomicU64::new(0),
132 }
133 }
134
135 #[must_use]
137 pub fn with_defaults() -> Self {
138 Self::new(AiResultCacheConfig::default())
139 }
140
141 #[must_use]
143 pub fn get(&self, key: &AiCacheKey) -> Option<CachedOutput> {
144 if let Some(entry) = self.cache.get(key) {
145 self.hits.fetch_add(1, Ordering::Relaxed);
146 Some(entry.value().clone())
147 } else {
148 self.misses.fetch_add(1, Ordering::Relaxed);
149 None
150 }
151 }
152
153 pub fn insert(&self, key: AiCacheKey, value: CachedOutput) {
155 self.cache.insert(key, value);
156 }
157
158 #[must_use]
160 pub fn hit_count(&self) -> u64 {
161 self.hits.load(Ordering::Relaxed)
162 }
163
164 #[must_use]
166 pub fn miss_count(&self) -> u64 {
167 self.misses.load(Ordering::Relaxed)
168 }
169
170 #[must_use]
172 pub fn len(&self) -> usize {
173 self.cache.entries()
174 }
175
176 #[must_use]
178 pub fn is_empty(&self) -> bool {
179 self.len() == 0
180 }
181}
182
183impl std::fmt::Debug for AiResultCache {
184 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185 f.debug_struct("AiResultCache")
186 .field("len", &self.len())
187 .field("hits", &self.hit_count())
188 .field("misses", &self.miss_count())
189 .finish()
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196
197 fn key(content: &str, model_id: u32, labels: Option<Vec<String>>) -> AiCacheKey {
198 let params = InferenceParams { labels };
199 AiCacheKey {
200 content_hash: content_hash(content),
201 model_id,
202 task: Task::Sentiment,
203 params_version: params_version(¶ms),
204 }
205 }
206
207 #[test]
208 fn params_version_separates_label_sets() {
209 let a = InferenceParams {
210 labels: Some(vec!["pos".into(), "neg".into()]),
211 };
212 let b = InferenceParams {
213 labels: Some(vec!["pos".into(), "neg".into(), "neutral".into()]),
214 };
215 assert_eq!(params_version(&a), params_version(&a));
216 assert_ne!(params_version(&a), params_version(&b));
217 assert_ne!(
218 params_version(&a),
219 params_version(&InferenceParams::default())
220 );
221 }
222
223 #[test]
224 fn same_text_different_model_does_not_collide() {
225 let cache = AiResultCache::with_defaults();
226 let finbert = key("flat quarter", 1, None);
227 let remote = key("flat quarter", 2, None);
228 cache.insert(finbert, CachedOutput::Text("neutral".into()));
229 cache.insert(remote, CachedOutput::Text("negative".into()));
230 assert_eq!(
231 cache.get(&finbert),
232 Some(CachedOutput::Text("neutral".into()))
233 );
234 assert_eq!(
235 cache.get(&remote),
236 Some(CachedOutput::Text("negative".into()))
237 );
238 assert_eq!(cache.hit_count(), 2);
239 }
240}