spin_key_value_redis/
store.rs

1use anyhow::{Context, Result};
2use redis::{aio::ConnectionManager, parse_redis_url, AsyncCommands, Client, RedisError};
3use spin_core::async_trait;
4use spin_factor_key_value::{log_error, Cas, Error, Store, StoreManager, SwapError};
5use std::sync::Arc;
6use tokio::sync::OnceCell;
7use url::Url;
8
9pub struct KeyValueRedis {
10    database_url: Url,
11    connection: OnceCell<ConnectionManager>,
12}
13
14impl KeyValueRedis {
15    pub fn new(address: String) -> Result<Self> {
16        let database_url = parse_redis_url(&address).context("Invalid Redis URL")?;
17
18        Ok(Self {
19            database_url,
20            connection: OnceCell::new(),
21        })
22    }
23}
24
25#[async_trait]
26impl StoreManager for KeyValueRedis {
27    async fn get(&self, _name: &str) -> Result<Arc<dyn Store>, Error> {
28        let connection = self
29            .connection
30            .get_or_try_init(|| async {
31                Client::open(self.database_url.clone())?
32                    .get_connection_manager()
33                    .await
34            })
35            .await
36            .map_err(log_error)?;
37
38        Ok(Arc::new(RedisStore {
39            connection: connection.clone(),
40            database_url: self.database_url.clone(),
41        }))
42    }
43
44    fn is_defined(&self, _store_name: &str) -> bool {
45        true
46    }
47
48    fn summary(&self, _store_name: &str) -> Option<String> {
49        let redis::ConnectionInfo { addr, .. } = self.database_url.as_str().parse().ok()?;
50        Some(format!("Redis at {addr}"))
51    }
52}
53
54struct RedisStore {
55    connection: ConnectionManager,
56    database_url: Url,
57}
58
59struct CompareAndSwap {
60    key: String,
61    connection: ConnectionManager,
62    bucket_rep: u32,
63}
64
65#[async_trait]
66impl Store for RedisStore {
67    async fn after_open(&self) -> Result<(), Error> {
68        if let Err(_error) = self.connection.clone().ping::<()>().await {
69            // If an IO error happens, ConnectionManager will start reconnection in the background
70            // so we do not take any action and just pray re-connection will be successful.
71        }
72        Ok(())
73    }
74
75    async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
76        self.connection.clone().get(key).await.map_err(log_error)
77    }
78
79    async fn set(&self, key: &str, value: &[u8]) -> Result<(), Error> {
80        self.connection
81            .clone()
82            .set(key, value)
83            .await
84            .map_err(log_error)
85    }
86
87    async fn delete(&self, key: &str) -> Result<(), Error> {
88        self.connection.clone().del(key).await.map_err(log_error)
89    }
90
91    async fn exists(&self, key: &str) -> Result<bool, Error> {
92        self.connection.clone().exists(key).await.map_err(log_error)
93    }
94
95    async fn get_keys(&self) -> Result<Vec<String>, Error> {
96        self.connection.clone().keys("*").await.map_err(log_error)
97    }
98
99    async fn get_many(&self, keys: Vec<String>) -> Result<Vec<(String, Option<Vec<u8>>)>, Error> {
100        self.connection.clone().keys(keys).await.map_err(log_error)
101    }
102
103    async fn set_many(&self, key_values: Vec<(String, Vec<u8>)>) -> Result<(), Error> {
104        self.connection
105            .clone()
106            .mset(&key_values)
107            .await
108            .map_err(log_error)
109    }
110
111    async fn delete_many(&self, keys: Vec<String>) -> Result<(), Error> {
112        self.connection.clone().del(keys).await.map_err(log_error)
113    }
114
115    async fn increment(&self, key: String, delta: i64) -> Result<i64, Error> {
116        self.connection
117            .clone()
118            .incr(key, delta)
119            .await
120            .map_err(log_error)
121    }
122
123    /// `new_compare_and_swap` builds a new CAS structure giving it its own connection since Redis
124    /// transactions are scoped to a connection and any WATCH should be dropped upon the drop of
125    /// the connection.
126    async fn new_compare_and_swap(
127        &self,
128        bucket_rep: u32,
129        key: &str,
130    ) -> Result<Arc<dyn Cas>, Error> {
131        let cx = Client::open(self.database_url.clone())
132            .map_err(log_error)?
133            .get_connection_manager()
134            .await
135            .map_err(log_error)?;
136
137        Ok(Arc::new(CompareAndSwap {
138            key: key.to_string(),
139            connection: cx,
140            bucket_rep,
141        }))
142    }
143}
144
145#[async_trait]
146impl Cas for CompareAndSwap {
147    /// current will initiate a transaction by WATCH'ing a key in Redis, and then returning the
148    /// current value for the key.
149    async fn current(&self) -> Result<Option<Vec<u8>>, Error> {
150        redis::cmd("WATCH")
151            .arg(&self.key)
152            .exec_async(&mut self.connection.clone())
153            .await
154            .map_err(log_error)?;
155        self.connection
156            .clone()
157            .get(&self.key)
158            .await
159            .map_err(log_error)
160    }
161
162    /// swap will set the key to the new value only if the key has not changed. Afterward, the
163    /// transaction will be terminated with an UNWATCH
164    async fn swap(&self, value: Vec<u8>) -> Result<(), SwapError> {
165        // Create transaction pipeline
166        let mut transaction = redis::pipe();
167        let res: Result<(), RedisError> = transaction
168            .atomic()
169            .set(&self.key, value)
170            .query_async(&mut self.connection.clone())
171            .await;
172
173        redis::cmd("UNWATCH")
174            .arg(&self.key)
175            .exec_async(&mut self.connection.clone())
176            .await
177            .map_err(|err| SwapError::CasFailed(format!("{err:?}")))?;
178
179        match res {
180            Ok(_) => Ok(()),
181            Err(err) => Err(SwapError::CasFailed(format!("{err:?}"))),
182        }
183    }
184
185    async fn bucket_rep(&self) -> u32 {
186        self.bucket_rep
187    }
188
189    async fn key(&self) -> String {
190        self.key.clone()
191    }
192}