Skip to main content

spin_key_value_redis/
store.rs

1use anyhow::{Context, Result};
2use redis::{aio::ConnectionManager, parse_redis_url, AsyncCommands, Client, RedisError};
3use spin_core::async_trait;
4use spin_factor_key_value::{
5    log_error, log_error_v3, v3, Cas, Error, Store, StoreManager, SwapError,
6};
7use std::sync::Arc;
8use tokio::sync::OnceCell;
9use url::Url;
10
11pub struct KeyValueRedis {
12    database_url: Url,
13    connection: OnceCell<ConnectionManager>,
14}
15
16impl KeyValueRedis {
17    pub fn new(address: String) -> Result<Self> {
18        let database_url = parse_redis_url(&address).context("Invalid Redis URL")?;
19
20        Ok(Self {
21            database_url,
22            connection: OnceCell::new(),
23        })
24    }
25}
26
27#[async_trait]
28impl StoreManager for KeyValueRedis {
29    async fn get(&self, _name: &str) -> Result<Arc<dyn Store>, Error> {
30        let connection = self
31            .connection
32            .get_or_try_init(|| async {
33                Client::open(self.database_url.clone())?
34                    .get_connection_manager()
35                    .await
36            })
37            .await
38            .map_err(log_error)?;
39
40        Ok(Arc::new(RedisStore {
41            connection: connection.clone(),
42            database_url: self.database_url.clone(),
43        }))
44    }
45
46    fn is_defined(&self, _store_name: &str) -> bool {
47        true
48    }
49
50    fn summary(&self, _store_name: &str) -> Option<String> {
51        let redis::ConnectionInfo { addr, .. } = self.database_url.as_str().parse().ok()?;
52        Some(format!("Redis at {addr}"))
53    }
54}
55
56struct RedisStore {
57    connection: ConnectionManager,
58    database_url: Url,
59}
60
61struct CompareAndSwap {
62    key: String,
63    connection: ConnectionManager,
64    bucket_rep: u32,
65}
66
67#[async_trait]
68impl Store for RedisStore {
69    async fn after_open(&self) -> Result<(), Error> {
70        if let Err(_error) = self.connection.clone().ping::<()>().await {
71            // If an IO error happens, ConnectionManager will start reconnection in the background
72            // so we do not take any action and just pray re-connection will be successful.
73        }
74        Ok(())
75    }
76
77    async fn get(&self, key: &str, max_result_bytes: usize) -> Result<Option<Vec<u8>>, Error> {
78        let value = self
79            .connection
80            .clone()
81            .get::<_, Option<Vec<u8>>>(key)
82            .await
83            .map_err(log_error)?;
84
85        // Currently there's no way to stream a `GET` result using the `redis`
86        // crate without buffering, so the damage (in terms of host memory
87        // usage) is already done, but we can still enforce the limit:
88        if std::mem::size_of::<Option<Vec<u8>>>() + value.as_ref().map(|v| v.len()).unwrap_or(0)
89            > max_result_bytes
90        {
91            Err(Error::Other(format!(
92                "query result exceeds limit of {max_result_bytes} bytes"
93            )))
94        } else {
95            Ok(value)
96        }
97    }
98
99    async fn set(&self, key: &str, value: &[u8]) -> Result<(), Error> {
100        self.connection
101            .clone()
102            .set(key, value)
103            .await
104            .map_err(log_error)
105    }
106
107    async fn delete(&self, key: &str) -> Result<(), Error> {
108        self.connection.clone().del(key).await.map_err(log_error)
109    }
110
111    async fn exists(&self, key: &str) -> Result<bool, Error> {
112        self.connection.clone().exists(key).await.map_err(log_error)
113    }
114
115    async fn get_keys(&self, max_result_bytes: usize) -> Result<Vec<String>, Error> {
116        // There's currently no way to limit buffering for `KEYS` commands using
117        // the `redis` crate, so we can only ignore this:
118        _ = max_result_bytes;
119
120        let keys = self
121            .connection
122            .clone()
123            .keys::<_, Vec<String>>("*")
124            .await
125            .map_err(log_error)?;
126
127        // Currently there's no way to stream a `KEYS` result using the `redis`
128        // crate without buffering, so the damage (in terms of host memory
129        // usage) is already done, but we can still enforce the limit:
130        if std::mem::size_of::<Vec<String>>()
131            + keys
132                .iter()
133                .map(|v| std::mem::size_of::<String>() + v.len())
134                .sum::<usize>()
135            > max_result_bytes
136        {
137            Err(Error::Other(format!(
138                "query result exceeds limit of {max_result_bytes} bytes"
139            )))
140        } else {
141            Ok(keys)
142        }
143    }
144
145    async fn get_keys_async(
146        &self,
147        max_result_bytes: usize,
148    ) -> (
149        tokio::sync::mpsc::Receiver<String>,
150        tokio::sync::oneshot::Receiver<Result<(), v3::Error>>,
151    ) {
152        let (keys_tx, keys_rx) = tokio::sync::mpsc::channel(4);
153        let (err_tx, err_rx) = tokio::sync::oneshot::channel();
154
155        let mut conn = self.connection.clone();
156
157        let the_work = async move {
158            let mut scan = conn.scan::<String>().await.map_err(log_error_v3)?;
159            loop {
160                match scan.next_item().await {
161                    None => break,
162                    Some(k) => {
163                        if k.len() > max_result_bytes {
164                            return Err(v3::Error::Other(format!(
165                                "query result exceeds limit of {max_result_bytes} bytes"
166                            )));
167                        }
168                        keys_tx.send(k).await.map_err(log_error_v3)?;
169                    }
170                }
171            }
172            Ok(())
173        };
174        tokio::spawn(async move {
175            let res = the_work.await;
176            _ = err_tx.send(res);
177        });
178
179        (keys_rx, err_rx)
180    }
181
182    async fn get_many(
183        &self,
184        keys: Vec<String>,
185        max_result_bytes: usize,
186    ) -> Result<Vec<(String, Option<Vec<u8>>)>, Error> {
187        let values = self
188            .connection
189            .clone()
190            .keys::<_, Vec<(String, Option<Vec<u8>>)>>(keys)
191            .await
192            .map_err(log_error)?;
193
194        // Currently there's no way to stream a `GET` result using the `redis`
195        // crate without buffering, so the damage (in terms of host memory
196        // usage) is already done, but we can still enforce the limit:
197        if std::mem::size_of::<Vec<(String, Option<Vec<u8>>)>>()
198            + values
199                .iter()
200                .map(|(k, v)| {
201                    std::mem::size_of::<(String, Option<Vec<u8>>)>()
202                        + k.len()
203                        + v.as_ref().map(|v| v.len()).unwrap_or(0)
204                })
205                .sum::<usize>()
206            > max_result_bytes
207        {
208            Err(Error::Other(format!(
209                "query result exceeds limit of {max_result_bytes} bytes"
210            )))
211        } else {
212            Ok(values)
213        }
214    }
215
216    async fn set_many(&self, key_values: Vec<(String, Vec<u8>)>) -> Result<(), Error> {
217        self.connection
218            .clone()
219            .mset(&key_values)
220            .await
221            .map_err(log_error)
222    }
223
224    async fn delete_many(&self, keys: Vec<String>) -> Result<(), Error> {
225        self.connection.clone().del(keys).await.map_err(log_error)
226    }
227
228    async fn increment(&self, key: String, delta: i64) -> Result<i64, Error> {
229        self.connection
230            .clone()
231            .incr(key, delta)
232            .await
233            .map_err(log_error)
234    }
235
236    /// `new_compare_and_swap` builds a new CAS structure giving it its own connection since Redis
237    /// transactions are scoped to a connection and any WATCH should be dropped upon the drop of
238    /// the connection.
239    async fn new_compare_and_swap(
240        &self,
241        bucket_rep: u32,
242        key: &str,
243    ) -> Result<Arc<dyn Cas>, Error> {
244        let cx = Client::open(self.database_url.clone())
245            .map_err(log_error)?
246            .get_connection_manager()
247            .await
248            .map_err(log_error)?;
249
250        Ok(Arc::new(CompareAndSwap {
251            key: key.to_string(),
252            connection: cx,
253            bucket_rep,
254        }))
255    }
256}
257
258#[async_trait]
259impl Cas for CompareAndSwap {
260    /// current will initiate a transaction by WATCH'ing a key in Redis, and then returning the
261    /// current value for the key.
262    async fn current(&self, max_result_bytes: usize) -> Result<Option<Vec<u8>>, Error> {
263        redis::cmd("WATCH")
264            .arg(&self.key)
265            .exec_async(&mut self.connection.clone())
266            .await
267            .map_err(log_error)?;
268        let value = self
269            .connection
270            .clone()
271            .get::<_, Option<Vec<u8>>>(&self.key)
272            .await
273            .map_err(log_error)?;
274
275        // Currently there's no way to stream a `WATCH` result using the `redis`
276        // crate without buffering, so the damage (in terms of host memory
277        // usage) is already done, but we can still enforce the limit:
278        if std::mem::size_of::<Option<Vec<u8>>>() + value.as_ref().map(|v| v.len()).unwrap_or(0)
279            > max_result_bytes
280        {
281            Err(Error::Other(format!(
282                "query result exceeds limit of {max_result_bytes} bytes"
283            )))
284        } else {
285            Ok(value)
286        }
287    }
288
289    /// swap will set the key to the new value only if the key has not changed. Afterward, the
290    /// transaction will be terminated with an UNWATCH
291    async fn swap(&self, value: Vec<u8>) -> Result<(), SwapError> {
292        // Create transaction pipeline
293        let mut transaction = redis::pipe();
294        let res: Result<(), RedisError> = transaction
295            .atomic()
296            .set(&self.key, value)
297            .query_async(&mut self.connection.clone())
298            .await;
299
300        redis::cmd("UNWATCH")
301            .arg(&self.key)
302            .exec_async(&mut self.connection.clone())
303            .await
304            .map_err(|err| SwapError::CasFailed(format!("{err:?}")))?;
305
306        match res {
307            Ok(_) => Ok(()),
308            Err(err) => Err(SwapError::CasFailed(format!("{err:?}"))),
309        }
310    }
311
312    async fn bucket_rep(&self) -> u32 {
313        self.bucket_rep
314    }
315
316    async fn key(&self) -> String {
317        self.key.clone()
318    }
319}