1use std::net::SocketAddr;
2
3use anyhow::Result;
4use redis::io::AsyncDNSResolver;
5use redis::AsyncConnectionConfig;
6use redis::{aio::MultiplexedConnection, AsyncCommands, FromRedisValue, Value};
7use spin_core::wasmtime::component::Resource;
8use spin_factor_outbound_networking::config::allowed_hosts::OutboundAllowedHosts;
9use spin_factor_outbound_networking::config::blocked_networks::BlockedNetworks;
10use spin_world::v1::{redis as v1, redis_types};
11use spin_world::v2::redis::{
12 self as v2, Connection as RedisConnection, Error, RedisParameter, RedisResult,
13};
14use tracing::field::Empty;
15use tracing::{instrument, Level};
16
17pub struct InstanceState {
18 pub allowed_hosts: OutboundAllowedHosts,
19 pub blocked_networks: BlockedNetworks,
20 pub connections: spin_resource_table::Table<MultiplexedConnection>,
21}
22
23impl InstanceState {
24 async fn is_address_allowed(&self, address: &str) -> Result<bool> {
25 self.allowed_hosts.check_url(address, "redis").await
26 }
27
28 async fn establish_connection(
29 &mut self,
30 address: String,
31 ) -> Result<Resource<RedisConnection>, Error> {
32 let config = AsyncConnectionConfig::new()
33 .set_dns_resolver(SpinDnsResolver(self.blocked_networks.clone()));
34 let conn = redis::Client::open(address.as_str())
35 .map_err(|_| Error::InvalidAddress)?
36 .get_multiplexed_async_connection_with_config(&config)
37 .await
38 .map_err(other_error)?;
39 self.connections
40 .push(conn)
41 .map(Resource::new_own)
42 .map_err(|_| Error::TooManyConnections)
43 }
44
45 async fn get_conn(
46 &mut self,
47 connection: Resource<RedisConnection>,
48 ) -> Result<&mut MultiplexedConnection, Error> {
49 self.connections
50 .get_mut(connection.rep())
51 .ok_or(Error::Other(
52 "could not find connection for resource".into(),
53 ))
54 }
55}
56
57impl v2::Host for crate::InstanceState {
58 fn convert_error(&mut self, error: Error) -> Result<Error> {
59 Ok(error)
60 }
61}
62
63impl v2::HostConnection for crate::InstanceState {
64 #[instrument(name = "spin_outbound_redis.open_connection", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", db.address = Empty, server.port = Empty, db.namespace = Empty))]
65 async fn open(&mut self, address: String) -> Result<Resource<RedisConnection>, Error> {
66 spin_factor_outbound_networking::record_address_fields(&address);
67
68 if !self
69 .is_address_allowed(&address)
70 .await
71 .map_err(|e| v2::Error::Other(e.to_string()))?
72 {
73 return Err(Error::InvalidAddress);
74 }
75
76 self.establish_connection(address).await
77 }
78
79 #[instrument(name = "spin_outbound_redis.publish", skip(self, connection, payload), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("PUBLISH {}", channel)))]
80 async fn publish(
81 &mut self,
82 connection: Resource<RedisConnection>,
83 channel: String,
84 payload: Vec<u8>,
85 ) -> Result<(), Error> {
86 let conn = self.get_conn(connection).await.map_err(other_error)?;
87 let () = conn
90 .publish(&channel, &payload)
91 .await
92 .map_err(other_error)?;
93 Ok(())
94 }
95
96 #[instrument(name = "spin_outbound_redis.get", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("GET {}", key)))]
97 async fn get(
98 &mut self,
99 connection: Resource<RedisConnection>,
100 key: String,
101 ) -> Result<Option<Vec<u8>>, Error> {
102 let conn = self.get_conn(connection).await.map_err(other_error)?;
103 let value = conn.get(&key).await.map_err(other_error)?;
104 Ok(value)
105 }
106
107 #[instrument(name = "spin_outbound_redis.set", skip(self, connection, value), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("SET {}", key)))]
108 async fn set(
109 &mut self,
110 connection: Resource<RedisConnection>,
111 key: String,
112 value: Vec<u8>,
113 ) -> Result<(), Error> {
114 let conn = self.get_conn(connection).await.map_err(other_error)?;
115 let () = conn.set(&key, &value).await.map_err(other_error)?;
118 Ok(())
119 }
120
121 #[instrument(name = "spin_outbound_redis.incr", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("INCRBY {} 1", key)))]
122 async fn incr(
123 &mut self,
124 connection: Resource<RedisConnection>,
125 key: String,
126 ) -> Result<i64, Error> {
127 let conn = self.get_conn(connection).await.map_err(other_error)?;
128 let value = conn.incr(&key, 1).await.map_err(other_error)?;
129 Ok(value)
130 }
131
132 #[instrument(name = "spin_outbound_redis.del", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("DEL {}", keys.join(" "))))]
133 async fn del(
134 &mut self,
135 connection: Resource<RedisConnection>,
136 keys: Vec<String>,
137 ) -> Result<u32, Error> {
138 let conn = self.get_conn(connection).await.map_err(other_error)?;
139 let value = conn.del(&keys).await.map_err(other_error)?;
140 Ok(value)
141 }
142
143 #[instrument(name = "spin_outbound_redis.sadd", skip(self, connection, values), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("SADD {} {}", key, values.join(" "))))]
144 async fn sadd(
145 &mut self,
146 connection: Resource<RedisConnection>,
147 key: String,
148 values: Vec<String>,
149 ) -> Result<u32, Error> {
150 let conn = self.get_conn(connection).await.map_err(other_error)?;
151 let value = conn.sadd(&key, &values).await.map_err(|e| {
152 if e.kind() == redis::ErrorKind::TypeError {
153 Error::TypeError
154 } else {
155 Error::Other(e.to_string())
156 }
157 })?;
158 Ok(value)
159 }
160
161 #[instrument(name = "spin_outbound_redis.smembers", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("SMEMBERS {}", key)))]
162 async fn smembers(
163 &mut self,
164 connection: Resource<RedisConnection>,
165 key: String,
166 ) -> Result<Vec<String>, Error> {
167 let conn = self.get_conn(connection).await.map_err(other_error)?;
168 let value = conn.smembers(&key).await.map_err(other_error)?;
169 Ok(value)
170 }
171
172 #[instrument(name = "spin_outbound_redis.srem", skip(self, connection, values), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("SREM {} {}", key, values.join(" "))))]
173 async fn srem(
174 &mut self,
175 connection: Resource<RedisConnection>,
176 key: String,
177 values: Vec<String>,
178 ) -> Result<u32, Error> {
179 let conn = self.get_conn(connection).await.map_err(other_error)?;
180 let value = conn.srem(&key, &values).await.map_err(other_error)?;
181 Ok(value)
182 }
183
184 #[instrument(name = "spin_outbound_redis.execute", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("{}", command)))]
185 async fn execute(
186 &mut self,
187 connection: Resource<RedisConnection>,
188 command: String,
189 arguments: Vec<RedisParameter>,
190 ) -> Result<Vec<RedisResult>, Error> {
191 let conn = self.get_conn(connection).await?;
192 let mut cmd = redis::cmd(&command);
193 arguments.iter().for_each(|value| match value {
194 RedisParameter::Int64(v) => {
195 cmd.arg(v);
196 }
197 RedisParameter::Binary(v) => {
198 cmd.arg(v);
199 }
200 });
201
202 cmd.query_async::<RedisResults>(conn)
203 .await
204 .map(|values| values.0)
205 .map_err(other_error)
206 }
207
208 async fn drop(&mut self, connection: Resource<RedisConnection>) -> anyhow::Result<()> {
209 self.connections.remove(connection.rep());
210 Ok(())
211 }
212}
213
214fn other_error(e: impl std::fmt::Display) -> Error {
215 Error::Other(e.to_string())
216}
217
218macro_rules! delegate {
220 ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{
221 if !$self.is_address_allowed(&$address).await.map_err(|_| v1::Error::Error)? {
222 return Err(v1::Error::Error);
223 }
224 let connection = match $self.establish_connection($address).await {
225 Ok(c) => c,
226 Err(_) => return Err(v1::Error::Error),
227 };
228 <Self as v2::HostConnection>::$name($self, connection, $($arg),*)
229 .await
230 .map_err(|_| v1::Error::Error)
231 }};
232}
233
234impl v1::Host for crate::InstanceState {
235 async fn publish(
236 &mut self,
237 address: String,
238 channel: String,
239 payload: Vec<u8>,
240 ) -> Result<(), v1::Error> {
241 delegate!(self.publish(address, channel, payload))
242 }
243
244 async fn get(&mut self, address: String, key: String) -> Result<Vec<u8>, v1::Error> {
245 delegate!(self.get(address, key)).map(|v| v.unwrap_or_default())
246 }
247
248 async fn set(&mut self, address: String, key: String, value: Vec<u8>) -> Result<(), v1::Error> {
249 delegate!(self.set(address, key, value))
250 }
251
252 async fn incr(&mut self, address: String, key: String) -> Result<i64, v1::Error> {
253 delegate!(self.incr(address, key))
254 }
255
256 async fn del(&mut self, address: String, keys: Vec<String>) -> Result<i64, v1::Error> {
257 delegate!(self.del(address, keys)).map(|v| v as i64)
258 }
259
260 async fn sadd(
261 &mut self,
262 address: String,
263 key: String,
264 values: Vec<String>,
265 ) -> Result<i64, v1::Error> {
266 delegate!(self.sadd(address, key, values)).map(|v| v as i64)
267 }
268
269 async fn smembers(&mut self, address: String, key: String) -> Result<Vec<String>, v1::Error> {
270 delegate!(self.smembers(address, key))
271 }
272
273 async fn srem(
274 &mut self,
275 address: String,
276 key: String,
277 values: Vec<String>,
278 ) -> Result<i64, v1::Error> {
279 delegate!(self.srem(address, key, values)).map(|v| v as i64)
280 }
281
282 async fn execute(
283 &mut self,
284 address: String,
285 command: String,
286 arguments: Vec<v1::RedisParameter>,
287 ) -> Result<Vec<v1::RedisResult>, v1::Error> {
288 delegate!(self.execute(
289 address,
290 command,
291 arguments.into_iter().map(Into::into).collect()
292 ))
293 .map(|v| v.into_iter().map(Into::into).collect())
294 }
295}
296
297impl redis_types::Host for crate::InstanceState {
298 fn convert_error(&mut self, error: redis_types::Error) -> Result<redis_types::Error> {
299 Ok(error)
300 }
301}
302
303struct RedisResults(Vec<RedisResult>);
304
305impl FromRedisValue for RedisResults {
306 fn from_redis_value(value: &Value) -> redis::RedisResult<Self> {
307 fn append(values: &mut Vec<RedisResult>, value: &Value) -> redis::RedisResult<()> {
308 match value {
309 Value::Nil => {
310 values.push(RedisResult::Nil);
311 Ok(())
312 }
313 Value::Int(v) => {
314 values.push(RedisResult::Int64(*v));
315 Ok(())
316 }
317 Value::BulkString(bytes) => {
318 values.push(RedisResult::Binary(bytes.to_owned()));
319 Ok(())
320 }
321 Value::SimpleString(s) => {
322 values.push(RedisResult::Status(s.to_owned()));
323 Ok(())
324 }
325 Value::Okay => {
326 values.push(RedisResult::Status("OK".to_string()));
327 Ok(())
328 }
329 Value::Map(_) => Err(redis::RedisError::from((
330 redis::ErrorKind::TypeError,
331 "Could not convert Redis response",
332 "Redis Map type is not supported".to_string(),
333 ))),
334 Value::Attribute { .. } => Err(redis::RedisError::from((
335 redis::ErrorKind::TypeError,
336 "Could not convert Redis response",
337 "Redis Attribute type is not supported".to_string(),
338 ))),
339 Value::Array(arr) | Value::Set(arr) => {
340 arr.iter().try_for_each(|value| append(values, value))
341 }
342 Value::Double(v) => {
343 values.push(RedisResult::Binary(v.to_string().into_bytes()));
344 Ok(())
345 }
346 Value::VerbatimString { .. } => Err(redis::RedisError::from((
347 redis::ErrorKind::TypeError,
348 "Could not convert Redis response",
349 "Redis string with format attribute is not supported".to_string(),
350 ))),
351 Value::Boolean(v) => {
352 values.push(RedisResult::Int64(if *v { 1 } else { 0 }));
353 Ok(())
354 }
355 Value::BigNumber(v) => {
356 values.push(RedisResult::Binary(v.to_string().as_bytes().to_owned()));
357 Ok(())
358 }
359 Value::Push { .. } => Err(redis::RedisError::from((
360 redis::ErrorKind::TypeError,
361 "Could not convert Redis response",
362 "Redis Pub/Sub types are not supported".to_string(),
363 ))),
364 Value::ServerError(err) => Err(redis::RedisError::from((
365 redis::ErrorKind::ResponseError,
366 "Server error",
367 format!("{err:?}"),
368 ))),
369 }
370 }
371 let mut values = Vec::new();
372 append(&mut values, value)?;
373 Ok(RedisResults(values))
374 }
375}
376
377struct SpinDnsResolver(BlockedNetworks);
379
380impl AsyncDNSResolver for SpinDnsResolver {
381 fn resolve<'a, 'b: 'a>(
382 &'a self,
383 host: &'b str,
384 port: u16,
385 ) -> redis::RedisFuture<'a, Box<dyn Iterator<Item = std::net::SocketAddr> + Send + 'a>> {
386 Box::pin(async move {
387 let mut addrs = tokio::net::lookup_host((host, port))
388 .await?
389 .collect::<Vec<_>>();
390 let blocked_addrs = self.0.remove_blocked(&mut addrs);
392 if addrs.is_empty() && !blocked_addrs.is_empty() {
393 tracing::error!(
394 "error.type" = "destination_ip_prohibited",
395 ?blocked_addrs,
396 "all destination IP(s) prohibited by runtime config"
397 );
398 }
399 Ok(Box::new(addrs.into_iter()) as Box<dyn Iterator<Item = SocketAddr> + Send>)
400 })
401 }
402}