1use std::hash::{Hash, Hasher};
12use std::time::{Duration, Instant};
13
14use arrow_array::RecordBatch;
15use equivalent::Equivalent;
16use quick_cache::sync::{Cache, DefaultLifecycle};
17use quick_cache::{DefaultHashBuilder, Weighter};
18
19use crate::lookup::table::LookupResult;
20
21#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
26pub struct LookupCacheKey {
27 pub table_id: u32,
29 pub key: Vec<u8>,
31}
32
33pub(crate) struct LookupCacheKeyRef<'a> {
39 pub(crate) table_id: u32,
40 pub(crate) key: &'a [u8],
41}
42
43impl Hash for LookupCacheKeyRef<'_> {
44 fn hash<H: Hasher>(&self, state: &mut H) {
45 self.table_id.hash(state);
49 self.key.hash(state);
50 }
51}
52
53impl Equivalent<LookupCacheKey> for LookupCacheKeyRef<'_> {
54 fn equivalent(&self, other: &LookupCacheKey) -> bool {
55 self.table_id == other.table_id && self.key == other.key.as_slice()
56 }
57}
58
59#[derive(Debug, Clone, Copy)]
61pub struct LookupMemoryCacheConfig {
62 pub capacity_bytes: usize,
68 pub ttl: Option<Duration>,
74}
75
76impl Default for LookupMemoryCacheConfig {
77 fn default() -> Self {
78 Self {
79 capacity_bytes: 64 * 1024 * 1024, ttl: None,
81 }
82 }
83}
84
85#[derive(Clone)]
88struct CachedBatch {
89 batch: RecordBatch,
90 inserted_at: Instant,
91}
92
93#[derive(Debug, Clone)]
96struct BatchWeighter;
97
98impl Weighter<LookupCacheKey, CachedBatch> for BatchWeighter {
99 fn weight(&self, _key: &LookupCacheKey, val: &CachedBatch) -> u64 {
100 val.batch.get_array_memory_size().max(1) as u64
101 }
102}
103
104type BatchCache = Cache<LookupCacheKey, CachedBatch, BatchWeighter>;
105
106pub struct LookupMemoryCache {
118 cache: BatchCache,
119 table_id: u32,
120 ttl: Option<Duration>,
121}
122
123impl LookupMemoryCache {
124 #[must_use]
126 pub fn new(table_id: u32, config: LookupMemoryCacheConfig) -> Self {
127 let estimated_items = (config.capacity_bytes / 1024).max(64);
130 let cache = BatchCache::with(
131 estimated_items,
132 config.capacity_bytes as u64,
133 BatchWeighter,
134 DefaultHashBuilder::default(),
135 DefaultLifecycle::default(),
136 );
137
138 Self {
139 cache,
140 table_id,
141 ttl: config.ttl,
142 }
143 }
144
145 #[must_use]
147 pub fn with_defaults(table_id: u32) -> Self {
148 Self::new(table_id, LookupMemoryCacheConfig::default())
149 }
150
151 #[must_use]
153 pub fn table_id(&self) -> u32 {
154 self.table_id
155 }
156
157 fn make_key(&self, key: &[u8]) -> LookupCacheKey {
159 LookupCacheKey {
160 table_id: self.table_id,
161 key: key.to_vec(),
162 }
163 }
164
165 #[must_use]
173 pub fn get_cached(&self, key: &[u8]) -> LookupResult {
174 let ref_key = LookupCacheKeyRef {
175 table_id: self.table_id,
176 key,
177 };
178 match self.cache.get(&ref_key) {
179 Some(cached) if self.is_expired(&cached) => {
180 self.cache.remove_if(&ref_key, |v| self.is_expired(v));
181 LookupResult::NotFound
182 }
183 Some(cached) => LookupResult::Hit(cached.batch),
184 None => LookupResult::NotFound,
185 }
186 }
187
188 fn is_expired(&self, entry: &CachedBatch) -> bool {
190 self.ttl
191 .is_some_and(|ttl| entry.inserted_at.elapsed() >= ttl)
192 }
193
194 pub fn insert(&self, key: &[u8], value: RecordBatch) {
196 let cache_key = self.make_key(key);
197 self.cache.insert(
198 cache_key,
199 CachedBatch {
200 batch: value,
201 inserted_at: Instant::now(),
202 },
203 );
204 }
205
206 pub fn invalidate(&self, key: &[u8]) {
208 let ref_key = LookupCacheKeyRef {
209 table_id: self.table_id,
210 key,
211 };
212 self.cache.remove(&ref_key);
213 }
214
215 #[must_use]
217 pub fn len(&self) -> usize {
218 self.cache.len()
219 }
220
221 #[must_use]
223 pub fn is_empty(&self) -> bool {
224 self.cache.is_empty()
225 }
226}
227
228impl std::fmt::Debug for LookupMemoryCache {
229 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230 f.debug_struct("LookupMemoryCache")
231 .field("table_id", &self.table_id)
232 .field("ttl", &self.ttl)
233 .field("entries", &self.cache.len())
234 .finish()
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use arrow_array::StringArray;
242 use arrow_schema::{DataType, Field, Schema};
243 use std::sync::Arc;
244
245 fn test_batch(val: &str) -> RecordBatch {
246 let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Utf8, false)]));
247 RecordBatch::try_new(schema, vec![Arc::new(StringArray::from(vec![val]))]).unwrap()
248 }
249
250 fn small_cache(table_id: u32) -> LookupMemoryCache {
251 LookupMemoryCache::new(
252 table_id,
253 LookupMemoryCacheConfig {
254 capacity_bytes: 64 * 1024,
255 ttl: None,
256 },
257 )
258 }
259
260 #[test]
261 fn test_lookup_cache_hit_miss() {
262 let cache = small_cache(1);
263
264 assert!(cache.get_cached(b"key1").is_not_found());
265
266 cache.insert(b"key1", test_batch("value1"));
267 let result = cache.get_cached(b"key1");
268 assert!(result.is_hit());
269 assert_eq!(result.into_batch().unwrap().num_rows(), 1);
270 }
271
272 #[test]
273 fn test_lookup_cache_eviction() {
274 let cache = LookupMemoryCache::new(
277 1,
278 LookupMemoryCacheConfig {
279 capacity_bytes: 512,
280 ttl: None,
281 },
282 );
283
284 for i in 0..200u8 {
285 cache.insert(&[i], test_batch(&format!("v{i}")));
286 }
287
288 assert!(
289 cache.len() < 200,
290 "byte bound did not evict: len {}",
291 cache.len()
292 );
293 }
294
295 #[test]
296 fn test_lookup_cache_invalidation() {
297 let cache = small_cache(1);
298
299 cache.insert(b"key1", test_batch("value1"));
300 assert!(cache.get_cached(b"key1").is_hit());
301
302 cache.invalidate(b"key1");
303 assert!(cache.get_cached(b"key1").is_not_found());
304 }
305
306 #[test]
307 fn test_lookup_cache_table_id_isolation() {
308 let cache_a = small_cache(1);
309 let cache_b = small_cache(2);
310
311 cache_a.insert(b"key1", test_batch("from_a"));
312 cache_b.insert(b"key1", test_batch("from_b"));
313
314 let batch_a = cache_a.get_cached(b"key1").into_batch().unwrap();
315 let batch_b = cache_b.get_cached(b"key1").into_batch().unwrap();
316
317 assert_eq!(batch_a.num_rows(), 1);
318 assert_eq!(batch_b.num_rows(), 1);
319 assert_ne!(batch_a, batch_b);
320 }
321
322 fn ttl_cache(ttl: Duration) -> LookupMemoryCache {
323 LookupMemoryCache::new(
324 1,
325 LookupMemoryCacheConfig {
326 capacity_bytes: 64 * 1024,
327 ttl: Some(ttl),
328 },
329 )
330 }
331
332 #[test]
333 fn test_ttl_zero_expires_immediately() {
334 let cache = ttl_cache(Duration::ZERO);
336 cache.insert(b"k", test_batch("v"));
337 assert!(cache.get_cached(b"k").is_not_found());
338 assert!(cache.is_empty());
340 }
341
342 #[test]
343 fn test_ttl_hit_then_expire() {
344 let cache = ttl_cache(Duration::from_millis(20));
345 cache.insert(b"k", test_batch("v"));
346 assert!(cache.get_cached(b"k").is_hit());
348 std::thread::sleep(Duration::from_millis(40));
349 assert!(cache.get_cached(b"k").is_not_found());
351 assert!(cache.is_empty());
352 }
353
354 #[test]
355 fn test_no_ttl_entry_survives() {
356 let cache = small_cache(1);
358 cache.insert(b"k", test_batch("v"));
359 std::thread::sleep(Duration::from_millis(10));
360 assert!(cache.get_cached(b"k").is_hit());
361 }
362}