1use std::hash::{Hash, Hasher};
11use std::sync::atomic::{AtomicU64, Ordering};
12
13use arrow_array::RecordBatch;
14use equivalent::Equivalent;
15use foyer::{Cache, CacheBuilder};
16
17use crate::lookup::table::{LookupResult, LookupTable};
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
24pub struct LookupCacheKey {
25 pub table_id: u32,
27 pub key: Vec<u8>,
29}
30
31pub(crate) struct LookupCacheKeyRef<'a> {
37 pub(crate) table_id: u32,
38 pub(crate) key: &'a [u8],
39}
40
41impl Hash for LookupCacheKeyRef<'_> {
42 fn hash<H: Hasher>(&self, state: &mut H) {
43 self.table_id.hash(state);
47 self.key.hash(state);
48 }
49}
50
51impl Equivalent<LookupCacheKey> for LookupCacheKeyRef<'_> {
52 fn equivalent(&self, other: &LookupCacheKey) -> bool {
53 self.table_id == other.table_id && self.key == other.key.as_slice()
54 }
55}
56
57#[derive(Debug, Clone, Copy)]
59pub struct FoyerMemoryCacheConfig {
60 pub capacity: usize,
62 pub shards: usize,
64}
65
66impl Default for FoyerMemoryCacheConfig {
67 fn default() -> Self {
68 Self {
69 capacity: 256 * 1024, shards: 16,
71 }
72 }
73}
74
75pub struct FoyerMemoryCache {
85 cache: Cache<LookupCacheKey, RecordBatch>,
86 table_id: u32,
87 hits: AtomicU64,
88 misses: AtomicU64,
89}
90
91impl FoyerMemoryCache {
92 #[must_use]
94 pub fn new(table_id: u32, config: FoyerMemoryCacheConfig) -> Self {
95 let cache = CacheBuilder::new(config.capacity)
96 .with_shards(config.shards)
97 .build();
98
99 Self {
100 cache,
101 table_id,
102 hits: AtomicU64::new(0),
103 misses: AtomicU64::new(0),
104 }
105 }
106
107 #[must_use]
109 pub fn with_defaults(table_id: u32) -> Self {
110 Self::new(table_id, FoyerMemoryCacheConfig::default())
111 }
112
113 #[must_use]
115 pub fn hit_count(&self) -> u64 {
116 self.hits.load(Ordering::Relaxed)
117 }
118
119 #[must_use]
121 pub fn miss_count(&self) -> u64 {
122 self.misses.load(Ordering::Relaxed)
123 }
124
125 #[must_use]
127 #[allow(clippy::cast_precision_loss)]
128 pub fn hit_ratio(&self) -> f64 {
129 let hits = self.hits.load(Ordering::Relaxed);
130 let misses = self.misses.load(Ordering::Relaxed);
131 let total = hits + misses;
132 if total == 0 {
133 0.0
134 } else {
135 hits as f64 / total as f64
136 }
137 }
138
139 #[must_use]
141 pub fn table_id(&self) -> u32 {
142 self.table_id
143 }
144
145 fn make_key(&self, key: &[u8]) -> LookupCacheKey {
147 LookupCacheKey {
148 table_id: self.table_id,
149 key: key.to_vec(),
150 }
151 }
152}
153
154impl LookupTable for FoyerMemoryCache {
155 fn get_cached(&self, key: &[u8]) -> LookupResult {
156 let ref_key = LookupCacheKeyRef {
157 table_id: self.table_id,
158 key,
159 };
160 if let Some(entry) = self.cache.get(&ref_key) {
161 let value = entry.value().clone();
162 self.hits.fetch_add(1, Ordering::Relaxed);
163 LookupResult::Hit(value)
164 } else {
165 self.misses.fetch_add(1, Ordering::Relaxed);
166 LookupResult::NotFound
167 }
168 }
169
170 fn get(&self, key: &[u8]) -> LookupResult {
171 self.get_cached(key)
172 }
173
174 fn insert(&self, key: &[u8], value: RecordBatch) {
175 let cache_key = self.make_key(key);
176 self.cache.insert(cache_key, value);
177 }
178
179 fn invalidate(&self, key: &[u8]) {
180 let ref_key = LookupCacheKeyRef {
181 table_id: self.table_id,
182 key,
183 };
184 self.cache.remove(&ref_key);
185 }
186
187 fn len(&self) -> usize {
188 self.cache.usage()
189 }
190}
191
192impl std::fmt::Debug for FoyerMemoryCache {
193 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 f.debug_struct("FoyerMemoryCache")
195 .field("table_id", &self.table_id)
196 .field("entries", &self.cache.usage())
197 .field("hits", &self.hits.load(Ordering::Relaxed))
198 .field("misses", &self.misses.load(Ordering::Relaxed))
199 .finish()
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206 use arrow_array::StringArray;
207 use arrow_schema::{DataType, Field, Schema};
208 use std::sync::Arc;
209
210 fn test_batch(val: &str) -> RecordBatch {
211 let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Utf8, false)]));
212 RecordBatch::try_new(schema, vec![Arc::new(StringArray::from(vec![val]))]).unwrap()
213 }
214
215 fn small_cache(table_id: u32) -> FoyerMemoryCache {
216 FoyerMemoryCache::new(
217 table_id,
218 FoyerMemoryCacheConfig {
219 capacity: 64,
220 shards: 4,
221 },
222 )
223 }
224
225 #[test]
226 fn test_foyer_cache_hit_miss() {
227 let cache = small_cache(1);
228
229 let result = cache.get_cached(b"key1");
230 assert!(result.is_not_found());
231 assert_eq!(cache.miss_count(), 1);
232
233 cache.insert(b"key1", test_batch("value1"));
234 let result = cache.get_cached(b"key1");
235 assert!(result.is_hit());
236 let batch = result.into_batch().unwrap();
237 assert_eq!(batch.num_rows(), 1);
238 assert_eq!(cache.hit_count(), 1);
239 }
240
241 #[test]
242 fn test_foyer_cache_eviction() {
243 let cache = FoyerMemoryCache::new(
244 1,
245 FoyerMemoryCacheConfig {
246 capacity: 8,
247 shards: 1,
248 },
249 );
250
251 for i in 0..20u8 {
252 cache.insert(&[i], test_batch(&format!("v{i}")));
253 }
254
255 assert!(cache.len() <= 8, "len {} > capacity 8", cache.len());
256 }
257
258 #[test]
259 fn test_foyer_cache_invalidation() {
260 let cache = small_cache(1);
261
262 cache.insert(b"key1", test_batch("value1"));
263 assert!(cache.get_cached(b"key1").is_hit());
264
265 cache.invalidate(b"key1");
266 assert!(cache.get_cached(b"key1").is_not_found());
267 }
268
269 #[test]
270 fn test_foyer_cache_table_id_isolation() {
271 let cache_a = small_cache(1);
272 let cache_b = small_cache(2);
273
274 cache_a.insert(b"key1", test_batch("from_a"));
275 cache_b.insert(b"key1", test_batch("from_b"));
276
277 let batch_a = cache_a.get_cached(b"key1").into_batch().unwrap();
278 let batch_b = cache_b.get_cached(b"key1").into_batch().unwrap();
279
280 assert_eq!(batch_a.num_rows(), 1);
281 assert_eq!(batch_b.num_rows(), 1);
282 assert_ne!(batch_a, batch_b);
283 }
284
285 #[test]
286 fn test_foyer_cache_implements_lookup_table() {
287 let cache = small_cache(1);
288 let table: &dyn LookupTable = &cache;
289
290 table.insert(b"k", test_batch("v"));
291 assert!(!table.is_empty());
292 assert!(table.get(b"k").is_hit());
293 }
294
295 #[test]
296 fn test_foyer_cache_hit_ratio() {
297 let cache = small_cache(1);
298 cache.insert(b"k1", test_batch("v1"));
299
300 cache.get_cached(b"k1");
302 cache.get_cached(b"k2");
304
305 assert!((cache.hit_ratio() - 0.5).abs() < f64::EPSILON);
306 }
307
308 #[test]
309 fn test_foyer_cache_debug() {
310 let cache = small_cache(42);
311 let debug = format!("{cache:?}");
312 assert!(debug.contains("FoyerMemoryCache"));
313 assert!(debug.contains("table_id: 42"));
314 }
315
316 #[test]
317 fn test_foyer_cache_default_config() {
318 let config = FoyerMemoryCacheConfig::default();
319 assert_eq!(config.capacity, 256 * 1024);
320 assert_eq!(config.shards, 16);
321 }
322
323 #[test]
324 fn test_foyer_cache_recordbatch_clone_is_cheap() {
325 let cache = small_cache(1);
326 let batch = test_batch("value");
327 cache.insert(b"k", batch.clone());
328
329 let hit1 = cache.get_cached(b"k").into_batch().unwrap();
330 let hit2 = cache.get_cached(b"k").into_batch().unwrap();
331 assert_eq!(hit1, hit2);
332 assert_eq!(hit1.num_rows(), 1);
333 }
334}