Skip to main content

laminar_db/ai/
cache.rs

1//! Result cache for AI inference, keyed `(content_hash, model_id,
2//! params_version)`.
3//!
4//! The key versions on both the model and its parameters, so a local model and
5//! a remote model — or the same model under different parameters — never
6//! collide on the same input text. Local results are deterministic and cacheable
7//! permanently for correctness; remote results are cached as a cost-saver. The
8//! cache itself does not distinguish the two — that policy lives in the caller.
9//!
10//! The cache is an in-memory [`quick_cache::sync::Cache`] with S3-FIFO-style
11//! eviction, the same crate the lookup and schema-registry caches use. A lookup
12//! is a memory op: cheap enough to gate the inference worker from the operator
13//! without doing the model call inline.
14
15use quick_cache::sync::{Cache, DefaultLifecycle};
16use quick_cache::{DefaultHashBuilder, Weighter};
17
18use crate::ai::provider::InferenceParams;
19use crate::ai::registry::Task;
20
21/// Cache key. All fields are `Copy`, so lookups need no allocation and no
22/// borrowed-key indirection.
23///
24/// - `content_hash`: xxh3-128 of the input text (see [`content_hash`]).
25/// - `model_id`: a stable per-model integer assigned by the caller (the
26///   registry), distinguishing models without hashing their names on lookup.
27/// - `params_version`: a hash of the request parameters (see [`params_version`]).
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub struct AiCacheKey {
30    /// xxh3-128 hash of the input content.
31    pub content_hash: u128,
32    /// Stable integer id of the model.
33    pub model_id: u32,
34    /// The task — a model can serve several (e.g. classify and sentiment), and
35    /// they produce different outputs for the same input, so it must key the
36    /// cache or results would collide across tasks.
37    pub task: Task,
38    /// Hash of the request parameters that affect the output.
39    pub params_version: u64,
40}
41
42/// One row's cached inference output. Mirrors the per-row shape of
43/// [`crate::ai::provider::InferenceOutputs`] but singular, since the cache is keyed
44/// per row of input.
45#[derive(Debug, Clone, PartialEq)]
46pub enum CachedOutput {
47    /// A text output (label, completion, summary, …).
48    Text(String),
49    /// A numeric vector output (embedding).
50    Vector(Vec<f32>),
51    /// A scalar score output (`ai_sentiment`, continuous in `[-1, 1]`).
52    Score(f64),
53}
54
55/// xxh3-128 of the input content. Not cryptographic; a fast, collision-negligible
56/// key for a result cache.
57#[must_use]
58pub fn content_hash(input: &str) -> u128 {
59    xxhash_rust::xxh3::xxh3_128(input.as_bytes())
60}
61
62/// A stable hash of the parameters that change a model's output for the same
63/// input — currently the candidate label set. Computed once per batch (all rows
64/// in a batch share parameters), not per row.
65///
66/// Maintenance contract: every field of [`InferenceParams`] that can change the
67/// output must be folded in here. It is hashed field-by-field rather than via a
68/// derived `Hash` because parameters added later (e.g. an `f32` temperature)
69/// are not `Hash`; fold those in explicitly (e.g. `to_bits()`).
70#[must_use]
71pub fn params_version(params: &InferenceParams) -> u64 {
72    use std::hash::{Hash, Hasher};
73    let mut hasher = xxhash_rust::xxh3::Xxh3::new();
74    params.labels.hash(&mut hasher);
75    hasher.finish()
76}
77
78/// Configuration for [`AiResultCache`].
79#[derive(Debug, Clone, Copy)]
80pub struct AiResultCacheConfig {
81    /// Memory budget in bytes. Entries are weighted by payload size, so this
82    /// bounds memory directly — an entry count would not, since an embedding
83    /// vector is orders of magnitude larger than a one-word label.
84    pub capacity_bytes: usize,
85}
86
87impl Default for AiResultCacheConfig {
88    fn default() -> Self {
89        Self {
90            capacity_bytes: 64 * 1024 * 1024,
91        }
92    }
93}
94
95/// Weight of one cache entry: its payload bytes plus fixed key/bookkeeping
96/// overhead, so tiny entries still count against the budget.
97#[derive(Debug, Clone)]
98struct OutputWeighter;
99
100impl Weighter<AiCacheKey, CachedOutput> for OutputWeighter {
101    fn weight(&self, _key: &AiCacheKey, value: &CachedOutput) -> u64 {
102        let payload = match value {
103            CachedOutput::Text(s) => s.len(),
104            CachedOutput::Vector(v) => v.len() * std::mem::size_of::<f32>(),
105            CachedOutput::Score(_) => std::mem::size_of::<f64>(),
106        };
107        (payload + std::mem::size_of::<AiCacheKey>() + 32) as u64
108    }
109}
110
111/// `quick_cache`-backed in-memory cache of per-row inference results.
112///
113/// `quick_cache::sync::Cache` is internally sharded with per-shard locks held
114/// only for the duration of a map operation, so [`AiResultCache`] is
115/// `Send + Sync`.
116pub struct AiResultCache {
117    cache: Cache<AiCacheKey, CachedOutput, OutputWeighter>,
118}
119
120impl AiResultCache {
121    /// Create a cache with the given configuration.
122    #[must_use]
123    pub fn new(config: AiResultCacheConfig) -> Self {
124        // Estimated entry count only sizes internal tables; within an order
125        // of magnitude is fine. Assume ~256 B/entry (labels are small,
126        // embeddings large; the byte weighter enforces the real bound).
127        let estimated_items = (config.capacity_bytes / 256).max(64);
128        let cache = Cache::with(
129            estimated_items,
130            config.capacity_bytes as u64,
131            OutputWeighter,
132            DefaultHashBuilder::default(),
133            DefaultLifecycle::default(),
134        );
135        Self { cache }
136    }
137
138    /// Create a cache with default configuration.
139    #[must_use]
140    pub fn with_defaults() -> Self {
141        Self::new(AiResultCacheConfig::default())
142    }
143
144    /// Look up a cached result.
145    #[must_use]
146    pub fn get(&self, key: &AiCacheKey) -> Option<CachedOutput> {
147        self.cache.get(key)
148    }
149
150    /// Insert or update a cached result.
151    pub fn insert(&self, key: AiCacheKey, value: CachedOutput) {
152        self.cache.insert(key, value);
153    }
154
155    /// Number of entries currently cached.
156    #[must_use]
157    pub fn len(&self) -> usize {
158        self.cache.len()
159    }
160
161    /// Whether the cache is empty.
162    #[must_use]
163    pub fn is_empty(&self) -> bool {
164        self.len() == 0
165    }
166}
167
168impl std::fmt::Debug for AiResultCache {
169    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170        f.debug_struct("AiResultCache")
171            .field("len", &self.len())
172            .finish()
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    fn key(content: &str, model_id: u32, labels: Option<Vec<String>>) -> AiCacheKey {
181        let params = InferenceParams { labels };
182        AiCacheKey {
183            content_hash: content_hash(content),
184            model_id,
185            task: Task::Sentiment,
186            params_version: params_version(&params),
187        }
188    }
189
190    #[test]
191    fn params_version_separates_label_sets() {
192        let a = InferenceParams {
193            labels: Some(vec!["pos".into(), "neg".into()]),
194        };
195        let b = InferenceParams {
196            labels: Some(vec!["pos".into(), "neg".into(), "neutral".into()]),
197        };
198        assert_eq!(params_version(&a), params_version(&a));
199        assert_ne!(params_version(&a), params_version(&b));
200        assert_ne!(
201            params_version(&a),
202            params_version(&InferenceParams::default())
203        );
204    }
205
206    #[test]
207    fn same_text_different_model_does_not_collide() {
208        let cache = AiResultCache::with_defaults();
209        let finbert = key("flat quarter", 1, None);
210        let remote = key("flat quarter", 2, None);
211        cache.insert(finbert, CachedOutput::Text("neutral".into()));
212        cache.insert(remote, CachedOutput::Text("negative".into()));
213        assert_eq!(
214            cache.get(&finbert),
215            Some(CachedOutput::Text("neutral".into()))
216        );
217        assert_eq!(
218            cache.get(&remote),
219            Some(CachedOutput::Text("negative".into()))
220        );
221    }
222}