Skip to main content

laminar_ai/
cache.rs

1//! Result cache for AI inference, keyed `(content_hash, model_id,
2//! params_version)`.
3//!
4//! The key versions on both the model and its parameters, so a local model and
5//! a remote model — or the same model under different parameters — never
6//! collide on the same input text. Local results are deterministic and cacheable
7//! permanently for correctness; remote results are cached as a cost-saver. The
8//! cache itself does not distinguish the two — that policy lives in the caller.
9//!
10//! The cache is an in-memory [`foyer::Cache`] with S3-FIFO eviction, the same
11//! crate the lookup and schema-registry caches use. A lookup is a memory op:
12//! cheap enough to gate the inference worker from the operator without doing the
13//! model call inline.
14
15use std::sync::atomic::{AtomicU64, Ordering};
16
17use foyer::{Cache, CacheBuilder};
18
19use crate::provider::InferenceParams;
20use crate::registry::Task;
21
22/// Cache key. All fields are `Copy`, so lookups need no allocation and no
23/// borrowed-key indirection.
24///
25/// - `content_hash`: xxh3-128 of the input text (see [`content_hash`]).
26/// - `model_id`: a stable per-model integer assigned by the caller (the
27///   registry), distinguishing models without hashing their names on lookup.
28/// - `params_version`: a hash of the request parameters (see [`params_version`]).
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
30pub struct AiCacheKey {
31    /// xxh3-128 hash of the input content.
32    pub content_hash: u128,
33    /// Stable integer id of the model.
34    pub model_id: u32,
35    /// The task — a model can serve several (e.g. classify and sentiment), and
36    /// they produce different outputs for the same input, so it must key the
37    /// cache or results would collide across tasks.
38    pub task: Task,
39    /// Hash of the request parameters that affect the output.
40    pub params_version: u64,
41}
42
43/// One row's cached inference output. Mirrors the per-row shape of
44/// [`crate::provider::InferenceOutputs`] but singular, since the cache is keyed
45/// per row of input.
46#[derive(Debug, Clone, PartialEq)]
47pub enum CachedOutput {
48    /// A text output (label, completion, summary, …).
49    Text(String),
50    /// A numeric vector output (embedding).
51    Vector(Vec<f32>),
52    /// A scalar score output (`ai_sentiment`, continuous in `[-1, 1]`).
53    Score(f64),
54}
55
56/// xxh3-128 of the input content. Not cryptographic; a fast, collision-negligible
57/// key for a result cache.
58#[must_use]
59pub fn content_hash(input: &str) -> u128 {
60    xxhash_rust::xxh3::xxh3_128(input.as_bytes())
61}
62
63/// A stable hash of the parameters that change a model's output for the same
64/// input — currently the candidate label set. Computed once per batch (all rows
65/// in a batch share parameters), not per row.
66///
67/// Maintenance contract: every field of [`InferenceParams`] that can change the
68/// output must be folded in here. It is hashed field-by-field rather than via a
69/// derived `Hash` because parameters added later (e.g. an `f32` temperature)
70/// are not `Hash`; fold those in explicitly (e.g. `to_bits()`).
71#[must_use]
72pub fn params_version(params: &InferenceParams) -> u64 {
73    use std::hash::{Hash, Hasher};
74    let mut hasher = xxhash_rust::xxh3::Xxh3::new();
75    params.labels.hash(&mut hasher);
76    hasher.finish()
77}
78
79/// Configuration for [`AiResultCache`].
80#[derive(Debug, Clone, Copy)]
81pub struct AiResultCacheConfig {
82    /// Memory budget in bytes. Entries are weighted by payload size, so this
83    /// bounds memory directly — an entry count would not, since an embedding
84    /// vector is orders of magnitude larger than a one-word label.
85    pub capacity_bytes: usize,
86    /// Number of shards for concurrent access (power of 2).
87    pub shards: usize,
88}
89
90impl Default for AiResultCacheConfig {
91    fn default() -> Self {
92        Self {
93            capacity_bytes: 64 * 1024 * 1024,
94            shards: 16,
95        }
96    }
97}
98
99/// Weight of one cache entry: its payload bytes plus fixed key/bookkeeping
100/// overhead, so tiny entries still count against the budget.
101fn entry_weight(_key: &AiCacheKey, value: &CachedOutput) -> usize {
102    let payload = match value {
103        CachedOutput::Text(s) => s.len(),
104        CachedOutput::Vector(v) => v.len() * std::mem::size_of::<f32>(),
105        CachedOutput::Score(_) => std::mem::size_of::<f64>(),
106    };
107    payload + std::mem::size_of::<AiCacheKey>() + 32
108}
109
110/// foyer-backed in-memory cache of per-row inference results.
111///
112/// `foyer::Cache` is internally sharded and lock-free on the read path, so
113/// [`AiResultCache`] is `Send + Sync`.
114pub struct AiResultCache {
115    cache: Cache<AiCacheKey, CachedOutput>,
116    hits: AtomicU64,
117    misses: AtomicU64,
118}
119
120impl AiResultCache {
121    /// Create a cache with the given configuration.
122    #[must_use]
123    pub fn new(config: AiResultCacheConfig) -> Self {
124        let cache = CacheBuilder::new(config.capacity_bytes)
125            .with_shards(config.shards)
126            .with_weighter(entry_weight)
127            .build();
128        Self {
129            cache,
130            hits: AtomicU64::new(0),
131            misses: AtomicU64::new(0),
132        }
133    }
134
135    /// Create a cache with default configuration.
136    #[must_use]
137    pub fn with_defaults() -> Self {
138        Self::new(AiResultCacheConfig::default())
139    }
140
141    /// Look up a cached result, recording a hit or miss.
142    #[must_use]
143    pub fn get(&self, key: &AiCacheKey) -> Option<CachedOutput> {
144        if let Some(entry) = self.cache.get(key) {
145            self.hits.fetch_add(1, Ordering::Relaxed);
146            Some(entry.value().clone())
147        } else {
148            self.misses.fetch_add(1, Ordering::Relaxed);
149            None
150        }
151    }
152
153    /// Insert or update a cached result.
154    pub fn insert(&self, key: AiCacheKey, value: CachedOutput) {
155        self.cache.insert(key, value);
156    }
157
158    /// Total cache hits since creation.
159    #[must_use]
160    pub fn hit_count(&self) -> u64 {
161        self.hits.load(Ordering::Relaxed)
162    }
163
164    /// Total cache misses since creation.
165    #[must_use]
166    pub fn miss_count(&self) -> u64 {
167        self.misses.load(Ordering::Relaxed)
168    }
169
170    /// Number of entries currently cached.
171    #[must_use]
172    pub fn len(&self) -> usize {
173        self.cache.entries()
174    }
175
176    /// Whether the cache is empty.
177    #[must_use]
178    pub fn is_empty(&self) -> bool {
179        self.len() == 0
180    }
181}
182
183impl std::fmt::Debug for AiResultCache {
184    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185        f.debug_struct("AiResultCache")
186            .field("len", &self.len())
187            .field("hits", &self.hit_count())
188            .field("misses", &self.miss_count())
189            .finish()
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196
197    fn key(content: &str, model_id: u32, labels: Option<Vec<String>>) -> AiCacheKey {
198        let params = InferenceParams { labels };
199        AiCacheKey {
200            content_hash: content_hash(content),
201            model_id,
202            task: Task::Sentiment,
203            params_version: params_version(&params),
204        }
205    }
206
207    #[test]
208    fn params_version_separates_label_sets() {
209        let a = InferenceParams {
210            labels: Some(vec!["pos".into(), "neg".into()]),
211        };
212        let b = InferenceParams {
213            labels: Some(vec!["pos".into(), "neg".into(), "neutral".into()]),
214        };
215        assert_eq!(params_version(&a), params_version(&a));
216        assert_ne!(params_version(&a), params_version(&b));
217        assert_ne!(
218            params_version(&a),
219            params_version(&InferenceParams::default())
220        );
221    }
222
223    #[test]
224    fn same_text_different_model_does_not_collide() {
225        let cache = AiResultCache::with_defaults();
226        let finbert = key("flat quarter", 1, None);
227        let remote = key("flat quarter", 2, None);
228        cache.insert(finbert, CachedOutput::Text("neutral".into()));
229        cache.insert(remote, CachedOutput::Text("negative".into()));
230        assert_eq!(
231            cache.get(&finbert),
232            Some(CachedOutput::Text("neutral".into()))
233        );
234        assert_eq!(
235            cache.get(&remote),
236            Some(CachedOutput::Text("negative".into()))
237        );
238        assert_eq!(cache.hit_count(), 2);
239    }
240}