You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

193 lines
5.6 KiB
Rust

use std::fmt;
use std::hash::{Hash, Hasher};
use std::marker::PhantomData;
use std::num::NonZeroUsize;
use std::sync::Arc;
use tokio::sync::RwLock;
/// Shared in-memory LRU with multiple cache lines.
/// Basically a [`std::collections::HashMap`] that is [`Send`] + [`Sync`],
/// and without a mechanism to handle hash collisions.
pub struct MemCache<I: HashableId, T: Indexable<I>> {
lines: Vec<RwLock<Option<Arc<T>>>>,
_phantom: PhantomData<I>,
}
pub trait HashableId: Clone + fmt::Debug + Hash + PartialEq + Send {}
impl<T> HashableId for T where T: Clone + fmt::Debug + Hash + PartialEq + Send {}
pub trait Indexable<I: HashableId>: Clone + Send {
fn get_id(&self) -> I;
}
impl<I: HashableId, T: Indexable<I>> MemCache<I, T> {
pub fn new() -> MemCache<I, T> {
MemCache::with_capacity(NonZeroUsize::new(1024).unwrap())
}
pub fn with_capacity(capacity: NonZeroUsize) -> MemCache<I, T> {
MemCache {
lines: (0..capacity.get()).map(|_| RwLock::new(None)).collect(),
_phantom: PhantomData,
}
}
pub async fn get(&self, id: I) -> Option<Arc<T>> {
let index = self.compute_index(&id);
let guard = self.lines[index].read().await;
let entry = Arc::clone(guard.as_ref()?);
drop(guard);
if entry.get_id() == id {
Some(entry)
} else {
None
}
}
pub async fn put(&self, entry: T) {
let index = self.compute_index(&entry.get_id());
let entry = Arc::new(entry);
let mut guard = self.lines[index].write().await;
let old_entry = guard.replace(entry);
// explicitly drop the lock before the old entry
// so we never deallocate while holding the lock
// (unless rustc is trying to be smarter than us)
drop(guard);
drop(old_entry);
}
pub async fn del(&self, id: I) -> Option<Arc<T>> {
let index = self.compute_index(&id);
let mut guard = self.lines[index].write().await;
let entry = guard.as_ref()?;
if entry.get_id() == id {
guard.take()
} else {
None
}
}
/// Update the entry with `id`, if it exists in the cache.
/// You MUST NOT manipulate the id or things will break.
pub async fn update<F>(&self, id: I, callback: F)
where
F: FnOnce(&mut T),
{
// FIXME: Find a way to avoid the deep clone
let index = self.compute_index(&id);
let mut guard = self.lines[index].write().await;
if let Some(entry) = guard.as_ref() {
if entry.get_id() == id {
let mut entry = entry.as_ref().clone();
callback(&mut entry);
assert_eq!(entry.get_id(), id);
guard.replace(Arc::new(entry));
}
}
}
fn compute_index(&self, id: &I) -> usize {
let mut hasher = Djb2::new();
id.hash(&mut hasher);
hasher.finish() as usize % self.lines.len()
}
}
/// This algorithm was first mentioned by Daniel J. Bernstein in `comp.lang.c`
/// in the previous millennium.
/// See <http://www.cse.yorku.ca/~oz/hash.html#djb2> for details.
/// Whether it is a good choice for this use case and in 2023 is debatable;
/// i chose it for its relative simplicity in terms of both implementation
/// and seemingly low computational effort. I'm certain there are many better
/// alternatives these days, though.
struct Djb2 {
hash: u64,
}
impl Djb2 {
pub fn new() -> Djb2 {
Djb2 { hash: 5381 }
}
}
impl Hasher for Djb2 {
fn finish(&self) -> u64 {
self.hash
}
fn write(&mut self, bytes: &[u8]) {
// The website (see doc comment for the struct) says the original
// algorithm used addition instead of xor, but Bernstein likes the
// latter more. Let's trust him, he is a clever guy.
self.hash = bytes
.iter()
.fold(self.hash, |h, b| h.wrapping_mul(33) ^ (*b as u64));
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::*;
type TestId = Id<TestEntry>;
#[derive(Clone, Debug, PartialEq)]
struct TestEntry {
id: TestId,
data: i32,
}
impl Indexable<TestId> for TestEntry {
fn get_id(&self) -> TestId {
self.id
}
}
#[actix_web::test]
async fn store_stuff() {
let cache = MemCache::new();
for i in 0i64..1024i64 {
let id = i.into();
let entry = TestEntry { id, data: 0 };
cache.put(entry.clone()).await;
let retrieved_elem = cache.get(id).await.unwrap();
assert_eq!(retrieved_elem.as_ref(), &entry);
}
let mut had_entries = false;
for i in 0..1024 {
let id = i.into();
let entry = cache.del(id).await;
assert_eq!(cache.get(id).await, None);
if let Some(entry) = entry {
assert_eq!(entry.id, id);
had_entries = true;
}
}
assert!(had_entries);
}
#[actix_web::test]
async fn update_stuff() {
let id = 420.into();
let cache = MemCache::new();
cache.put(TestEntry { id, data: 69 }).await;
cache.update(id, |entry| entry.data = 1312).await;
let retrieved_elem = cache.get(id).await.unwrap();
assert_eq!(retrieved_elem.as_ref(), &TestEntry { id, data: 1312 });
}
#[actix_web::test]
#[should_panic]
async fn mutate_id() {
let id = 420.into();
let cache = MemCache::new();
cache.put(TestEntry { id, data: 69 }).await;
cache.update(id, |entry| entry.id = 1312.into()).await;
}
}