spin_factor_outbound_mqtt/
host.rs

1use std::{sync::Arc, time::Duration};
2
3use anyhow::Result;
4use spin_core::{async_trait, wasmtime::component::Resource};
5use spin_factor_outbound_networking::OutboundAllowedHosts;
6use spin_world::v2::mqtt::{self as v2, Connection, Error, Qos};
7use tracing::{instrument, Level};
8
9use crate::ClientCreator;
10
11pub struct InstanceState {
12    allowed_hosts: OutboundAllowedHosts,
13    connections: spin_resource_table::Table<Arc<dyn MqttClient>>,
14    create_client: Arc<dyn ClientCreator>,
15}
16
17impl InstanceState {
18    pub fn new(allowed_hosts: OutboundAllowedHosts, create_client: Arc<dyn ClientCreator>) -> Self {
19        Self {
20            allowed_hosts,
21            create_client,
22            connections: spin_resource_table::Table::new(1024),
23        }
24    }
25}
26
27#[async_trait]
28pub trait MqttClient: Send + Sync {
29    async fn publish_bytes(&self, topic: String, qos: Qos, payload: Vec<u8>) -> Result<(), Error>;
30}
31
32impl InstanceState {
33    async fn is_address_allowed(&self, address: &str) -> Result<bool> {
34        self.allowed_hosts.check_url(address, "mqtt").await
35    }
36
37    async fn establish_connection(
38        &mut self,
39        address: String,
40        username: String,
41        password: String,
42        keep_alive_interval: Duration,
43    ) -> Result<Resource<Connection>, Error> {
44        self.connections
45            .push((self.create_client).create(address, username, password, keep_alive_interval)?)
46            .map(Resource::new_own)
47            .map_err(|_| Error::TooManyConnections)
48    }
49
50    async fn get_conn(&self, connection: Resource<Connection>) -> Result<&dyn MqttClient, Error> {
51        self.connections
52            .get(connection.rep())
53            .ok_or(Error::Other(
54                "could not find connection for resource".into(),
55            ))
56            .map(|c| c.as_ref())
57    }
58}
59
60impl v2::Host for InstanceState {
61    fn convert_error(&mut self, error: Error) -> Result<Error> {
62        Ok(error)
63    }
64}
65
66impl v2::HostConnection for InstanceState {
67    #[instrument(name = "spin_outbound_mqtt.open_connection", skip(self, password), err(level = Level::INFO), fields(otel.kind = "client"))]
68    async fn open(
69        &mut self,
70        address: String,
71        username: String,
72        password: String,
73        keep_alive_interval: u64,
74    ) -> Result<Resource<Connection>, Error> {
75        if !self
76            .is_address_allowed(&address)
77            .await
78            .map_err(|e| v2::Error::Other(e.to_string()))?
79        {
80            return Err(v2::Error::ConnectionFailed(format!(
81                "address {address} is not permitted"
82            )));
83        }
84        self.establish_connection(
85            address,
86            username,
87            password,
88            Duration::from_secs(keep_alive_interval),
89        )
90        .await
91    }
92
93    /// Publish a message to the MQTT broker.
94    ///
95    /// OTEL trace propagation is not directly supported in MQTT V3. You will need to embed the
96    /// current trace context into the payload yourself.
97    /// https://w3c.github.io/trace-context-mqtt/#mqtt-v3-recommendation.
98    #[instrument(name = "spin_outbound_mqtt.publish", skip(self, connection, payload), err(level = Level::INFO),
99        fields(otel.kind = "producer", otel.name = format!("{} publish", topic), messaging.operation = "publish",
100        messaging.system = "mqtt"))]
101    async fn publish(
102        &mut self,
103        connection: Resource<Connection>,
104        topic: String,
105        payload: Vec<u8>,
106        qos: Qos,
107    ) -> Result<(), Error> {
108        let conn = self.get_conn(connection).await.map_err(other_error)?;
109
110        conn.publish_bytes(topic, qos, payload).await?;
111
112        Ok(())
113    }
114
115    async fn drop(&mut self, connection: Resource<Connection>) -> anyhow::Result<()> {
116        self.connections.remove(connection.rep());
117        Ok(())
118    }
119}
120
121pub fn other_error(e: impl std::fmt::Display) -> Error {
122    Error::Other(e.to_string())
123}