spin_key_value_redis/
store.rs1use anyhow::{Context, Result};
2use redis::{aio::ConnectionManager, parse_redis_url, AsyncCommands, Client, RedisError};
3use spin_core::async_trait;
4use spin_factor_key_value::{log_error, Cas, Error, Store, StoreManager, SwapError};
5use std::sync::Arc;
6use tokio::sync::OnceCell;
7use url::Url;
8
9pub struct KeyValueRedis {
10 database_url: Url,
11 connection: OnceCell<ConnectionManager>,
12}
13
14impl KeyValueRedis {
15 pub fn new(address: String) -> Result<Self> {
16 let database_url = parse_redis_url(&address).context("Invalid Redis URL")?;
17
18 Ok(Self {
19 database_url,
20 connection: OnceCell::new(),
21 })
22 }
23}
24
25#[async_trait]
26impl StoreManager for KeyValueRedis {
27 async fn get(&self, _name: &str) -> Result<Arc<dyn Store>, Error> {
28 let connection = self
29 .connection
30 .get_or_try_init(|| async {
31 Client::open(self.database_url.clone())?
32 .get_connection_manager()
33 .await
34 })
35 .await
36 .map_err(log_error)?;
37
38 Ok(Arc::new(RedisStore {
39 connection: connection.clone(),
40 database_url: self.database_url.clone(),
41 }))
42 }
43
44 fn is_defined(&self, _store_name: &str) -> bool {
45 true
46 }
47
48 fn summary(&self, _store_name: &str) -> Option<String> {
49 let redis::ConnectionInfo { addr, .. } = self.database_url.as_str().parse().ok()?;
50 Some(format!("Redis at {addr}"))
51 }
52}
53
54struct RedisStore {
55 connection: ConnectionManager,
56 database_url: Url,
57}
58
59struct CompareAndSwap {
60 key: String,
61 connection: ConnectionManager,
62 bucket_rep: u32,
63}
64
65#[async_trait]
66impl Store for RedisStore {
67 async fn after_open(&self) -> Result<(), Error> {
68 if let Err(_error) = self.connection.clone().ping::<()>().await {
69 }
72 Ok(())
73 }
74
75 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
76 self.connection.clone().get(key).await.map_err(log_error)
77 }
78
79 async fn set(&self, key: &str, value: &[u8]) -> Result<(), Error> {
80 self.connection
81 .clone()
82 .set(key, value)
83 .await
84 .map_err(log_error)
85 }
86
87 async fn delete(&self, key: &str) -> Result<(), Error> {
88 self.connection.clone().del(key).await.map_err(log_error)
89 }
90
91 async fn exists(&self, key: &str) -> Result<bool, Error> {
92 self.connection.clone().exists(key).await.map_err(log_error)
93 }
94
95 async fn get_keys(&self) -> Result<Vec<String>, Error> {
96 self.connection.clone().keys("*").await.map_err(log_error)
97 }
98
99 async fn get_many(&self, keys: Vec<String>) -> Result<Vec<(String, Option<Vec<u8>>)>, Error> {
100 self.connection.clone().keys(keys).await.map_err(log_error)
101 }
102
103 async fn set_many(&self, key_values: Vec<(String, Vec<u8>)>) -> Result<(), Error> {
104 self.connection
105 .clone()
106 .mset(&key_values)
107 .await
108 .map_err(log_error)
109 }
110
111 async fn delete_many(&self, keys: Vec<String>) -> Result<(), Error> {
112 self.connection.clone().del(keys).await.map_err(log_error)
113 }
114
115 async fn increment(&self, key: String, delta: i64) -> Result<i64, Error> {
116 self.connection
117 .clone()
118 .incr(key, delta)
119 .await
120 .map_err(log_error)
121 }
122
123 async fn new_compare_and_swap(
127 &self,
128 bucket_rep: u32,
129 key: &str,
130 ) -> Result<Arc<dyn Cas>, Error> {
131 let cx = Client::open(self.database_url.clone())
132 .map_err(log_error)?
133 .get_connection_manager()
134 .await
135 .map_err(log_error)?;
136
137 Ok(Arc::new(CompareAndSwap {
138 key: key.to_string(),
139 connection: cx,
140 bucket_rep,
141 }))
142 }
143}
144
145#[async_trait]
146impl Cas for CompareAndSwap {
147 async fn current(&self) -> Result<Option<Vec<u8>>, Error> {
150 redis::cmd("WATCH")
151 .arg(&self.key)
152 .exec_async(&mut self.connection.clone())
153 .await
154 .map_err(log_error)?;
155 self.connection
156 .clone()
157 .get(&self.key)
158 .await
159 .map_err(log_error)
160 }
161
162 async fn swap(&self, value: Vec<u8>) -> Result<(), SwapError> {
165 let mut transaction = redis::pipe();
167 let res: Result<(), RedisError> = transaction
168 .atomic()
169 .set(&self.key, value)
170 .query_async(&mut self.connection.clone())
171 .await;
172
173 redis::cmd("UNWATCH")
174 .arg(&self.key)
175 .exec_async(&mut self.connection.clone())
176 .await
177 .map_err(|err| SwapError::CasFailed(format!("{err:?}")))?;
178
179 match res {
180 Ok(_) => Ok(()),
181 Err(err) => Err(SwapError::CasFailed(format!("{err:?}"))),
182 }
183 }
184
185 async fn bucket_rep(&self) -> u32 {
186 self.bucket_rep
187 }
188
189 async fn key(&self) -> String {
190 self.key.clone()
191 }
192}