use crate::{Cas, Error, Store, StoreManager, SwapError};
use lru::LruCache;
use spin_core::async_trait;
use std::{
collections::{HashMap, HashSet},
future::Future,
num::NonZeroUsize,
sync::Arc,
};
use tokio::{
sync::Mutex as AsyncMutex,
task::{self, JoinHandle},
};
use tracing::Instrument;
pub struct DelegatingStoreManager {
delegates: HashMap<String, Arc<dyn StoreManager>>,
}
impl DelegatingStoreManager {
pub fn new(delegates: impl IntoIterator<Item = (String, Arc<dyn StoreManager>)>) -> Self {
let delegates = delegates.into_iter().collect();
Self { delegates }
}
}
#[async_trait]
impl StoreManager for DelegatingStoreManager {
async fn get(&self, name: &str) -> Result<Arc<dyn Store>, Error> {
match self.delegates.get(name) {
Some(store) => store.get(name).await,
None => Err(Error::NoSuchStore),
}
}
fn is_defined(&self, store_name: &str) -> bool {
self.delegates.contains_key(store_name)
}
fn summary(&self, store_name: &str) -> Option<String> {
if let Some(store) = self.delegates.get(store_name) {
return store.summary(store_name);
}
None
}
}
pub struct CachingStoreManager<T> {
capacity: NonZeroUsize,
inner: T,
}
const DEFAULT_CACHE_SIZE: usize = 256;
impl<T> CachingStoreManager<T> {
pub fn new(inner: T) -> Self {
Self::new_with_capacity(NonZeroUsize::new(DEFAULT_CACHE_SIZE).unwrap(), inner)
}
pub fn new_with_capacity(capacity: NonZeroUsize, inner: T) -> Self {
Self { capacity, inner }
}
}
#[async_trait]
impl<T: StoreManager> StoreManager for CachingStoreManager<T> {
async fn get(&self, name: &str) -> Result<Arc<dyn Store>, Error> {
Ok(Arc::new(CachingStore {
inner: self.inner.get(name).await?,
state: Arc::new(AsyncMutex::new(CachingStoreState {
cache: LruCache::new(self.capacity),
previous_task: None,
})),
}))
}
fn is_defined(&self, store_name: &str) -> bool {
self.inner.is_defined(store_name)
}
fn summary(&self, store_name: &str) -> Option<String> {
self.inner.summary(store_name)
}
}
struct CachingStoreState {
cache: LruCache<String, Option<Vec<u8>>>,
previous_task: Option<JoinHandle<Result<(), Error>>>,
}
impl CachingStoreState {
fn spawn(&mut self, task: impl Future<Output = Result<(), Error>> + Send + 'static) {
let previous_task = self.previous_task.take();
let task = async move {
if let Some(previous_task) = previous_task {
previous_task
.await
.map_err(|e| Error::Other(format!("{e:?}")))??
}
task.await
};
self.previous_task = Some(task::spawn(task.in_current_span()))
}
async fn flush(&mut self) -> Result<(), Error> {
if let Some(previous_task) = self.previous_task.take() {
previous_task
.await
.map_err(|e| Error::Other(format!("{e:?}")))??
}
Ok(())
}
}
struct CachingStore {
inner: Arc<dyn Store>,
state: Arc<AsyncMutex<CachingStoreState>>,
}
#[async_trait]
impl Store for CachingStore {
async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
let mut state = self.state.lock().await;
if let Some(value) = state.cache.get(key).cloned() {
return Ok(value);
}
state.flush().await?;
let value = self.inner.get(key).await?;
state.cache.put(key.to_owned(), value.clone());
Ok(value)
}
async fn set(&self, key: &str, value: &[u8]) -> Result<(), Error> {
let mut state = self.state.lock().await;
state.cache.put(key.to_owned(), Some(value.to_owned()));
let inner = self.inner.clone();
let key = key.to_owned();
let value = value.to_owned();
state.spawn(async move { inner.set(&key, &value).await });
Ok(())
}
async fn delete(&self, key: &str) -> Result<(), Error> {
let mut state = self.state.lock().await;
state.cache.put(key.to_owned(), None);
let inner = self.inner.clone();
let key = key.to_owned();
state.spawn(async move { inner.delete(&key).await });
Ok(())
}
async fn exists(&self, key: &str) -> Result<bool, Error> {
Ok(self.get(key).await?.is_some())
}
async fn get_keys(&self) -> Result<Vec<String>, Error> {
let mut state = self.state.lock().await;
state.flush().await?;
Ok(self
.inner
.get_keys()
.await?
.into_iter()
.filter(|k| {
state
.cache
.peek(k)
.map(|v| v.as_ref().is_some())
.unwrap_or(true)
})
.chain(
state
.cache
.iter()
.filter_map(|(k, v)| v.as_ref().map(|_| k.to_owned())),
)
.collect::<HashSet<_>>()
.into_iter()
.collect())
}
async fn get_many(
&self,
keys: Vec<String>,
) -> anyhow::Result<Vec<(String, Option<Vec<u8>>)>, Error> {
let mut state = self.state.lock().await;
let mut found: Vec<(String, Option<Vec<u8>>)> = Vec::new();
let mut not_found: Vec<String> = Vec::new();
for key in keys {
match state.cache.get(key.as_str()) {
Some(Some(value)) => found.push((key, Some(value.clone()))),
_ => not_found.push(key),
}
}
let keys_and_values = self.inner.get_many(not_found).await?;
for (key, value) in keys_and_values {
found.push((key.clone(), value.clone()));
state.cache.put(key, value);
}
Ok(found)
}
async fn set_many(&self, key_values: Vec<(String, Vec<u8>)>) -> anyhow::Result<(), Error> {
let mut state = self.state.lock().await;
for (key, value) in key_values.clone() {
state.cache.put(key, Some(value));
}
self.inner.set_many(key_values).await
}
async fn delete_many(&self, keys: Vec<String>) -> anyhow::Result<(), Error> {
let mut state = self.state.lock().await;
for key in keys.clone() {
state.cache.put(key, None);
}
self.inner.delete_many(keys).await
}
async fn increment(&self, key: String, delta: i64) -> anyhow::Result<i64, Error> {
let mut state = self.state.lock().await;
let counter = self.inner.increment(key.clone(), delta).await?;
state
.cache
.put(key, Some(i64::to_le_bytes(counter).to_vec()));
Ok(counter)
}
async fn new_compare_and_swap(
&self,
bucket_rep: u32,
key: &str,
) -> anyhow::Result<Arc<dyn Cas>, Error> {
let inner = self.inner.new_compare_and_swap(bucket_rep, key).await?;
Ok(Arc::new(CompareAndSwap {
bucket_rep,
state: self.state.clone(),
key: key.to_string(),
inner_cas: inner,
}))
}
}
struct CompareAndSwap {
bucket_rep: u32,
key: String,
state: Arc<AsyncMutex<CachingStoreState>>,
inner_cas: Arc<dyn Cas>,
}
#[async_trait]
impl Cas for CompareAndSwap {
async fn current(&self) -> anyhow::Result<Option<Vec<u8>>, Error> {
let mut state = self.state.lock().await;
state.flush().await?;
let res = self.inner_cas.current().await;
match res.clone() {
Ok(value) => {
state.cache.put(self.key.clone(), value.clone());
state.flush().await?;
Ok(value)
}
Err(err) => Err(err),
}?;
res
}
async fn swap(&self, value: Vec<u8>) -> anyhow::Result<(), SwapError> {
let mut state = self.state.lock().await;
state
.flush()
.await
.map_err(|_e| SwapError::Other("failed flushing".to_string()))?;
let res = self.inner_cas.swap(value.clone()).await;
match res {
Ok(()) => {
state.cache.put(self.key.clone(), Some(value));
state
.flush()
.await
.map_err(|_e| SwapError::Other("failed flushing".to_string()))?;
Ok(())
}
Err(err) => Err(err),
}
}
async fn bucket_rep(&self) -> u32 {
self.bucket_rep
}
async fn key(&self) -> String {
self.key.clone()
}
}