Skip to main content

spin_factor_outbound_mqtt/
host.rs

1use std::{sync::Arc, time::Duration};
2
3use anyhow::Result;
4use spin_core::{
5    async_trait,
6    wasmtime::component::{Accessor, Resource},
7};
8use spin_factor_otel::OtelFactorState;
9use spin_factor_outbound_networking::config::allowed_hosts::OutboundAllowedHosts;
10use spin_world::spin::mqtt::mqtt as v3;
11use spin_world::v2::mqtt as v2;
12use tracing::{Level, instrument};
13
14use crate::{ClientCreator, allowed_hosts::AllowedHostChecker};
15
16pub struct InstanceState {
17    allowed_hosts: AllowedHostChecker,
18    connections: spin_resource_table::Table<Arc<dyn MqttClient>>,
19    create_client: Arc<dyn ClientCreator>,
20    otel: OtelFactorState,
21}
22
23impl InstanceState {
24    pub fn new(
25        allowed_hosts: OutboundAllowedHosts,
26        create_client: Arc<dyn ClientCreator>,
27        otel: OtelFactorState,
28    ) -> Self {
29        Self {
30            allowed_hosts: AllowedHostChecker::new(allowed_hosts),
31            create_client,
32            connections: spin_resource_table::Table::new(1024),
33            otel,
34        }
35    }
36}
37
38#[async_trait]
39pub trait MqttClient: Send + Sync {
40    async fn publish_bytes(
41        &self,
42        topic: String,
43        qos: v3::Qos,
44        payload: Vec<u8>,
45    ) -> Result<(), v3::Error>;
46}
47
48impl InstanceState {
49    async fn is_address_allowed(&self, address: &str) -> Result<bool> {
50        self.allowed_hosts.is_address_allowed(address).await
51    }
52
53    async fn establish_connection(
54        &mut self,
55        address: String,
56        username: String,
57        password: String,
58        keep_alive_interval: Duration,
59    ) -> Result<Resource<v2::Connection>, v2::Error> {
60        self.connections
61            .push((self.create_client).create(address, username, password, keep_alive_interval)?)
62            .map(Resource::new_own)
63            .map_err(|_| v2::Error::TooManyConnections)
64    }
65
66    fn get_conn(&self, connection: Resource<v2::Connection>) -> Result<&dyn MqttClient, v2::Error> {
67        self.connections
68            .get(connection.rep())
69            .ok_or(v2::Error::Other(
70                "could not find connection for resource".into(),
71            ))
72            .map(|c| c.as_ref())
73    }
74
75    fn get_conn_v3(
76        &self,
77        connection: Resource<v3::Connection>,
78    ) -> Result<Arc<dyn MqttClient>, v3::Error> {
79        self.connections
80            .get(connection.rep())
81            .cloned()
82            .ok_or(v3::Error::Other(
83                "could not find connection for resource".into(),
84            ))
85    }
86}
87
88impl v3::Host for InstanceState {
89    fn convert_error(&mut self, err: v3::Error) -> anyhow::Result<v3::Error> {
90        Ok(err)
91    }
92}
93
94impl v3::HostConnection for InstanceState {
95    async fn drop(&mut self, connection: Resource<v3::Connection>) -> anyhow::Result<()> {
96        self.connections.remove(connection.rep());
97        Ok(())
98    }
99}
100
101impl v3::HostConnectionWithStore for crate::MqttFactorData {
102    #[instrument(name = "spin_outbound_mqtt.open_connection", skip(accessor, password), err(level = Level::INFO), fields(otel.kind = "client"))]
103    async fn open<T: Send>(
104        accessor: &Accessor<T, Self>,
105        address: String,
106        username: String,
107        password: String,
108        keep_alive_interval_in_secs: u64,
109    ) -> Result<Resource<v3::Connection>, v3::Error> {
110        let (allowed_host_checker, create_client) = accessor.with(|mut access| {
111            let host = access.get();
112            host.otel.reparent_tracing_span();
113            (host.allowed_hosts.clone(), host.create_client.clone())
114        });
115
116        if !allowed_host_checker
117            .is_address_allowed(&address)
118            .await
119            .map_err(|e| v3::Error::Other(e.to_string()))?
120        {
121            return Err(v3::Error::ConnectionFailed(format!(
122                "address {address} is not permitted"
123            )));
124        }
125
126        let client = create_client
127            .create(
128                address,
129                username,
130                password,
131                Duration::from_secs(keep_alive_interval_in_secs),
132            )
133            .unwrap();
134
135        accessor.with(|mut access| {
136            let host = access.get();
137            host.connections
138                .push(client)
139                .map(Resource::new_own)
140                .map_err(|_| v3::Error::TooManyConnections)
141        })
142    }
143
144    #[instrument(name = "spin_outbound_mqtt.publish", skip(accessor, connection, payload), err(level = Level::INFO),
145        fields(otel.kind = "producer", otel.name = format!("{} publish", topic), messaging.operation = "publish",
146        messaging.system = "mqtt"))]
147    async fn publish<T: Send>(
148        accessor: &Accessor<T, Self>,
149        connection: Resource<v3::Connection>,
150        topic: String,
151        payload: v3::Payload,
152        qos: v3::Qos,
153    ) -> Result<(), v3::Error> {
154        let conn = accessor.with(|mut access| {
155            let host = access.get();
156            host.otel.reparent_tracing_span();
157            host.get_conn_v3(connection)
158        })?;
159
160        conn.publish_bytes(topic, qos, payload).await?;
161
162        Ok(())
163    }
164}
165
166impl v2::Host for InstanceState {
167    fn convert_error(&mut self, error: v2::Error) -> Result<v2::Error> {
168        Ok(error)
169    }
170}
171
172impl v2::HostConnection for InstanceState {
173    #[instrument(name = "spin_outbound_mqtt.open_connection", skip(self, password), err(level = Level::INFO), fields(otel.kind = "client"))]
174    async fn open(
175        &mut self,
176        address: String,
177        username: String,
178        password: String,
179        keep_alive_interval: u64,
180    ) -> Result<Resource<v2::Connection>, v2::Error> {
181        self.otel.reparent_tracing_span();
182
183        if !self
184            .is_address_allowed(&address)
185            .await
186            .map_err(|e| v2::Error::Other(e.to_string()))?
187        {
188            return Err(v2::Error::ConnectionFailed(format!(
189                "address {address} is not permitted"
190            )));
191        }
192        self.establish_connection(
193            address,
194            username,
195            password,
196            Duration::from_secs(keep_alive_interval),
197        )
198        .await
199    }
200
201    /// Publish a message to the MQTT broker.
202    ///
203    /// OTEL trace propagation is not directly supported in MQTT V3. You will need to embed the
204    /// current trace context into the payload yourself.
205    /// https://w3c.github.io/trace-context-mqtt/#mqtt-v3-recommendation.
206    #[instrument(name = "spin_outbound_mqtt.publish", skip(self, connection, payload), err(level = Level::INFO),
207        fields(otel.kind = "producer", otel.name = format!("{} publish", topic), messaging.operation = "publish",
208        messaging.system = "mqtt"))]
209    async fn publish(
210        &mut self,
211        connection: Resource<v2::Connection>,
212        topic: String,
213        payload: Vec<u8>,
214        qos: v2::Qos,
215    ) -> Result<(), v2::Error> {
216        self.otel.reparent_tracing_span();
217
218        let conn = self.get_conn(connection)?;
219
220        let qos = match qos {
221            v2::Qos::AtMostOnce => v3::Qos::AtMostOnce,
222            v2::Qos::AtLeastOnce => v3::Qos::AtLeastOnce,
223            v2::Qos::ExactlyOnce => v3::Qos::ExactlyOnce,
224        };
225
226        conn.publish_bytes(topic, qos, payload).await?;
227
228        Ok(())
229    }
230
231    async fn drop(&mut self, connection: Resource<v2::Connection>) -> anyhow::Result<()> {
232        self.connections.remove(connection.rep());
233        Ok(())
234    }
235}
236
237pub fn other_error_v3(e: impl std::fmt::Display) -> v3::Error {
238    v3::Error::Other(e.to_string())
239}