spin_factor_outbound_mqtt/
host.rs1use std::{sync::Arc, time::Duration};
2
3use anyhow::Result;
4use spin_core::{async_trait, wasmtime::component::Resource};
5use spin_factor_otel::OtelFactorState;
6use spin_factor_outbound_networking::config::allowed_hosts::OutboundAllowedHosts;
7use spin_world::v2::mqtt::{self as v2, Connection, Error, Qos};
8use tracing::{instrument, Level};
9
10use crate::ClientCreator;
11
12pub struct InstanceState {
13 allowed_hosts: OutboundAllowedHosts,
14 connections: spin_resource_table::Table<Arc<dyn MqttClient>>,
15 create_client: Arc<dyn ClientCreator>,
16 otel: OtelFactorState,
17}
18
19impl InstanceState {
20 pub fn new(
21 allowed_hosts: OutboundAllowedHosts,
22 create_client: Arc<dyn ClientCreator>,
23 otel: OtelFactorState,
24 ) -> Self {
25 Self {
26 allowed_hosts,
27 create_client,
28 connections: spin_resource_table::Table::new(1024),
29 otel,
30 }
31 }
32}
33
34#[async_trait]
35pub trait MqttClient: Send + Sync {
36 async fn publish_bytes(&self, topic: String, qos: Qos, payload: Vec<u8>) -> Result<(), Error>;
37}
38
39impl InstanceState {
40 async fn is_address_allowed(&self, address: &str) -> Result<bool> {
41 self.allowed_hosts.check_url(address, "mqtt").await
42 }
43
44 async fn establish_connection(
45 &mut self,
46 address: String,
47 username: String,
48 password: String,
49 keep_alive_interval: Duration,
50 ) -> Result<Resource<Connection>, Error> {
51 self.connections
52 .push((self.create_client).create(address, username, password, keep_alive_interval)?)
53 .map(Resource::new_own)
54 .map_err(|_| Error::TooManyConnections)
55 }
56
57 async fn get_conn(&self, connection: Resource<Connection>) -> Result<&dyn MqttClient, Error> {
58 self.connections
59 .get(connection.rep())
60 .ok_or(Error::Other(
61 "could not find connection for resource".into(),
62 ))
63 .map(|c| c.as_ref())
64 }
65}
66
67impl v2::Host for InstanceState {
68 fn convert_error(&mut self, error: Error) -> Result<Error> {
69 Ok(error)
70 }
71}
72
73impl v2::HostConnection for InstanceState {
74 #[instrument(name = "spin_outbound_mqtt.open_connection", skip(self, password), err(level = Level::INFO), fields(otel.kind = "client"))]
75 async fn open(
76 &mut self,
77 address: String,
78 username: String,
79 password: String,
80 keep_alive_interval: u64,
81 ) -> Result<Resource<Connection>, Error> {
82 self.otel.reparent_tracing_span();
83
84 if !self
85 .is_address_allowed(&address)
86 .await
87 .map_err(|e| v2::Error::Other(e.to_string()))?
88 {
89 return Err(v2::Error::ConnectionFailed(format!(
90 "address {address} is not permitted"
91 )));
92 }
93 self.establish_connection(
94 address,
95 username,
96 password,
97 Duration::from_secs(keep_alive_interval),
98 )
99 .await
100 }
101
102 #[instrument(name = "spin_outbound_mqtt.publish", skip(self, connection, payload), err(level = Level::INFO),
108 fields(otel.kind = "producer", otel.name = format!("{} publish", topic), messaging.operation = "publish",
109 messaging.system = "mqtt"))]
110 async fn publish(
111 &mut self,
112 connection: Resource<Connection>,
113 topic: String,
114 payload: Vec<u8>,
115 qos: Qos,
116 ) -> Result<(), Error> {
117 self.otel.reparent_tracing_span();
118
119 let conn = self.get_conn(connection).await.map_err(other_error)?;
120
121 conn.publish_bytes(topic, qos, payload).await?;
122
123 Ok(())
124 }
125
126 async fn drop(&mut self, connection: Resource<Connection>) -> anyhow::Result<()> {
127 self.connections.remove(connection.rep());
128 Ok(())
129 }
130}
131
132pub fn other_error(e: impl std::fmt::Display) -> Error {
133 Error::Other(e.to_string())
134}