Skip to main content

spin_factor_outbound_mqtt/
host.rs

1use std::{sync::Arc, time::Duration};
2
3use anyhow::Result;
4use spin_core::{async_trait, wasmtime::component::Resource};
5use spin_factor_otel::OtelFactorState;
6use spin_factor_outbound_networking::config::allowed_hosts::OutboundAllowedHosts;
7use spin_world::v2::mqtt::{self as v2, Connection, Error, Qos};
8use tracing::{instrument, Level};
9
10use crate::ClientCreator;
11
12pub struct InstanceState {
13    allowed_hosts: OutboundAllowedHosts,
14    connections: spin_resource_table::Table<Arc<dyn MqttClient>>,
15    create_client: Arc<dyn ClientCreator>,
16    otel: OtelFactorState,
17}
18
19impl InstanceState {
20    pub fn new(
21        allowed_hosts: OutboundAllowedHosts,
22        create_client: Arc<dyn ClientCreator>,
23        otel: OtelFactorState,
24    ) -> Self {
25        Self {
26            allowed_hosts,
27            create_client,
28            connections: spin_resource_table::Table::new(1024),
29            otel,
30        }
31    }
32}
33
34#[async_trait]
35pub trait MqttClient: Send + Sync {
36    async fn publish_bytes(&self, topic: String, qos: Qos, payload: Vec<u8>) -> Result<(), Error>;
37}
38
39impl InstanceState {
40    async fn is_address_allowed(&self, address: &str) -> Result<bool> {
41        self.allowed_hosts.check_url(address, "mqtt").await
42    }
43
44    async fn establish_connection(
45        &mut self,
46        address: String,
47        username: String,
48        password: String,
49        keep_alive_interval: Duration,
50    ) -> Result<Resource<Connection>, Error> {
51        self.connections
52            .push((self.create_client).create(address, username, password, keep_alive_interval)?)
53            .map(Resource::new_own)
54            .map_err(|_| Error::TooManyConnections)
55    }
56
57    async fn get_conn(&self, connection: Resource<Connection>) -> Result<&dyn MqttClient, Error> {
58        self.connections
59            .get(connection.rep())
60            .ok_or(Error::Other(
61                "could not find connection for resource".into(),
62            ))
63            .map(|c| c.as_ref())
64    }
65}
66
67impl v2::Host for InstanceState {
68    fn convert_error(&mut self, error: Error) -> Result<Error> {
69        Ok(error)
70    }
71}
72
73impl v2::HostConnection for InstanceState {
74    #[instrument(name = "spin_outbound_mqtt.open_connection", skip(self, password), err(level = Level::INFO), fields(otel.kind = "client"))]
75    async fn open(
76        &mut self,
77        address: String,
78        username: String,
79        password: String,
80        keep_alive_interval: u64,
81    ) -> Result<Resource<Connection>, Error> {
82        self.otel.reparent_tracing_span();
83
84        if !self
85            .is_address_allowed(&address)
86            .await
87            .map_err(|e| v2::Error::Other(e.to_string()))?
88        {
89            return Err(v2::Error::ConnectionFailed(format!(
90                "address {address} is not permitted"
91            )));
92        }
93        self.establish_connection(
94            address,
95            username,
96            password,
97            Duration::from_secs(keep_alive_interval),
98        )
99        .await
100    }
101
102    /// Publish a message to the MQTT broker.
103    ///
104    /// OTEL trace propagation is not directly supported in MQTT V3. You will need to embed the
105    /// current trace context into the payload yourself.
106    /// https://w3c.github.io/trace-context-mqtt/#mqtt-v3-recommendation.
107    #[instrument(name = "spin_outbound_mqtt.publish", skip(self, connection, payload), err(level = Level::INFO),
108        fields(otel.kind = "producer", otel.name = format!("{} publish", topic), messaging.operation = "publish",
109        messaging.system = "mqtt"))]
110    async fn publish(
111        &mut self,
112        connection: Resource<Connection>,
113        topic: String,
114        payload: Vec<u8>,
115        qos: Qos,
116    ) -> Result<(), Error> {
117        self.otel.reparent_tracing_span();
118
119        let conn = self.get_conn(connection).await.map_err(other_error)?;
120
121        conn.publish_bytes(topic, qos, payload).await?;
122
123        Ok(())
124    }
125
126    async fn drop(&mut self, connection: Resource<Connection>) -> anyhow::Result<()> {
127        self.connections.remove(connection.rep());
128        Ok(())
129    }
130}
131
132pub fn other_error(e: impl std::fmt::Display) -> Error {
133    Error::Other(e.to_string())
134}