spin_factor_outbound_redis/
host.rs

1use anyhow::Result;
2use redis::{aio::MultiplexedConnection, AsyncCommands, FromRedisValue, Value};
3use spin_core::wasmtime::component::Resource;
4use spin_factor_outbound_networking::OutboundAllowedHosts;
5use spin_world::v1::{redis as v1, redis_types};
6use spin_world::v2::redis::{
7    self as v2, Connection as RedisConnection, Error, RedisParameter, RedisResult,
8};
9use tracing::field::Empty;
10use tracing::{instrument, Level};
11
12pub struct InstanceState {
13    pub allowed_hosts: OutboundAllowedHosts,
14    pub connections: spin_resource_table::Table<MultiplexedConnection>,
15}
16
17impl InstanceState {
18    async fn is_address_allowed(&self, address: &str) -> Result<bool> {
19        self.allowed_hosts.check_url(address, "redis").await
20    }
21
22    async fn establish_connection(
23        &mut self,
24        address: String,
25    ) -> Result<Resource<RedisConnection>, Error> {
26        let conn = redis::Client::open(address.as_str())
27            .map_err(|_| Error::InvalidAddress)?
28            .get_multiplexed_async_connection()
29            .await
30            .map_err(other_error)?;
31        self.connections
32            .push(conn)
33            .map(Resource::new_own)
34            .map_err(|_| Error::TooManyConnections)
35    }
36
37    async fn get_conn(
38        &mut self,
39        connection: Resource<RedisConnection>,
40    ) -> Result<&mut MultiplexedConnection, Error> {
41        self.connections
42            .get_mut(connection.rep())
43            .ok_or(Error::Other(
44                "could not find connection for resource".into(),
45            ))
46    }
47}
48
49impl v2::Host for crate::InstanceState {
50    fn convert_error(&mut self, error: Error) -> Result<Error> {
51        Ok(error)
52    }
53}
54
55impl v2::HostConnection for crate::InstanceState {
56    #[instrument(name = "spin_outbound_redis.open_connection", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", db.address = Empty, server.port = Empty, db.namespace = Empty))]
57    async fn open(&mut self, address: String) -> Result<Resource<RedisConnection>, Error> {
58        spin_factor_outbound_networking::record_address_fields(&address);
59
60        if !self
61            .is_address_allowed(&address)
62            .await
63            .map_err(|e| v2::Error::Other(e.to_string()))?
64        {
65            return Err(Error::InvalidAddress);
66        }
67
68        self.establish_connection(address).await
69    }
70
71    #[instrument(name = "spin_outbound_redis.publish", skip(self, connection, payload), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("PUBLISH {}", channel)))]
72    async fn publish(
73        &mut self,
74        connection: Resource<RedisConnection>,
75        channel: String,
76        payload: Vec<u8>,
77    ) -> Result<(), Error> {
78        let conn = self.get_conn(connection).await.map_err(other_error)?;
79        // The `let () =` syntax is needed to suppress a warning when the result type is inferred.
80        // You can read more about the issue here: <https://github.com/redis-rs/redis-rs/issues/1228>
81        let () = conn
82            .publish(&channel, &payload)
83            .await
84            .map_err(other_error)?;
85        Ok(())
86    }
87
88    #[instrument(name = "spin_outbound_redis.get", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("GET {}", key)))]
89    async fn get(
90        &mut self,
91        connection: Resource<RedisConnection>,
92        key: String,
93    ) -> Result<Option<Vec<u8>>, Error> {
94        let conn = self.get_conn(connection).await.map_err(other_error)?;
95        let value = conn.get(&key).await.map_err(other_error)?;
96        Ok(value)
97    }
98
99    #[instrument(name = "spin_outbound_redis.set", skip(self, connection, value), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("SET {}", key)))]
100    async fn set(
101        &mut self,
102        connection: Resource<RedisConnection>,
103        key: String,
104        value: Vec<u8>,
105    ) -> Result<(), Error> {
106        let conn = self.get_conn(connection).await.map_err(other_error)?;
107        // The `let () =` syntax is needed to suppress a warning when the result type is inferred.
108        // You can read more about the issue here: <https://github.com/redis-rs/redis-rs/issues/1228>
109        let () = conn.set(&key, &value).await.map_err(other_error)?;
110        Ok(())
111    }
112
113    #[instrument(name = "spin_outbound_redis.incr", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("INCRBY {} 1", key)))]
114    async fn incr(
115        &mut self,
116        connection: Resource<RedisConnection>,
117        key: String,
118    ) -> Result<i64, Error> {
119        let conn = self.get_conn(connection).await.map_err(other_error)?;
120        let value = conn.incr(&key, 1).await.map_err(other_error)?;
121        Ok(value)
122    }
123
124    #[instrument(name = "spin_outbound_redis.del", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("DEL {}", keys.join(" "))))]
125    async fn del(
126        &mut self,
127        connection: Resource<RedisConnection>,
128        keys: Vec<String>,
129    ) -> Result<u32, Error> {
130        let conn = self.get_conn(connection).await.map_err(other_error)?;
131        let value = conn.del(&keys).await.map_err(other_error)?;
132        Ok(value)
133    }
134
135    #[instrument(name = "spin_outbound_redis.sadd", skip(self, connection, values), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("SADD {} {}", key, values.join(" "))))]
136    async fn sadd(
137        &mut self,
138        connection: Resource<RedisConnection>,
139        key: String,
140        values: Vec<String>,
141    ) -> Result<u32, Error> {
142        let conn = self.get_conn(connection).await.map_err(other_error)?;
143        let value = conn.sadd(&key, &values).await.map_err(|e| {
144            if e.kind() == redis::ErrorKind::TypeError {
145                Error::TypeError
146            } else {
147                Error::Other(e.to_string())
148            }
149        })?;
150        Ok(value)
151    }
152
153    #[instrument(name = "spin_outbound_redis.smembers", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("SMEMBERS {}", key)))]
154    async fn smembers(
155        &mut self,
156        connection: Resource<RedisConnection>,
157        key: String,
158    ) -> Result<Vec<String>, Error> {
159        let conn = self.get_conn(connection).await.map_err(other_error)?;
160        let value = conn.smembers(&key).await.map_err(other_error)?;
161        Ok(value)
162    }
163
164    #[instrument(name = "spin_outbound_redis.srem", skip(self, connection, values), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("SREM {} {}", key, values.join(" "))))]
165    async fn srem(
166        &mut self,
167        connection: Resource<RedisConnection>,
168        key: String,
169        values: Vec<String>,
170    ) -> Result<u32, Error> {
171        let conn = self.get_conn(connection).await.map_err(other_error)?;
172        let value = conn.srem(&key, &values).await.map_err(other_error)?;
173        Ok(value)
174    }
175
176    #[instrument(name = "spin_outbound_redis.execute", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("{}", command)))]
177    async fn execute(
178        &mut self,
179        connection: Resource<RedisConnection>,
180        command: String,
181        arguments: Vec<RedisParameter>,
182    ) -> Result<Vec<RedisResult>, Error> {
183        let conn = self.get_conn(connection).await?;
184        let mut cmd = redis::cmd(&command);
185        arguments.iter().for_each(|value| match value {
186            RedisParameter::Int64(v) => {
187                cmd.arg(v);
188            }
189            RedisParameter::Binary(v) => {
190                cmd.arg(v);
191            }
192        });
193
194        cmd.query_async::<_, RedisResults>(conn)
195            .await
196            .map(|values| values.0)
197            .map_err(other_error)
198    }
199
200    async fn drop(&mut self, connection: Resource<RedisConnection>) -> anyhow::Result<()> {
201        self.connections.remove(connection.rep());
202        Ok(())
203    }
204}
205
206fn other_error(e: impl std::fmt::Display) -> Error {
207    Error::Other(e.to_string())
208}
209
210/// Delegate a function call to the v2::HostConnection implementation
211macro_rules! delegate {
212    ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{
213        if !$self.is_address_allowed(&$address).await.map_err(|_| v1::Error::Error)?  {
214            return Err(v1::Error::Error);
215        }
216        let connection = match $self.establish_connection($address).await {
217            Ok(c) => c,
218            Err(_) => return Err(v1::Error::Error),
219        };
220        <Self as v2::HostConnection>::$name($self, connection, $($arg),*)
221            .await
222            .map_err(|_| v1::Error::Error)
223    }};
224}
225
226impl v1::Host for crate::InstanceState {
227    async fn publish(
228        &mut self,
229        address: String,
230        channel: String,
231        payload: Vec<u8>,
232    ) -> Result<(), v1::Error> {
233        delegate!(self.publish(address, channel, payload))
234    }
235
236    async fn get(&mut self, address: String, key: String) -> Result<Vec<u8>, v1::Error> {
237        delegate!(self.get(address, key)).map(|v| v.unwrap_or_default())
238    }
239
240    async fn set(&mut self, address: String, key: String, value: Vec<u8>) -> Result<(), v1::Error> {
241        delegate!(self.set(address, key, value))
242    }
243
244    async fn incr(&mut self, address: String, key: String) -> Result<i64, v1::Error> {
245        delegate!(self.incr(address, key))
246    }
247
248    async fn del(&mut self, address: String, keys: Vec<String>) -> Result<i64, v1::Error> {
249        delegate!(self.del(address, keys)).map(|v| v as i64)
250    }
251
252    async fn sadd(
253        &mut self,
254        address: String,
255        key: String,
256        values: Vec<String>,
257    ) -> Result<i64, v1::Error> {
258        delegate!(self.sadd(address, key, values)).map(|v| v as i64)
259    }
260
261    async fn smembers(&mut self, address: String, key: String) -> Result<Vec<String>, v1::Error> {
262        delegate!(self.smembers(address, key))
263    }
264
265    async fn srem(
266        &mut self,
267        address: String,
268        key: String,
269        values: Vec<String>,
270    ) -> Result<i64, v1::Error> {
271        delegate!(self.srem(address, key, values)).map(|v| v as i64)
272    }
273
274    async fn execute(
275        &mut self,
276        address: String,
277        command: String,
278        arguments: Vec<v1::RedisParameter>,
279    ) -> Result<Vec<v1::RedisResult>, v1::Error> {
280        delegate!(self.execute(
281            address,
282            command,
283            arguments.into_iter().map(Into::into).collect()
284        ))
285        .map(|v| v.into_iter().map(Into::into).collect())
286    }
287}
288
289impl redis_types::Host for crate::InstanceState {
290    fn convert_error(&mut self, error: redis_types::Error) -> Result<redis_types::Error> {
291        Ok(error)
292    }
293}
294
295struct RedisResults(Vec<RedisResult>);
296
297impl FromRedisValue for RedisResults {
298    fn from_redis_value(value: &Value) -> redis::RedisResult<Self> {
299        fn append(values: &mut Vec<RedisResult>, value: &Value) {
300            match value {
301                Value::Nil | Value::Okay => (),
302                Value::Int(v) => values.push(RedisResult::Int64(*v)),
303                Value::Data(bytes) => values.push(RedisResult::Binary(bytes.to_owned())),
304                Value::Bulk(bulk) => bulk.iter().for_each(|value| append(values, value)),
305                Value::Status(message) => values.push(RedisResult::Status(message.to_owned())),
306            }
307        }
308
309        let mut values = Vec::new();
310        append(&mut values, value);
311        Ok(RedisResults(values))
312    }
313}