laminar_db/ai/cache.rs
1//! Result cache for AI inference, keyed `(content_hash, model_id,
2//! params_version)`.
3//!
4//! The key versions on both the model and its parameters, so a local model and
5//! a remote model — or the same model under different parameters — never
6//! collide on the same input text. Local results are deterministic and cacheable
7//! permanently for correctness; remote results are cached as a cost-saver. The
8//! cache itself does not distinguish the two — that policy lives in the caller.
9//!
10//! The cache is an in-memory [`quick_cache::sync::Cache`] with S3-FIFO-style
11//! eviction, the same crate the lookup and schema-registry caches use. A lookup
12//! is a memory op: cheap enough to gate the inference worker from the operator
13//! without doing the model call inline.
14
15use quick_cache::sync::{Cache, DefaultLifecycle};
16use quick_cache::{DefaultHashBuilder, Weighter};
17
18use crate::ai::provider::InferenceParams;
19use crate::ai::registry::Task;
20
21/// Cache key. All fields are `Copy`, so lookups need no allocation and no
22/// borrowed-key indirection.
23///
24/// - `content_hash`: xxh3-128 of the input text (see [`content_hash`]).
25/// - `model_id`: a stable per-model integer assigned by the caller (the
26/// registry), distinguishing models without hashing their names on lookup.
27/// - `params_version`: a hash of the request parameters (see [`params_version`]).
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub struct AiCacheKey {
30 /// xxh3-128 hash of the input content.
31 pub content_hash: u128,
32 /// Stable integer id of the model.
33 pub model_id: u32,
34 /// The task — a model can serve several (e.g. classify and sentiment), and
35 /// they produce different outputs for the same input, so it must key the
36 /// cache or results would collide across tasks.
37 pub task: Task,
38 /// Hash of the request parameters that affect the output.
39 pub params_version: u64,
40}
41
42/// One row's cached inference output. Mirrors the per-row shape of
43/// [`crate::ai::provider::InferenceOutputs`] but singular, since the cache is keyed
44/// per row of input.
45#[derive(Debug, Clone, PartialEq)]
46pub enum CachedOutput {
47 /// A text output (label, completion, summary, …).
48 Text(String),
49 /// A numeric vector output (embedding).
50 Vector(Vec<f32>),
51 /// A scalar score output (`ai_sentiment`, continuous in `[-1, 1]`).
52 Score(f64),
53}
54
55/// xxh3-128 of the input content. Not cryptographic; a fast, collision-negligible
56/// key for a result cache.
57#[must_use]
58pub fn content_hash(input: &str) -> u128 {
59 xxhash_rust::xxh3::xxh3_128(input.as_bytes())
60}
61
62/// A stable hash of the parameters that change a model's output for the same
63/// input — currently the candidate label set. Computed once per batch (all rows
64/// in a batch share parameters), not per row.
65///
66/// Maintenance contract: every field of [`InferenceParams`] that can change the
67/// output must be folded in here. It is hashed field-by-field rather than via a
68/// derived `Hash` because parameters added later (e.g. an `f32` temperature)
69/// are not `Hash`; fold those in explicitly (e.g. `to_bits()`).
70#[must_use]
71pub fn params_version(params: &InferenceParams) -> u64 {
72 use std::hash::{Hash, Hasher};
73 let mut hasher = xxhash_rust::xxh3::Xxh3::new();
74 params.labels.hash(&mut hasher);
75 hasher.finish()
76}
77
78/// Configuration for [`AiResultCache`].
79#[derive(Debug, Clone, Copy)]
80pub struct AiResultCacheConfig {
81 /// Memory budget in bytes. Entries are weighted by payload size, so this
82 /// bounds memory directly — an entry count would not, since an embedding
83 /// vector is orders of magnitude larger than a one-word label.
84 pub capacity_bytes: usize,
85}
86
87impl Default for AiResultCacheConfig {
88 fn default() -> Self {
89 Self {
90 capacity_bytes: 64 * 1024 * 1024,
91 }
92 }
93}
94
95/// Weight of one cache entry: its payload bytes plus fixed key/bookkeeping
96/// overhead, so tiny entries still count against the budget.
97#[derive(Debug, Clone)]
98struct OutputWeighter;
99
100impl Weighter<AiCacheKey, CachedOutput> for OutputWeighter {
101 fn weight(&self, _key: &AiCacheKey, value: &CachedOutput) -> u64 {
102 let payload = match value {
103 CachedOutput::Text(s) => s.len(),
104 CachedOutput::Vector(v) => v.len() * std::mem::size_of::<f32>(),
105 CachedOutput::Score(_) => std::mem::size_of::<f64>(),
106 };
107 (payload + std::mem::size_of::<AiCacheKey>() + 32) as u64
108 }
109}
110
111/// `quick_cache`-backed in-memory cache of per-row inference results.
112///
113/// `quick_cache::sync::Cache` is internally sharded with per-shard locks held
114/// only for the duration of a map operation, so [`AiResultCache`] is
115/// `Send + Sync`.
116pub struct AiResultCache {
117 cache: Cache<AiCacheKey, CachedOutput, OutputWeighter>,
118}
119
120impl AiResultCache {
121 /// Create a cache with the given configuration.
122 #[must_use]
123 pub fn new(config: AiResultCacheConfig) -> Self {
124 // Estimated entry count only sizes internal tables; within an order
125 // of magnitude is fine. Assume ~256 B/entry (labels are small,
126 // embeddings large; the byte weighter enforces the real bound).
127 let estimated_items = (config.capacity_bytes / 256).max(64);
128 let cache = Cache::with(
129 estimated_items,
130 config.capacity_bytes as u64,
131 OutputWeighter,
132 DefaultHashBuilder::default(),
133 DefaultLifecycle::default(),
134 );
135 Self { cache }
136 }
137
138 /// Create a cache with default configuration.
139 #[must_use]
140 pub fn with_defaults() -> Self {
141 Self::new(AiResultCacheConfig::default())
142 }
143
144 /// Look up a cached result.
145 #[must_use]
146 pub fn get(&self, key: &AiCacheKey) -> Option<CachedOutput> {
147 self.cache.get(key)
148 }
149
150 /// Insert or update a cached result.
151 pub fn insert(&self, key: AiCacheKey, value: CachedOutput) {
152 self.cache.insert(key, value);
153 }
154
155 /// Number of entries currently cached.
156 #[must_use]
157 pub fn len(&self) -> usize {
158 self.cache.len()
159 }
160
161 /// Whether the cache is empty.
162 #[must_use]
163 pub fn is_empty(&self) -> bool {
164 self.len() == 0
165 }
166}
167
168impl std::fmt::Debug for AiResultCache {
169 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170 f.debug_struct("AiResultCache")
171 .field("len", &self.len())
172 .finish()
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 fn key(content: &str, model_id: u32, labels: Option<Vec<String>>) -> AiCacheKey {
181 let params = InferenceParams { labels };
182 AiCacheKey {
183 content_hash: content_hash(content),
184 model_id,
185 task: Task::Sentiment,
186 params_version: params_version(¶ms),
187 }
188 }
189
190 #[test]
191 fn params_version_separates_label_sets() {
192 let a = InferenceParams {
193 labels: Some(vec!["pos".into(), "neg".into()]),
194 };
195 let b = InferenceParams {
196 labels: Some(vec!["pos".into(), "neg".into(), "neutral".into()]),
197 };
198 assert_eq!(params_version(&a), params_version(&a));
199 assert_ne!(params_version(&a), params_version(&b));
200 assert_ne!(
201 params_version(&a),
202 params_version(&InferenceParams::default())
203 );
204 }
205
206 #[test]
207 fn same_text_different_model_does_not_collide() {
208 let cache = AiResultCache::with_defaults();
209 let finbert = key("flat quarter", 1, None);
210 let remote = key("flat quarter", 2, None);
211 cache.insert(finbert, CachedOutput::Text("neutral".into()));
212 cache.insert(remote, CachedOutput::Text("negative".into()));
213 assert_eq!(
214 cache.get(&finbert),
215 Some(CachedOutput::Text("neutral".into()))
216 );
217 assert_eq!(
218 cache.get(&remote),
219 Some(CachedOutput::Text("negative".into()))
220 );
221 }
222}