Skip to main content

spin_factor_outbound_redis/
host.rs

1use std::net::SocketAddr;
2
3use anyhow::Result;
4use redis::io::AsyncDNSResolver;
5use redis::AsyncConnectionConfig;
6use redis::{aio::MultiplexedConnection, AsyncCommands, FromRedisValue, Value};
7use spin_core::wasmtime::component::Resource;
8use spin_factor_otel::OtelFactorState;
9use spin_factor_outbound_networking::config::allowed_hosts::OutboundAllowedHosts;
10use spin_factor_outbound_networking::config::blocked_networks::BlockedNetworks;
11use spin_world::v1::{redis as v1, redis_types};
12use spin_world::v2::redis::{
13    self as v2, Connection as RedisConnection, Error, RedisParameter, RedisResult,
14};
15use spin_world::MAX_HOST_BUFFERED_BYTES;
16use tracing::field::Empty;
17use tracing::{instrument, Level};
18
19pub struct InstanceState {
20    pub allowed_hosts: OutboundAllowedHosts,
21    pub blocked_networks: BlockedNetworks,
22    pub connections: spin_resource_table::Table<MultiplexedConnection>,
23    pub otel: OtelFactorState,
24}
25
26impl InstanceState {
27    async fn is_address_allowed(&self, address: &str) -> Result<bool> {
28        self.allowed_hosts.check_url(address, "redis").await
29    }
30
31    async fn establish_connection(
32        &mut self,
33        address: String,
34    ) -> Result<Resource<RedisConnection>, Error> {
35        let config = AsyncConnectionConfig::new()
36            .set_dns_resolver(SpinDnsResolver(self.blocked_networks.clone()));
37        let conn = redis::Client::open(address.as_str())
38            .map_err(|_| Error::InvalidAddress)?
39            .get_multiplexed_async_connection_with_config(&config)
40            .await
41            .map_err(other_error)?;
42        self.connections
43            .push(conn)
44            .map(Resource::new_own)
45            .map_err(|_| Error::TooManyConnections)
46    }
47
48    async fn get_conn(
49        &mut self,
50        connection: Resource<RedisConnection>,
51    ) -> Result<&mut MultiplexedConnection, Error> {
52        self.connections
53            .get_mut(connection.rep())
54            .ok_or(Error::Other(
55                "could not find connection for resource".into(),
56            ))
57    }
58}
59
60impl v2::Host for crate::InstanceState {
61    fn convert_error(&mut self, error: Error) -> Result<Error> {
62        Ok(error)
63    }
64}
65
66impl v2::HostConnection for crate::InstanceState {
67    #[instrument(name = "spin_outbound_redis.open_connection", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", db.address = Empty, server.port = Empty, db.namespace = Empty))]
68    async fn open(&mut self, address: String) -> Result<Resource<RedisConnection>, Error> {
69        self.otel.reparent_tracing_span();
70        if !self
71            .is_address_allowed(&address)
72            .await
73            .map_err(|e| v2::Error::Other(e.to_string()))?
74        {
75            return Err(Error::InvalidAddress);
76        }
77
78        self.establish_connection(address).await
79    }
80
81    #[instrument(name = "spin_outbound_redis.publish", skip(self, connection, payload), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("PUBLISH {}", channel)))]
82    async fn publish(
83        &mut self,
84        connection: Resource<RedisConnection>,
85        channel: String,
86        payload: Vec<u8>,
87    ) -> Result<(), Error> {
88        self.otel.reparent_tracing_span();
89
90        let conn = self.get_conn(connection).await.map_err(other_error)?;
91        // The `let () =` syntax is needed to suppress a warning when the result type is inferred.
92        // You can read more about the issue here: <https://github.com/redis-rs/redis-rs/issues/1228>
93        let () = conn
94            .publish(&channel, &payload)
95            .await
96            .map_err(other_error)?;
97        Ok(())
98    }
99
100    #[instrument(name = "spin_outbound_redis.get", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("GET {}", key)))]
101    async fn get(
102        &mut self,
103        connection: Resource<RedisConnection>,
104        key: String,
105    ) -> Result<Option<Vec<u8>>, Error> {
106        self.otel.reparent_tracing_span();
107
108        let conn = self.get_conn(connection).await.map_err(other_error)?;
109        let value = conn
110            .get::<_, Option<Vec<u8>>>(&key)
111            .await
112            .map_err(other_error)?;
113
114        // Currently there's no way to stream a `GET` result using the `redis`
115        // crate without buffering, so the damage (in terms of host memory
116        // usage) is already done, but we can still enforce the limit:
117        if std::mem::size_of::<Option<Vec<u8>>>() + value.as_ref().map(|v| v.len()).unwrap_or(0)
118            > MAX_HOST_BUFFERED_BYTES
119        {
120            Err(Error::Other(format!(
121                "query result exceeds limit of {MAX_HOST_BUFFERED_BYTES} bytes"
122            )))
123        } else {
124            Ok(value)
125        }
126    }
127
128    #[instrument(name = "spin_outbound_redis.set", skip(self, connection, value), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("SET {}", key)))]
129    async fn set(
130        &mut self,
131        connection: Resource<RedisConnection>,
132        key: String,
133        value: Vec<u8>,
134    ) -> Result<(), Error> {
135        self.otel.reparent_tracing_span();
136
137        let conn = self.get_conn(connection).await.map_err(other_error)?;
138        // The `let () =` syntax is needed to suppress a warning when the result type is inferred.
139        // You can read more about the issue here: <https://github.com/redis-rs/redis-rs/issues/1228>
140        let () = conn.set(&key, &value).await.map_err(other_error)?;
141        Ok(())
142    }
143
144    #[instrument(name = "spin_outbound_redis.incr", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("INCRBY {} 1", key)))]
145    async fn incr(
146        &mut self,
147        connection: Resource<RedisConnection>,
148        key: String,
149    ) -> Result<i64, Error> {
150        self.otel.reparent_tracing_span();
151
152        let conn = self.get_conn(connection).await.map_err(other_error)?;
153        let value = conn.incr(&key, 1).await.map_err(other_error)?;
154        Ok(value)
155    }
156
157    #[instrument(name = "spin_outbound_redis.del", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("DEL {}", keys.join(" "))))]
158    async fn del(
159        &mut self,
160        connection: Resource<RedisConnection>,
161        keys: Vec<String>,
162    ) -> Result<u32, Error> {
163        self.otel.reparent_tracing_span();
164
165        let conn = self.get_conn(connection).await.map_err(other_error)?;
166        let value = conn.del(&keys).await.map_err(other_error)?;
167        Ok(value)
168    }
169
170    #[instrument(name = "spin_outbound_redis.sadd", skip(self, connection, values), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("SADD {} {}", key, values.join(" "))))]
171    async fn sadd(
172        &mut self,
173        connection: Resource<RedisConnection>,
174        key: String,
175        values: Vec<String>,
176    ) -> Result<u32, Error> {
177        self.otel.reparent_tracing_span();
178
179        let conn = self.get_conn(connection).await.map_err(other_error)?;
180        let value = conn.sadd(&key, &values).await.map_err(|e| {
181            if e.kind() == redis::ErrorKind::TypeError {
182                Error::TypeError
183            } else {
184                Error::Other(e.to_string())
185            }
186        })?;
187        Ok(value)
188    }
189
190    #[instrument(name = "spin_outbound_redis.smembers", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("SMEMBERS {}", key)))]
191    async fn smembers(
192        &mut self,
193        connection: Resource<RedisConnection>,
194        key: String,
195    ) -> Result<Vec<String>, Error> {
196        self.otel.reparent_tracing_span();
197
198        let conn = self.get_conn(connection).await.map_err(other_error)?;
199        let value = conn.smembers(&key).await.map_err(other_error)?;
200        Ok(value)
201    }
202
203    #[instrument(name = "spin_outbound_redis.srem", skip(self, connection, values), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("SREM {} {}", key, values.join(" "))))]
204    async fn srem(
205        &mut self,
206        connection: Resource<RedisConnection>,
207        key: String,
208        values: Vec<String>,
209    ) -> Result<u32, Error> {
210        self.otel.reparent_tracing_span();
211
212        let conn = self.get_conn(connection).await.map_err(other_error)?;
213        let value = conn.srem(&key, &values).await.map_err(other_error)?;
214        Ok(value)
215    }
216
217    #[instrument(name = "spin_outbound_redis.execute", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("{}", command)))]
218    async fn execute(
219        &mut self,
220        connection: Resource<RedisConnection>,
221        command: String,
222        arguments: Vec<RedisParameter>,
223    ) -> Result<Vec<RedisResult>, Error> {
224        self.otel.reparent_tracing_span();
225
226        let conn = self.get_conn(connection).await?;
227        let mut cmd = redis::cmd(&command);
228        arguments.iter().for_each(|value| match value {
229            RedisParameter::Int64(v) => {
230                cmd.arg(v);
231            }
232            RedisParameter::Binary(v) => {
233                cmd.arg(v);
234            }
235        });
236
237        let results = cmd
238            .query_async::<RedisResults>(conn)
239            .await
240            .map(|values| values.0)
241            .map_err(other_error)?;
242
243        // Currently there's no way to stream results using the `redis`
244        // crate without buffering, so the damage (in terms of host memory
245        // usage) is already done, but we can still enforce the limit:
246        if std::mem::size_of::<Vec<RedisResult>>() + results.iter().map(memory_size).sum::<usize>()
247            > MAX_HOST_BUFFERED_BYTES
248        {
249            Err(Error::Other(format!(
250                "query result exceeds limit of {MAX_HOST_BUFFERED_BYTES} bytes"
251            )))
252        } else {
253            Ok(results)
254        }
255    }
256
257    async fn drop(&mut self, connection: Resource<RedisConnection>) -> anyhow::Result<()> {
258        self.connections.remove(connection.rep());
259        Ok(())
260    }
261}
262
263fn other_error(e: impl std::fmt::Display) -> Error {
264    Error::Other(e.to_string())
265}
266
267/// Delegate a function call to the v2::HostConnection implementation
268macro_rules! delegate {
269    ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{
270        if !$self.is_address_allowed(&$address).await.map_err(|_| v1::Error::Error)?  {
271            return Err(v1::Error::Error);
272        }
273        let connection = match $self.establish_connection($address).await {
274            Ok(c) => c,
275            Err(_) => return Err(v1::Error::Error),
276        };
277        <Self as v2::HostConnection>::$name($self, connection, $($arg),*)
278            .await
279            .map_err(|_| v1::Error::Error)
280    }};
281}
282
283impl v1::Host for crate::InstanceState {
284    async fn publish(
285        &mut self,
286        address: String,
287        channel: String,
288        payload: Vec<u8>,
289    ) -> Result<(), v1::Error> {
290        delegate!(self.publish(address, channel, payload))
291    }
292
293    async fn get(&mut self, address: String, key: String) -> Result<Vec<u8>, v1::Error> {
294        delegate!(self.get(address, key)).map(|v| v.unwrap_or_default())
295    }
296
297    async fn set(&mut self, address: String, key: String, value: Vec<u8>) -> Result<(), v1::Error> {
298        delegate!(self.set(address, key, value))
299    }
300
301    async fn incr(&mut self, address: String, key: String) -> Result<i64, v1::Error> {
302        delegate!(self.incr(address, key))
303    }
304
305    async fn del(&mut self, address: String, keys: Vec<String>) -> Result<i64, v1::Error> {
306        delegate!(self.del(address, keys)).map(|v| v as i64)
307    }
308
309    async fn sadd(
310        &mut self,
311        address: String,
312        key: String,
313        values: Vec<String>,
314    ) -> Result<i64, v1::Error> {
315        delegate!(self.sadd(address, key, values)).map(|v| v as i64)
316    }
317
318    async fn smembers(&mut self, address: String, key: String) -> Result<Vec<String>, v1::Error> {
319        delegate!(self.smembers(address, key))
320    }
321
322    async fn srem(
323        &mut self,
324        address: String,
325        key: String,
326        values: Vec<String>,
327    ) -> Result<i64, v1::Error> {
328        delegate!(self.srem(address, key, values)).map(|v| v as i64)
329    }
330
331    async fn execute(
332        &mut self,
333        address: String,
334        command: String,
335        arguments: Vec<v1::RedisParameter>,
336    ) -> Result<Vec<v1::RedisResult>, v1::Error> {
337        delegate!(self.execute(
338            address,
339            command,
340            arguments.into_iter().map(Into::into).collect()
341        ))
342        .map(|v| v.into_iter().map(Into::into).collect())
343    }
344}
345
346impl redis_types::Host for crate::InstanceState {
347    fn convert_error(&mut self, error: redis_types::Error) -> Result<redis_types::Error> {
348        Ok(error)
349    }
350}
351
352struct RedisResults(Vec<RedisResult>);
353
354impl FromRedisValue for RedisResults {
355    fn from_redis_value(value: &Value) -> redis::RedisResult<Self> {
356        fn append(values: &mut Vec<RedisResult>, value: &Value) -> redis::RedisResult<()> {
357            match value {
358                Value::Nil => {
359                    values.push(RedisResult::Nil);
360                    Ok(())
361                }
362                Value::Int(v) => {
363                    values.push(RedisResult::Int64(*v));
364                    Ok(())
365                }
366                Value::BulkString(bytes) => {
367                    values.push(RedisResult::Binary(bytes.to_owned()));
368                    Ok(())
369                }
370                Value::SimpleString(s) => {
371                    values.push(RedisResult::Status(s.to_owned()));
372                    Ok(())
373                }
374                Value::Okay => {
375                    values.push(RedisResult::Status("OK".to_string()));
376                    Ok(())
377                }
378                Value::Map(_) => Err(redis::RedisError::from((
379                    redis::ErrorKind::TypeError,
380                    "Could not convert Redis response",
381                    "Redis Map type is not supported".to_string(),
382                ))),
383                Value::Attribute { .. } => Err(redis::RedisError::from((
384                    redis::ErrorKind::TypeError,
385                    "Could not convert Redis response",
386                    "Redis Attribute type is not supported".to_string(),
387                ))),
388                Value::Array(arr) | Value::Set(arr) => {
389                    arr.iter().try_for_each(|value| append(values, value))
390                }
391                Value::Double(v) => {
392                    values.push(RedisResult::Binary(v.to_string().into_bytes()));
393                    Ok(())
394                }
395                Value::VerbatimString { .. } => Err(redis::RedisError::from((
396                    redis::ErrorKind::TypeError,
397                    "Could not convert Redis response",
398                    "Redis string with format attribute is not supported".to_string(),
399                ))),
400                Value::Boolean(v) => {
401                    values.push(RedisResult::Int64(if *v { 1 } else { 0 }));
402                    Ok(())
403                }
404                Value::BigNumber(v) => {
405                    values.push(RedisResult::Binary(v.to_string().as_bytes().to_owned()));
406                    Ok(())
407                }
408                Value::Push { .. } => Err(redis::RedisError::from((
409                    redis::ErrorKind::TypeError,
410                    "Could not convert Redis response",
411                    "Redis Pub/Sub types are not supported".to_string(),
412                ))),
413                Value::ServerError(err) => Err(redis::RedisError::from((
414                    redis::ErrorKind::ResponseError,
415                    "Server error",
416                    format!("{err:?}"),
417                ))),
418            }
419        }
420        let mut values = Vec::new();
421        append(&mut values, value)?;
422        Ok(RedisResults(values))
423    }
424}
425
426fn memory_size(value: &RedisResult) -> usize {
427    match value {
428        RedisResult::Nil | RedisResult::Int64(_) => std::mem::size_of::<RedisResult>(),
429        RedisResult::Binary(b) => std::mem::size_of::<RedisResult>() + b.len(),
430        RedisResult::Status(s) => std::mem::size_of::<RedisResult>() + s.len(),
431    }
432}
433
434/// Resolves DNS using Tokio's resolver, filtering out blocked IPs.
435struct SpinDnsResolver(BlockedNetworks);
436
437impl AsyncDNSResolver for SpinDnsResolver {
438    fn resolve<'a, 'b: 'a>(
439        &'a self,
440        host: &'b str,
441        port: u16,
442    ) -> redis::RedisFuture<'a, Box<dyn Iterator<Item = std::net::SocketAddr> + Send + 'a>> {
443        Box::pin(async move {
444            let mut addrs = tokio::net::lookup_host((host, port))
445                .await?
446                .collect::<Vec<_>>();
447            // Remove blocked IPs
448            let blocked_addrs = self.0.remove_blocked(&mut addrs);
449            if addrs.is_empty() && !blocked_addrs.is_empty() {
450                tracing::error!(
451                    "error.type" = "destination_ip_prohibited",
452                    ?blocked_addrs,
453                    "all destination IP(s) prohibited by runtime config"
454                );
455            }
456            Ok(Box::new(addrs.into_iter()) as Box<dyn Iterator<Item = SocketAddr> + Send>)
457        })
458    }
459}