1use anyhow::Result;
2use redis::{aio::MultiplexedConnection, AsyncCommands, FromRedisValue, Value};
3use spin_core::wasmtime::component::Resource;
4use spin_factor_outbound_networking::OutboundAllowedHosts;
5use spin_world::v1::{redis as v1, redis_types};
6use spin_world::v2::redis::{
7 self as v2, Connection as RedisConnection, Error, RedisParameter, RedisResult,
8};
9use tracing::field::Empty;
10use tracing::{instrument, Level};
11
12pub struct InstanceState {
13 pub allowed_hosts: OutboundAllowedHosts,
14 pub connections: spin_resource_table::Table<MultiplexedConnection>,
15}
16
17impl InstanceState {
18 async fn is_address_allowed(&self, address: &str) -> Result<bool> {
19 self.allowed_hosts.check_url(address, "redis").await
20 }
21
22 async fn establish_connection(
23 &mut self,
24 address: String,
25 ) -> Result<Resource<RedisConnection>, Error> {
26 let conn = redis::Client::open(address.as_str())
27 .map_err(|_| Error::InvalidAddress)?
28 .get_multiplexed_async_connection()
29 .await
30 .map_err(other_error)?;
31 self.connections
32 .push(conn)
33 .map(Resource::new_own)
34 .map_err(|_| Error::TooManyConnections)
35 }
36
37 async fn get_conn(
38 &mut self,
39 connection: Resource<RedisConnection>,
40 ) -> Result<&mut MultiplexedConnection, Error> {
41 self.connections
42 .get_mut(connection.rep())
43 .ok_or(Error::Other(
44 "could not find connection for resource".into(),
45 ))
46 }
47}
48
49impl v2::Host for crate::InstanceState {
50 fn convert_error(&mut self, error: Error) -> Result<Error> {
51 Ok(error)
52 }
53}
54
55impl v2::HostConnection for crate::InstanceState {
56 #[instrument(name = "spin_outbound_redis.open_connection", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", db.address = Empty, server.port = Empty, db.namespace = Empty))]
57 async fn open(&mut self, address: String) -> Result<Resource<RedisConnection>, Error> {
58 spin_factor_outbound_networking::record_address_fields(&address);
59
60 if !self
61 .is_address_allowed(&address)
62 .await
63 .map_err(|e| v2::Error::Other(e.to_string()))?
64 {
65 return Err(Error::InvalidAddress);
66 }
67
68 self.establish_connection(address).await
69 }
70
71 #[instrument(name = "spin_outbound_redis.publish", skip(self, connection, payload), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("PUBLISH {}", channel)))]
72 async fn publish(
73 &mut self,
74 connection: Resource<RedisConnection>,
75 channel: String,
76 payload: Vec<u8>,
77 ) -> Result<(), Error> {
78 let conn = self.get_conn(connection).await.map_err(other_error)?;
79 let () = conn
82 .publish(&channel, &payload)
83 .await
84 .map_err(other_error)?;
85 Ok(())
86 }
87
88 #[instrument(name = "spin_outbound_redis.get", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("GET {}", key)))]
89 async fn get(
90 &mut self,
91 connection: Resource<RedisConnection>,
92 key: String,
93 ) -> Result<Option<Vec<u8>>, Error> {
94 let conn = self.get_conn(connection).await.map_err(other_error)?;
95 let value = conn.get(&key).await.map_err(other_error)?;
96 Ok(value)
97 }
98
99 #[instrument(name = "spin_outbound_redis.set", skip(self, connection, value), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("SET {}", key)))]
100 async fn set(
101 &mut self,
102 connection: Resource<RedisConnection>,
103 key: String,
104 value: Vec<u8>,
105 ) -> Result<(), Error> {
106 let conn = self.get_conn(connection).await.map_err(other_error)?;
107 let () = conn.set(&key, &value).await.map_err(other_error)?;
110 Ok(())
111 }
112
113 #[instrument(name = "spin_outbound_redis.incr", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("INCRBY {} 1", key)))]
114 async fn incr(
115 &mut self,
116 connection: Resource<RedisConnection>,
117 key: String,
118 ) -> Result<i64, Error> {
119 let conn = self.get_conn(connection).await.map_err(other_error)?;
120 let value = conn.incr(&key, 1).await.map_err(other_error)?;
121 Ok(value)
122 }
123
124 #[instrument(name = "spin_outbound_redis.del", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("DEL {}", keys.join(" "))))]
125 async fn del(
126 &mut self,
127 connection: Resource<RedisConnection>,
128 keys: Vec<String>,
129 ) -> Result<u32, Error> {
130 let conn = self.get_conn(connection).await.map_err(other_error)?;
131 let value = conn.del(&keys).await.map_err(other_error)?;
132 Ok(value)
133 }
134
135 #[instrument(name = "spin_outbound_redis.sadd", skip(self, connection, values), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("SADD {} {}", key, values.join(" "))))]
136 async fn sadd(
137 &mut self,
138 connection: Resource<RedisConnection>,
139 key: String,
140 values: Vec<String>,
141 ) -> Result<u32, Error> {
142 let conn = self.get_conn(connection).await.map_err(other_error)?;
143 let value = conn.sadd(&key, &values).await.map_err(|e| {
144 if e.kind() == redis::ErrorKind::TypeError {
145 Error::TypeError
146 } else {
147 Error::Other(e.to_string())
148 }
149 })?;
150 Ok(value)
151 }
152
153 #[instrument(name = "spin_outbound_redis.smembers", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("SMEMBERS {}", key)))]
154 async fn smembers(
155 &mut self,
156 connection: Resource<RedisConnection>,
157 key: String,
158 ) -> Result<Vec<String>, Error> {
159 let conn = self.get_conn(connection).await.map_err(other_error)?;
160 let value = conn.smembers(&key).await.map_err(other_error)?;
161 Ok(value)
162 }
163
164 #[instrument(name = "spin_outbound_redis.srem", skip(self, connection, values), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("SREM {} {}", key, values.join(" "))))]
165 async fn srem(
166 &mut self,
167 connection: Resource<RedisConnection>,
168 key: String,
169 values: Vec<String>,
170 ) -> Result<u32, Error> {
171 let conn = self.get_conn(connection).await.map_err(other_error)?;
172 let value = conn.srem(&key, &values).await.map_err(other_error)?;
173 Ok(value)
174 }
175
176 #[instrument(name = "spin_outbound_redis.execute", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("{}", command)))]
177 async fn execute(
178 &mut self,
179 connection: Resource<RedisConnection>,
180 command: String,
181 arguments: Vec<RedisParameter>,
182 ) -> Result<Vec<RedisResult>, Error> {
183 let conn = self.get_conn(connection).await?;
184 let mut cmd = redis::cmd(&command);
185 arguments.iter().for_each(|value| match value {
186 RedisParameter::Int64(v) => {
187 cmd.arg(v);
188 }
189 RedisParameter::Binary(v) => {
190 cmd.arg(v);
191 }
192 });
193
194 cmd.query_async::<_, RedisResults>(conn)
195 .await
196 .map(|values| values.0)
197 .map_err(other_error)
198 }
199
200 async fn drop(&mut self, connection: Resource<RedisConnection>) -> anyhow::Result<()> {
201 self.connections.remove(connection.rep());
202 Ok(())
203 }
204}
205
206fn other_error(e: impl std::fmt::Display) -> Error {
207 Error::Other(e.to_string())
208}
209
210macro_rules! delegate {
212 ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{
213 if !$self.is_address_allowed(&$address).await.map_err(|_| v1::Error::Error)? {
214 return Err(v1::Error::Error);
215 }
216 let connection = match $self.establish_connection($address).await {
217 Ok(c) => c,
218 Err(_) => return Err(v1::Error::Error),
219 };
220 <Self as v2::HostConnection>::$name($self, connection, $($arg),*)
221 .await
222 .map_err(|_| v1::Error::Error)
223 }};
224}
225
226impl v1::Host for crate::InstanceState {
227 async fn publish(
228 &mut self,
229 address: String,
230 channel: String,
231 payload: Vec<u8>,
232 ) -> Result<(), v1::Error> {
233 delegate!(self.publish(address, channel, payload))
234 }
235
236 async fn get(&mut self, address: String, key: String) -> Result<Vec<u8>, v1::Error> {
237 delegate!(self.get(address, key)).map(|v| v.unwrap_or_default())
238 }
239
240 async fn set(&mut self, address: String, key: String, value: Vec<u8>) -> Result<(), v1::Error> {
241 delegate!(self.set(address, key, value))
242 }
243
244 async fn incr(&mut self, address: String, key: String) -> Result<i64, v1::Error> {
245 delegate!(self.incr(address, key))
246 }
247
248 async fn del(&mut self, address: String, keys: Vec<String>) -> Result<i64, v1::Error> {
249 delegate!(self.del(address, keys)).map(|v| v as i64)
250 }
251
252 async fn sadd(
253 &mut self,
254 address: String,
255 key: String,
256 values: Vec<String>,
257 ) -> Result<i64, v1::Error> {
258 delegate!(self.sadd(address, key, values)).map(|v| v as i64)
259 }
260
261 async fn smembers(&mut self, address: String, key: String) -> Result<Vec<String>, v1::Error> {
262 delegate!(self.smembers(address, key))
263 }
264
265 async fn srem(
266 &mut self,
267 address: String,
268 key: String,
269 values: Vec<String>,
270 ) -> Result<i64, v1::Error> {
271 delegate!(self.srem(address, key, values)).map(|v| v as i64)
272 }
273
274 async fn execute(
275 &mut self,
276 address: String,
277 command: String,
278 arguments: Vec<v1::RedisParameter>,
279 ) -> Result<Vec<v1::RedisResult>, v1::Error> {
280 delegate!(self.execute(
281 address,
282 command,
283 arguments.into_iter().map(Into::into).collect()
284 ))
285 .map(|v| v.into_iter().map(Into::into).collect())
286 }
287}
288
289impl redis_types::Host for crate::InstanceState {
290 fn convert_error(&mut self, error: redis_types::Error) -> Result<redis_types::Error> {
291 Ok(error)
292 }
293}
294
295struct RedisResults(Vec<RedisResult>);
296
297impl FromRedisValue for RedisResults {
298 fn from_redis_value(value: &Value) -> redis::RedisResult<Self> {
299 fn append(values: &mut Vec<RedisResult>, value: &Value) {
300 match value {
301 Value::Nil | Value::Okay => (),
302 Value::Int(v) => values.push(RedisResult::Int64(*v)),
303 Value::Data(bytes) => values.push(RedisResult::Binary(bytes.to_owned())),
304 Value::Bulk(bulk) => bulk.iter().for_each(|value| append(values, value)),
305 Value::Status(message) => values.push(RedisResult::Status(message.to_owned())),
306 }
307 }
308
309 let mut values = Vec::new();
310 append(&mut values, value);
311 Ok(RedisResults(values))
312 }
313}