1use anyhow::{Context, Result};
2use redis::{aio::ConnectionManager, parse_redis_url, AsyncCommands, Client, RedisError};
3use spin_core::async_trait;
4use spin_factor_key_value::{
5 log_error, log_error_v3, v3, Cas, Error, Store, StoreManager, SwapError,
6};
7use std::sync::Arc;
8use tokio::sync::OnceCell;
9use url::Url;
10
11pub struct KeyValueRedis {
12 database_url: Url,
13 connection: OnceCell<ConnectionManager>,
14}
15
16impl KeyValueRedis {
17 pub fn new(address: String) -> Result<Self> {
18 let database_url = parse_redis_url(&address).context("Invalid Redis URL")?;
19
20 Ok(Self {
21 database_url,
22 connection: OnceCell::new(),
23 })
24 }
25}
26
27#[async_trait]
28impl StoreManager for KeyValueRedis {
29 async fn get(&self, _name: &str) -> Result<Arc<dyn Store>, Error> {
30 let connection = self
31 .connection
32 .get_or_try_init(|| async {
33 Client::open(self.database_url.clone())?
34 .get_connection_manager()
35 .await
36 })
37 .await
38 .map_err(log_error)?;
39
40 Ok(Arc::new(RedisStore {
41 connection: connection.clone(),
42 database_url: self.database_url.clone(),
43 }))
44 }
45
46 fn is_defined(&self, _store_name: &str) -> bool {
47 true
48 }
49
50 fn summary(&self, _store_name: &str) -> Option<String> {
51 let redis::ConnectionInfo { addr, .. } = self.database_url.as_str().parse().ok()?;
52 Some(format!("Redis at {addr}"))
53 }
54}
55
56struct RedisStore {
57 connection: ConnectionManager,
58 database_url: Url,
59}
60
61struct CompareAndSwap {
62 key: String,
63 connection: ConnectionManager,
64 bucket_rep: u32,
65}
66
67#[async_trait]
68impl Store for RedisStore {
69 async fn after_open(&self) -> Result<(), Error> {
70 if let Err(_error) = self.connection.clone().ping::<()>().await {
71 }
74 Ok(())
75 }
76
77 async fn get(&self, key: &str, max_result_bytes: usize) -> Result<Option<Vec<u8>>, Error> {
78 let value = self
79 .connection
80 .clone()
81 .get::<_, Option<Vec<u8>>>(key)
82 .await
83 .map_err(log_error)?;
84
85 if std::mem::size_of::<Option<Vec<u8>>>() + value.as_ref().map(|v| v.len()).unwrap_or(0)
89 > max_result_bytes
90 {
91 Err(Error::Other(format!(
92 "query result exceeds limit of {max_result_bytes} bytes"
93 )))
94 } else {
95 Ok(value)
96 }
97 }
98
99 async fn set(&self, key: &str, value: &[u8]) -> Result<(), Error> {
100 self.connection
101 .clone()
102 .set(key, value)
103 .await
104 .map_err(log_error)
105 }
106
107 async fn delete(&self, key: &str) -> Result<(), Error> {
108 self.connection.clone().del(key).await.map_err(log_error)
109 }
110
111 async fn exists(&self, key: &str) -> Result<bool, Error> {
112 self.connection.clone().exists(key).await.map_err(log_error)
113 }
114
115 async fn get_keys(&self, max_result_bytes: usize) -> Result<Vec<String>, Error> {
116 _ = max_result_bytes;
119
120 let keys = self
121 .connection
122 .clone()
123 .keys::<_, Vec<String>>("*")
124 .await
125 .map_err(log_error)?;
126
127 if std::mem::size_of::<Vec<String>>()
131 + keys
132 .iter()
133 .map(|v| std::mem::size_of::<String>() + v.len())
134 .sum::<usize>()
135 > max_result_bytes
136 {
137 Err(Error::Other(format!(
138 "query result exceeds limit of {max_result_bytes} bytes"
139 )))
140 } else {
141 Ok(keys)
142 }
143 }
144
145 async fn get_keys_async(
146 &self,
147 max_result_bytes: usize,
148 ) -> (
149 tokio::sync::mpsc::Receiver<String>,
150 tokio::sync::oneshot::Receiver<Result<(), v3::Error>>,
151 ) {
152 let (keys_tx, keys_rx) = tokio::sync::mpsc::channel(4);
153 let (err_tx, err_rx) = tokio::sync::oneshot::channel();
154
155 let mut conn = self.connection.clone();
156
157 let the_work = async move {
158 let mut scan = conn.scan::<String>().await.map_err(log_error_v3)?;
159 loop {
160 match scan.next_item().await {
161 None => break,
162 Some(k) => {
163 if k.len() > max_result_bytes {
164 return Err(v3::Error::Other(format!(
165 "query result exceeds limit of {max_result_bytes} bytes"
166 )));
167 }
168 keys_tx.send(k).await.map_err(log_error_v3)?;
169 }
170 }
171 }
172 Ok(())
173 };
174 tokio::spawn(async move {
175 let res = the_work.await;
176 _ = err_tx.send(res);
177 });
178
179 (keys_rx, err_rx)
180 }
181
182 async fn get_many(
183 &self,
184 keys: Vec<String>,
185 max_result_bytes: usize,
186 ) -> Result<Vec<(String, Option<Vec<u8>>)>, Error> {
187 let values = self
188 .connection
189 .clone()
190 .keys::<_, Vec<(String, Option<Vec<u8>>)>>(keys)
191 .await
192 .map_err(log_error)?;
193
194 if std::mem::size_of::<Vec<(String, Option<Vec<u8>>)>>()
198 + values
199 .iter()
200 .map(|(k, v)| {
201 std::mem::size_of::<(String, Option<Vec<u8>>)>()
202 + k.len()
203 + v.as_ref().map(|v| v.len()).unwrap_or(0)
204 })
205 .sum::<usize>()
206 > max_result_bytes
207 {
208 Err(Error::Other(format!(
209 "query result exceeds limit of {max_result_bytes} bytes"
210 )))
211 } else {
212 Ok(values)
213 }
214 }
215
216 async fn set_many(&self, key_values: Vec<(String, Vec<u8>)>) -> Result<(), Error> {
217 self.connection
218 .clone()
219 .mset(&key_values)
220 .await
221 .map_err(log_error)
222 }
223
224 async fn delete_many(&self, keys: Vec<String>) -> Result<(), Error> {
225 self.connection.clone().del(keys).await.map_err(log_error)
226 }
227
228 async fn increment(&self, key: String, delta: i64) -> Result<i64, Error> {
229 self.connection
230 .clone()
231 .incr(key, delta)
232 .await
233 .map_err(log_error)
234 }
235
236 async fn new_compare_and_swap(
240 &self,
241 bucket_rep: u32,
242 key: &str,
243 ) -> Result<Arc<dyn Cas>, Error> {
244 let cx = Client::open(self.database_url.clone())
245 .map_err(log_error)?
246 .get_connection_manager()
247 .await
248 .map_err(log_error)?;
249
250 Ok(Arc::new(CompareAndSwap {
251 key: key.to_string(),
252 connection: cx,
253 bucket_rep,
254 }))
255 }
256}
257
258#[async_trait]
259impl Cas for CompareAndSwap {
260 async fn current(&self, max_result_bytes: usize) -> Result<Option<Vec<u8>>, Error> {
263 redis::cmd("WATCH")
264 .arg(&self.key)
265 .exec_async(&mut self.connection.clone())
266 .await
267 .map_err(log_error)?;
268 let value = self
269 .connection
270 .clone()
271 .get::<_, Option<Vec<u8>>>(&self.key)
272 .await
273 .map_err(log_error)?;
274
275 if std::mem::size_of::<Option<Vec<u8>>>() + value.as_ref().map(|v| v.len()).unwrap_or(0)
279 > max_result_bytes
280 {
281 Err(Error::Other(format!(
282 "query result exceeds limit of {max_result_bytes} bytes"
283 )))
284 } else {
285 Ok(value)
286 }
287 }
288
289 async fn swap(&self, value: Vec<u8>) -> Result<(), SwapError> {
292 let mut transaction = redis::pipe();
294 let res: Result<(), RedisError> = transaction
295 .atomic()
296 .set(&self.key, value)
297 .query_async(&mut self.connection.clone())
298 .await;
299
300 redis::cmd("UNWATCH")
301 .arg(&self.key)
302 .exec_async(&mut self.connection.clone())
303 .await
304 .map_err(|err| SwapError::CasFailed(format!("{err:?}")))?;
305
306 match res {
307 Ok(_) => Ok(()),
308 Err(err) => Err(SwapError::CasFailed(format!("{err:?}"))),
309 }
310 }
311
312 async fn bucket_rep(&self) -> u32 {
313 self.bucket_rep
314 }
315
316 async fn key(&self) -> String {
317 self.key.clone()
318 }
319}