spin_factor_outbound_redis/
host.rs

1use std::net::SocketAddr;
2
3use anyhow::Result;
4use redis::io::AsyncDNSResolver;
5use redis::AsyncConnectionConfig;
6use redis::{aio::MultiplexedConnection, AsyncCommands, FromRedisValue, Value};
7use spin_core::wasmtime::component::Resource;
8use spin_factor_outbound_networking::config::allowed_hosts::OutboundAllowedHosts;
9use spin_factor_outbound_networking::config::blocked_networks::BlockedNetworks;
10use spin_world::v1::{redis as v1, redis_types};
11use spin_world::v2::redis::{
12    self as v2, Connection as RedisConnection, Error, RedisParameter, RedisResult,
13};
14use tracing::field::Empty;
15use tracing::{instrument, Level};
16
17pub struct InstanceState {
18    pub allowed_hosts: OutboundAllowedHosts,
19    pub blocked_networks: BlockedNetworks,
20    pub connections: spin_resource_table::Table<MultiplexedConnection>,
21}
22
23impl InstanceState {
24    async fn is_address_allowed(&self, address: &str) -> Result<bool> {
25        self.allowed_hosts.check_url(address, "redis").await
26    }
27
28    async fn establish_connection(
29        &mut self,
30        address: String,
31    ) -> Result<Resource<RedisConnection>, Error> {
32        let config = AsyncConnectionConfig::new()
33            .set_dns_resolver(SpinDnsResolver(self.blocked_networks.clone()));
34        let conn = redis::Client::open(address.as_str())
35            .map_err(|_| Error::InvalidAddress)?
36            .get_multiplexed_async_connection_with_config(&config)
37            .await
38            .map_err(other_error)?;
39        self.connections
40            .push(conn)
41            .map(Resource::new_own)
42            .map_err(|_| Error::TooManyConnections)
43    }
44
45    async fn get_conn(
46        &mut self,
47        connection: Resource<RedisConnection>,
48    ) -> Result<&mut MultiplexedConnection, Error> {
49        self.connections
50            .get_mut(connection.rep())
51            .ok_or(Error::Other(
52                "could not find connection for resource".into(),
53            ))
54    }
55}
56
57impl v2::Host for crate::InstanceState {
58    fn convert_error(&mut self, error: Error) -> Result<Error> {
59        Ok(error)
60    }
61}
62
63impl v2::HostConnection for crate::InstanceState {
64    #[instrument(name = "spin_outbound_redis.open_connection", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", db.address = Empty, server.port = Empty, db.namespace = Empty))]
65    async fn open(&mut self, address: String) -> Result<Resource<RedisConnection>, Error> {
66        spin_factor_outbound_networking::record_address_fields(&address);
67
68        if !self
69            .is_address_allowed(&address)
70            .await
71            .map_err(|e| v2::Error::Other(e.to_string()))?
72        {
73            return Err(Error::InvalidAddress);
74        }
75
76        self.establish_connection(address).await
77    }
78
79    #[instrument(name = "spin_outbound_redis.publish", skip(self, connection, payload), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("PUBLISH {}", channel)))]
80    async fn publish(
81        &mut self,
82        connection: Resource<RedisConnection>,
83        channel: String,
84        payload: Vec<u8>,
85    ) -> Result<(), Error> {
86        let conn = self.get_conn(connection).await.map_err(other_error)?;
87        // The `let () =` syntax is needed to suppress a warning when the result type is inferred.
88        // You can read more about the issue here: <https://github.com/redis-rs/redis-rs/issues/1228>
89        let () = conn
90            .publish(&channel, &payload)
91            .await
92            .map_err(other_error)?;
93        Ok(())
94    }
95
96    #[instrument(name = "spin_outbound_redis.get", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("GET {}", key)))]
97    async fn get(
98        &mut self,
99        connection: Resource<RedisConnection>,
100        key: String,
101    ) -> Result<Option<Vec<u8>>, Error> {
102        let conn = self.get_conn(connection).await.map_err(other_error)?;
103        let value = conn.get(&key).await.map_err(other_error)?;
104        Ok(value)
105    }
106
107    #[instrument(name = "spin_outbound_redis.set", skip(self, connection, value), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("SET {}", key)))]
108    async fn set(
109        &mut self,
110        connection: Resource<RedisConnection>,
111        key: String,
112        value: Vec<u8>,
113    ) -> Result<(), Error> {
114        let conn = self.get_conn(connection).await.map_err(other_error)?;
115        // The `let () =` syntax is needed to suppress a warning when the result type is inferred.
116        // You can read more about the issue here: <https://github.com/redis-rs/redis-rs/issues/1228>
117        let () = conn.set(&key, &value).await.map_err(other_error)?;
118        Ok(())
119    }
120
121    #[instrument(name = "spin_outbound_redis.incr", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("INCRBY {} 1", key)))]
122    async fn incr(
123        &mut self,
124        connection: Resource<RedisConnection>,
125        key: String,
126    ) -> Result<i64, Error> {
127        let conn = self.get_conn(connection).await.map_err(other_error)?;
128        let value = conn.incr(&key, 1).await.map_err(other_error)?;
129        Ok(value)
130    }
131
132    #[instrument(name = "spin_outbound_redis.del", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("DEL {}", keys.join(" "))))]
133    async fn del(
134        &mut self,
135        connection: Resource<RedisConnection>,
136        keys: Vec<String>,
137    ) -> Result<u32, Error> {
138        let conn = self.get_conn(connection).await.map_err(other_error)?;
139        let value = conn.del(&keys).await.map_err(other_error)?;
140        Ok(value)
141    }
142
143    #[instrument(name = "spin_outbound_redis.sadd", skip(self, connection, values), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("SADD {} {}", key, values.join(" "))))]
144    async fn sadd(
145        &mut self,
146        connection: Resource<RedisConnection>,
147        key: String,
148        values: Vec<String>,
149    ) -> Result<u32, Error> {
150        let conn = self.get_conn(connection).await.map_err(other_error)?;
151        let value = conn.sadd(&key, &values).await.map_err(|e| {
152            if e.kind() == redis::ErrorKind::TypeError {
153                Error::TypeError
154            } else {
155                Error::Other(e.to_string())
156            }
157        })?;
158        Ok(value)
159    }
160
161    #[instrument(name = "spin_outbound_redis.smembers", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("SMEMBERS {}", key)))]
162    async fn smembers(
163        &mut self,
164        connection: Resource<RedisConnection>,
165        key: String,
166    ) -> Result<Vec<String>, Error> {
167        let conn = self.get_conn(connection).await.map_err(other_error)?;
168        let value = conn.smembers(&key).await.map_err(other_error)?;
169        Ok(value)
170    }
171
172    #[instrument(name = "spin_outbound_redis.srem", skip(self, connection, values), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("SREM {} {}", key, values.join(" "))))]
173    async fn srem(
174        &mut self,
175        connection: Resource<RedisConnection>,
176        key: String,
177        values: Vec<String>,
178    ) -> Result<u32, Error> {
179        let conn = self.get_conn(connection).await.map_err(other_error)?;
180        let value = conn.srem(&key, &values).await.map_err(other_error)?;
181        Ok(value)
182    }
183
184    #[instrument(name = "spin_outbound_redis.execute", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("{}", command)))]
185    async fn execute(
186        &mut self,
187        connection: Resource<RedisConnection>,
188        command: String,
189        arguments: Vec<RedisParameter>,
190    ) -> Result<Vec<RedisResult>, Error> {
191        let conn = self.get_conn(connection).await?;
192        let mut cmd = redis::cmd(&command);
193        arguments.iter().for_each(|value| match value {
194            RedisParameter::Int64(v) => {
195                cmd.arg(v);
196            }
197            RedisParameter::Binary(v) => {
198                cmd.arg(v);
199            }
200        });
201
202        cmd.query_async::<RedisResults>(conn)
203            .await
204            .map(|values| values.0)
205            .map_err(other_error)
206    }
207
208    async fn drop(&mut self, connection: Resource<RedisConnection>) -> anyhow::Result<()> {
209        self.connections.remove(connection.rep());
210        Ok(())
211    }
212}
213
214fn other_error(e: impl std::fmt::Display) -> Error {
215    Error::Other(e.to_string())
216}
217
218/// Delegate a function call to the v2::HostConnection implementation
219macro_rules! delegate {
220    ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{
221        if !$self.is_address_allowed(&$address).await.map_err(|_| v1::Error::Error)?  {
222            return Err(v1::Error::Error);
223        }
224        let connection = match $self.establish_connection($address).await {
225            Ok(c) => c,
226            Err(_) => return Err(v1::Error::Error),
227        };
228        <Self as v2::HostConnection>::$name($self, connection, $($arg),*)
229            .await
230            .map_err(|_| v1::Error::Error)
231    }};
232}
233
234impl v1::Host for crate::InstanceState {
235    async fn publish(
236        &mut self,
237        address: String,
238        channel: String,
239        payload: Vec<u8>,
240    ) -> Result<(), v1::Error> {
241        delegate!(self.publish(address, channel, payload))
242    }
243
244    async fn get(&mut self, address: String, key: String) -> Result<Vec<u8>, v1::Error> {
245        delegate!(self.get(address, key)).map(|v| v.unwrap_or_default())
246    }
247
248    async fn set(&mut self, address: String, key: String, value: Vec<u8>) -> Result<(), v1::Error> {
249        delegate!(self.set(address, key, value))
250    }
251
252    async fn incr(&mut self, address: String, key: String) -> Result<i64, v1::Error> {
253        delegate!(self.incr(address, key))
254    }
255
256    async fn del(&mut self, address: String, keys: Vec<String>) -> Result<i64, v1::Error> {
257        delegate!(self.del(address, keys)).map(|v| v as i64)
258    }
259
260    async fn sadd(
261        &mut self,
262        address: String,
263        key: String,
264        values: Vec<String>,
265    ) -> Result<i64, v1::Error> {
266        delegate!(self.sadd(address, key, values)).map(|v| v as i64)
267    }
268
269    async fn smembers(&mut self, address: String, key: String) -> Result<Vec<String>, v1::Error> {
270        delegate!(self.smembers(address, key))
271    }
272
273    async fn srem(
274        &mut self,
275        address: String,
276        key: String,
277        values: Vec<String>,
278    ) -> Result<i64, v1::Error> {
279        delegate!(self.srem(address, key, values)).map(|v| v as i64)
280    }
281
282    async fn execute(
283        &mut self,
284        address: String,
285        command: String,
286        arguments: Vec<v1::RedisParameter>,
287    ) -> Result<Vec<v1::RedisResult>, v1::Error> {
288        delegate!(self.execute(
289            address,
290            command,
291            arguments.into_iter().map(Into::into).collect()
292        ))
293        .map(|v| v.into_iter().map(Into::into).collect())
294    }
295}
296
297impl redis_types::Host for crate::InstanceState {
298    fn convert_error(&mut self, error: redis_types::Error) -> Result<redis_types::Error> {
299        Ok(error)
300    }
301}
302
303struct RedisResults(Vec<RedisResult>);
304
305impl FromRedisValue for RedisResults {
306    fn from_redis_value(value: &Value) -> redis::RedisResult<Self> {
307        fn append(values: &mut Vec<RedisResult>, value: &Value) -> redis::RedisResult<()> {
308            match value {
309                Value::Nil => {
310                    values.push(RedisResult::Nil);
311                    Ok(())
312                }
313                Value::Int(v) => {
314                    values.push(RedisResult::Int64(*v));
315                    Ok(())
316                }
317                Value::BulkString(bytes) => {
318                    values.push(RedisResult::Binary(bytes.to_owned()));
319                    Ok(())
320                }
321                Value::SimpleString(s) => {
322                    values.push(RedisResult::Status(s.to_owned()));
323                    Ok(())
324                }
325                Value::Okay => {
326                    values.push(RedisResult::Status("OK".to_string()));
327                    Ok(())
328                }
329                Value::Map(_) => Err(redis::RedisError::from((
330                    redis::ErrorKind::TypeError,
331                    "Could not convert Redis response",
332                    "Redis Map type is not supported".to_string(),
333                ))),
334                Value::Attribute { .. } => Err(redis::RedisError::from((
335                    redis::ErrorKind::TypeError,
336                    "Could not convert Redis response",
337                    "Redis Attribute type is not supported".to_string(),
338                ))),
339                Value::Array(arr) | Value::Set(arr) => {
340                    arr.iter().try_for_each(|value| append(values, value))
341                }
342                Value::Double(v) => {
343                    values.push(RedisResult::Binary(v.to_string().into_bytes()));
344                    Ok(())
345                }
346                Value::VerbatimString { .. } => Err(redis::RedisError::from((
347                    redis::ErrorKind::TypeError,
348                    "Could not convert Redis response",
349                    "Redis string with format attribute is not supported".to_string(),
350                ))),
351                Value::Boolean(v) => {
352                    values.push(RedisResult::Int64(if *v { 1 } else { 0 }));
353                    Ok(())
354                }
355                Value::BigNumber(v) => {
356                    values.push(RedisResult::Binary(v.to_string().as_bytes().to_owned()));
357                    Ok(())
358                }
359                Value::Push { .. } => Err(redis::RedisError::from((
360                    redis::ErrorKind::TypeError,
361                    "Could not convert Redis response",
362                    "Redis Pub/Sub types are not supported".to_string(),
363                ))),
364                Value::ServerError(err) => Err(redis::RedisError::from((
365                    redis::ErrorKind::ResponseError,
366                    "Server error",
367                    format!("{err:?}"),
368                ))),
369            }
370        }
371        let mut values = Vec::new();
372        append(&mut values, value)?;
373        Ok(RedisResults(values))
374    }
375}
376
377/// Resolves DNS using Tokio's resolver, filtering out blocked IPs.
378struct SpinDnsResolver(BlockedNetworks);
379
380impl AsyncDNSResolver for SpinDnsResolver {
381    fn resolve<'a, 'b: 'a>(
382        &'a self,
383        host: &'b str,
384        port: u16,
385    ) -> redis::RedisFuture<'a, Box<dyn Iterator<Item = std::net::SocketAddr> + Send + 'a>> {
386        Box::pin(async move {
387            let mut addrs = tokio::net::lookup_host((host, port))
388                .await?
389                .collect::<Vec<_>>();
390            // Remove blocked IPs
391            let blocked_addrs = self.0.remove_blocked(&mut addrs);
392            if addrs.is_empty() && !blocked_addrs.is_empty() {
393                tracing::error!(
394                    "error.type" = "destination_ip_prohibited",
395                    ?blocked_addrs,
396                    "all destination IP(s) prohibited by runtime config"
397                );
398            }
399            Ok(Box::new(addrs.into_iter()) as Box<dyn Iterator<Item = SocketAddr> + Send>)
400        })
401    }
402}