1use std::net::SocketAddr;
2
3use anyhow::Result;
4use redis::io::AsyncDNSResolver;
5use redis::AsyncConnectionConfig;
6use redis::{aio::MultiplexedConnection, AsyncCommands, FromRedisValue, Value};
7use spin_core::wasmtime::component::Resource;
8use spin_factor_otel::OtelFactorState;
9use spin_factor_outbound_networking::config::allowed_hosts::OutboundAllowedHosts;
10use spin_factor_outbound_networking::config::blocked_networks::BlockedNetworks;
11use spin_world::v1::{redis as v1, redis_types};
12use spin_world::v2::redis::{
13 self as v2, Connection as RedisConnection, Error, RedisParameter, RedisResult,
14};
15use spin_world::MAX_HOST_BUFFERED_BYTES;
16use tracing::field::Empty;
17use tracing::{instrument, Level};
18
19pub struct InstanceState {
20 pub allowed_hosts: OutboundAllowedHosts,
21 pub blocked_networks: BlockedNetworks,
22 pub connections: spin_resource_table::Table<MultiplexedConnection>,
23 pub otel: OtelFactorState,
24}
25
26impl InstanceState {
27 async fn is_address_allowed(&self, address: &str) -> Result<bool> {
28 self.allowed_hosts.check_url(address, "redis").await
29 }
30
31 async fn establish_connection(
32 &mut self,
33 address: String,
34 ) -> Result<Resource<RedisConnection>, Error> {
35 let config = AsyncConnectionConfig::new()
36 .set_dns_resolver(SpinDnsResolver(self.blocked_networks.clone()));
37 let conn = redis::Client::open(address.as_str())
38 .map_err(|_| Error::InvalidAddress)?
39 .get_multiplexed_async_connection_with_config(&config)
40 .await
41 .map_err(other_error)?;
42 self.connections
43 .push(conn)
44 .map(Resource::new_own)
45 .map_err(|_| Error::TooManyConnections)
46 }
47
48 async fn get_conn(
49 &mut self,
50 connection: Resource<RedisConnection>,
51 ) -> Result<&mut MultiplexedConnection, Error> {
52 self.connections
53 .get_mut(connection.rep())
54 .ok_or(Error::Other(
55 "could not find connection for resource".into(),
56 ))
57 }
58}
59
60impl v2::Host for crate::InstanceState {
61 fn convert_error(&mut self, error: Error) -> Result<Error> {
62 Ok(error)
63 }
64}
65
66impl v2::HostConnection for crate::InstanceState {
67 #[instrument(name = "spin_outbound_redis.open_connection", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", db.address = Empty, server.port = Empty, db.namespace = Empty))]
68 async fn open(&mut self, address: String) -> Result<Resource<RedisConnection>, Error> {
69 self.otel.reparent_tracing_span();
70 if !self
71 .is_address_allowed(&address)
72 .await
73 .map_err(|e| v2::Error::Other(e.to_string()))?
74 {
75 return Err(Error::InvalidAddress);
76 }
77
78 self.establish_connection(address).await
79 }
80
81 #[instrument(name = "spin_outbound_redis.publish", skip(self, connection, payload), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("PUBLISH {}", channel)))]
82 async fn publish(
83 &mut self,
84 connection: Resource<RedisConnection>,
85 channel: String,
86 payload: Vec<u8>,
87 ) -> Result<(), Error> {
88 self.otel.reparent_tracing_span();
89
90 let conn = self.get_conn(connection).await.map_err(other_error)?;
91 let () = conn
94 .publish(&channel, &payload)
95 .await
96 .map_err(other_error)?;
97 Ok(())
98 }
99
100 #[instrument(name = "spin_outbound_redis.get", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("GET {}", key)))]
101 async fn get(
102 &mut self,
103 connection: Resource<RedisConnection>,
104 key: String,
105 ) -> Result<Option<Vec<u8>>, Error> {
106 self.otel.reparent_tracing_span();
107
108 let conn = self.get_conn(connection).await.map_err(other_error)?;
109 let value = conn
110 .get::<_, Option<Vec<u8>>>(&key)
111 .await
112 .map_err(other_error)?;
113
114 if std::mem::size_of::<Option<Vec<u8>>>() + value.as_ref().map(|v| v.len()).unwrap_or(0)
118 > MAX_HOST_BUFFERED_BYTES
119 {
120 Err(Error::Other(format!(
121 "query result exceeds limit of {MAX_HOST_BUFFERED_BYTES} bytes"
122 )))
123 } else {
124 Ok(value)
125 }
126 }
127
128 #[instrument(name = "spin_outbound_redis.set", skip(self, connection, value), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("SET {}", key)))]
129 async fn set(
130 &mut self,
131 connection: Resource<RedisConnection>,
132 key: String,
133 value: Vec<u8>,
134 ) -> Result<(), Error> {
135 self.otel.reparent_tracing_span();
136
137 let conn = self.get_conn(connection).await.map_err(other_error)?;
138 let () = conn.set(&key, &value).await.map_err(other_error)?;
141 Ok(())
142 }
143
144 #[instrument(name = "spin_outbound_redis.incr", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("INCRBY {} 1", key)))]
145 async fn incr(
146 &mut self,
147 connection: Resource<RedisConnection>,
148 key: String,
149 ) -> Result<i64, Error> {
150 self.otel.reparent_tracing_span();
151
152 let conn = self.get_conn(connection).await.map_err(other_error)?;
153 let value = conn.incr(&key, 1).await.map_err(other_error)?;
154 Ok(value)
155 }
156
157 #[instrument(name = "spin_outbound_redis.del", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("DEL {}", keys.join(" "))))]
158 async fn del(
159 &mut self,
160 connection: Resource<RedisConnection>,
161 keys: Vec<String>,
162 ) -> Result<u32, Error> {
163 self.otel.reparent_tracing_span();
164
165 let conn = self.get_conn(connection).await.map_err(other_error)?;
166 let value = conn.del(&keys).await.map_err(other_error)?;
167 Ok(value)
168 }
169
170 #[instrument(name = "spin_outbound_redis.sadd", skip(self, connection, values), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("SADD {} {}", key, values.join(" "))))]
171 async fn sadd(
172 &mut self,
173 connection: Resource<RedisConnection>,
174 key: String,
175 values: Vec<String>,
176 ) -> Result<u32, Error> {
177 self.otel.reparent_tracing_span();
178
179 let conn = self.get_conn(connection).await.map_err(other_error)?;
180 let value = conn.sadd(&key, &values).await.map_err(|e| {
181 if e.kind() == redis::ErrorKind::TypeError {
182 Error::TypeError
183 } else {
184 Error::Other(e.to_string())
185 }
186 })?;
187 Ok(value)
188 }
189
190 #[instrument(name = "spin_outbound_redis.smembers", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("SMEMBERS {}", key)))]
191 async fn smembers(
192 &mut self,
193 connection: Resource<RedisConnection>,
194 key: String,
195 ) -> Result<Vec<String>, Error> {
196 self.otel.reparent_tracing_span();
197
198 let conn = self.get_conn(connection).await.map_err(other_error)?;
199 let value = conn.smembers(&key).await.map_err(other_error)?;
200 Ok(value)
201 }
202
203 #[instrument(name = "spin_outbound_redis.srem", skip(self, connection, values), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("SREM {} {}", key, values.join(" "))))]
204 async fn srem(
205 &mut self,
206 connection: Resource<RedisConnection>,
207 key: String,
208 values: Vec<String>,
209 ) -> Result<u32, Error> {
210 self.otel.reparent_tracing_span();
211
212 let conn = self.get_conn(connection).await.map_err(other_error)?;
213 let value = conn.srem(&key, &values).await.map_err(other_error)?;
214 Ok(value)
215 }
216
217 #[instrument(name = "spin_outbound_redis.execute", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "redis", otel.name = format!("{}", command)))]
218 async fn execute(
219 &mut self,
220 connection: Resource<RedisConnection>,
221 command: String,
222 arguments: Vec<RedisParameter>,
223 ) -> Result<Vec<RedisResult>, Error> {
224 self.otel.reparent_tracing_span();
225
226 let conn = self.get_conn(connection).await?;
227 let mut cmd = redis::cmd(&command);
228 arguments.iter().for_each(|value| match value {
229 RedisParameter::Int64(v) => {
230 cmd.arg(v);
231 }
232 RedisParameter::Binary(v) => {
233 cmd.arg(v);
234 }
235 });
236
237 let results = cmd
238 .query_async::<RedisResults>(conn)
239 .await
240 .map(|values| values.0)
241 .map_err(other_error)?;
242
243 if std::mem::size_of::<Vec<RedisResult>>() + results.iter().map(memory_size).sum::<usize>()
247 > MAX_HOST_BUFFERED_BYTES
248 {
249 Err(Error::Other(format!(
250 "query result exceeds limit of {MAX_HOST_BUFFERED_BYTES} bytes"
251 )))
252 } else {
253 Ok(results)
254 }
255 }
256
257 async fn drop(&mut self, connection: Resource<RedisConnection>) -> anyhow::Result<()> {
258 self.connections.remove(connection.rep());
259 Ok(())
260 }
261}
262
263fn other_error(e: impl std::fmt::Display) -> Error {
264 Error::Other(e.to_string())
265}
266
267macro_rules! delegate {
269 ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{
270 if !$self.is_address_allowed(&$address).await.map_err(|_| v1::Error::Error)? {
271 return Err(v1::Error::Error);
272 }
273 let connection = match $self.establish_connection($address).await {
274 Ok(c) => c,
275 Err(_) => return Err(v1::Error::Error),
276 };
277 <Self as v2::HostConnection>::$name($self, connection, $($arg),*)
278 .await
279 .map_err(|_| v1::Error::Error)
280 }};
281}
282
283impl v1::Host for crate::InstanceState {
284 async fn publish(
285 &mut self,
286 address: String,
287 channel: String,
288 payload: Vec<u8>,
289 ) -> Result<(), v1::Error> {
290 delegate!(self.publish(address, channel, payload))
291 }
292
293 async fn get(&mut self, address: String, key: String) -> Result<Vec<u8>, v1::Error> {
294 delegate!(self.get(address, key)).map(|v| v.unwrap_or_default())
295 }
296
297 async fn set(&mut self, address: String, key: String, value: Vec<u8>) -> Result<(), v1::Error> {
298 delegate!(self.set(address, key, value))
299 }
300
301 async fn incr(&mut self, address: String, key: String) -> Result<i64, v1::Error> {
302 delegate!(self.incr(address, key))
303 }
304
305 async fn del(&mut self, address: String, keys: Vec<String>) -> Result<i64, v1::Error> {
306 delegate!(self.del(address, keys)).map(|v| v as i64)
307 }
308
309 async fn sadd(
310 &mut self,
311 address: String,
312 key: String,
313 values: Vec<String>,
314 ) -> Result<i64, v1::Error> {
315 delegate!(self.sadd(address, key, values)).map(|v| v as i64)
316 }
317
318 async fn smembers(&mut self, address: String, key: String) -> Result<Vec<String>, v1::Error> {
319 delegate!(self.smembers(address, key))
320 }
321
322 async fn srem(
323 &mut self,
324 address: String,
325 key: String,
326 values: Vec<String>,
327 ) -> Result<i64, v1::Error> {
328 delegate!(self.srem(address, key, values)).map(|v| v as i64)
329 }
330
331 async fn execute(
332 &mut self,
333 address: String,
334 command: String,
335 arguments: Vec<v1::RedisParameter>,
336 ) -> Result<Vec<v1::RedisResult>, v1::Error> {
337 delegate!(self.execute(
338 address,
339 command,
340 arguments.into_iter().map(Into::into).collect()
341 ))
342 .map(|v| v.into_iter().map(Into::into).collect())
343 }
344}
345
346impl redis_types::Host for crate::InstanceState {
347 fn convert_error(&mut self, error: redis_types::Error) -> Result<redis_types::Error> {
348 Ok(error)
349 }
350}
351
352struct RedisResults(Vec<RedisResult>);
353
354impl FromRedisValue for RedisResults {
355 fn from_redis_value(value: &Value) -> redis::RedisResult<Self> {
356 fn append(values: &mut Vec<RedisResult>, value: &Value) -> redis::RedisResult<()> {
357 match value {
358 Value::Nil => {
359 values.push(RedisResult::Nil);
360 Ok(())
361 }
362 Value::Int(v) => {
363 values.push(RedisResult::Int64(*v));
364 Ok(())
365 }
366 Value::BulkString(bytes) => {
367 values.push(RedisResult::Binary(bytes.to_owned()));
368 Ok(())
369 }
370 Value::SimpleString(s) => {
371 values.push(RedisResult::Status(s.to_owned()));
372 Ok(())
373 }
374 Value::Okay => {
375 values.push(RedisResult::Status("OK".to_string()));
376 Ok(())
377 }
378 Value::Map(_) => Err(redis::RedisError::from((
379 redis::ErrorKind::TypeError,
380 "Could not convert Redis response",
381 "Redis Map type is not supported".to_string(),
382 ))),
383 Value::Attribute { .. } => Err(redis::RedisError::from((
384 redis::ErrorKind::TypeError,
385 "Could not convert Redis response",
386 "Redis Attribute type is not supported".to_string(),
387 ))),
388 Value::Array(arr) | Value::Set(arr) => {
389 arr.iter().try_for_each(|value| append(values, value))
390 }
391 Value::Double(v) => {
392 values.push(RedisResult::Binary(v.to_string().into_bytes()));
393 Ok(())
394 }
395 Value::VerbatimString { .. } => Err(redis::RedisError::from((
396 redis::ErrorKind::TypeError,
397 "Could not convert Redis response",
398 "Redis string with format attribute is not supported".to_string(),
399 ))),
400 Value::Boolean(v) => {
401 values.push(RedisResult::Int64(if *v { 1 } else { 0 }));
402 Ok(())
403 }
404 Value::BigNumber(v) => {
405 values.push(RedisResult::Binary(v.to_string().as_bytes().to_owned()));
406 Ok(())
407 }
408 Value::Push { .. } => Err(redis::RedisError::from((
409 redis::ErrorKind::TypeError,
410 "Could not convert Redis response",
411 "Redis Pub/Sub types are not supported".to_string(),
412 ))),
413 Value::ServerError(err) => Err(redis::RedisError::from((
414 redis::ErrorKind::ResponseError,
415 "Server error",
416 format!("{err:?}"),
417 ))),
418 }
419 }
420 let mut values = Vec::new();
421 append(&mut values, value)?;
422 Ok(RedisResults(values))
423 }
424}
425
426fn memory_size(value: &RedisResult) -> usize {
427 match value {
428 RedisResult::Nil | RedisResult::Int64(_) => std::mem::size_of::<RedisResult>(),
429 RedisResult::Binary(b) => std::mem::size_of::<RedisResult>() + b.len(),
430 RedisResult::Status(s) => std::mem::size_of::<RedisResult>() + s.len(),
431 }
432}
433
434struct SpinDnsResolver(BlockedNetworks);
436
437impl AsyncDNSResolver for SpinDnsResolver {
438 fn resolve<'a, 'b: 'a>(
439 &'a self,
440 host: &'b str,
441 port: u16,
442 ) -> redis::RedisFuture<'a, Box<dyn Iterator<Item = std::net::SocketAddr> + Send + 'a>> {
443 Box::pin(async move {
444 let mut addrs = tokio::net::lookup_host((host, port))
445 .await?
446 .collect::<Vec<_>>();
447 let blocked_addrs = self.0.remove_blocked(&mut addrs);
449 if addrs.is_empty() && !blocked_addrs.is_empty() {
450 tracing::error!(
451 "error.type" = "destination_ip_prohibited",
452 ?blocked_addrs,
453 "all destination IP(s) prohibited by runtime config"
454 );
455 }
456 Ok(Box::new(addrs.into_iter()) as Box<dyn Iterator<Item = SocketAddr> + Send>)
457 })
458 }
459}