Skip to main content

spin_factor_outbound_mqtt/
host.rs

1use std::{sync::Arc, time::Duration};
2
3use anyhow::Result;
4use spin_core::{
5    async_trait,
6    wasmtime::component::{Accessor, Resource},
7};
8use spin_factor_otel::OtelFactorState;
9use spin_factor_outbound_networking::config::allowed_hosts::OutboundAllowedHosts;
10use spin_factor_outbound_networking::{ConnectionPermit, ConnectionSemaphore};
11use spin_world::spin::mqtt::mqtt as v3;
12use spin_world::v2::mqtt as v2;
13use tracing::{Level, instrument};
14
15use crate::{ClientCreator, allowed_hosts::AllowedHostChecker};
16
17pub struct InstanceState {
18    allowed_hosts: AllowedHostChecker,
19    connections: spin_resource_table::Table<(Arc<dyn MqttClient>, ConnectionPermit)>,
20    create_client: Arc<dyn ClientCreator>,
21    semaphore: ConnectionSemaphore,
22    otel: OtelFactorState,
23    max_payload_size_bytes: Option<usize>,
24}
25
26impl InstanceState {
27    pub fn new(
28        allowed_hosts: OutboundAllowedHosts,
29        create_client: Arc<dyn ClientCreator>,
30        semaphore: ConnectionSemaphore,
31        otel: OtelFactorState,
32        max_payload_size_bytes: Option<usize>,
33    ) -> Self {
34        Self {
35            allowed_hosts: AllowedHostChecker::new(allowed_hosts),
36            create_client,
37            connections: spin_resource_table::Table::new(1024),
38            semaphore,
39            otel,
40            max_payload_size_bytes,
41        }
42    }
43}
44
45#[async_trait]
46pub trait MqttClient: Send + Sync {
47    async fn publish_bytes(
48        &self,
49        topic: String,
50        qos: v3::Qos,
51        payload: Vec<u8>,
52    ) -> Result<(), v3::Error>;
53}
54
55impl InstanceState {
56    async fn is_address_allowed(&self, address: &str) -> Result<bool> {
57        self.allowed_hosts.is_address_allowed(address).await
58    }
59
60    async fn establish_connection(
61        &mut self,
62        address: String,
63        username: String,
64        password: String,
65        keep_alive_interval: Duration,
66    ) -> Result<Resource<v2::Connection>, v2::Error> {
67        let permit = self
68            .semaphore
69            .acquire()
70            .await
71            .map_err(|_| v2::Error::TooManyConnections)?;
72        let client =
73            (self.create_client).create(address, username, password, keep_alive_interval)?;
74        self.connections
75            .push((client, permit))
76            .map(Resource::new_own)
77            .map_err(|_| v2::Error::TooManyConnections)
78    }
79
80    fn get_conn(&self, connection: Resource<v2::Connection>) -> Result<&dyn MqttClient, v2::Error> {
81        self.connections
82            .get(connection.rep())
83            .ok_or(v2::Error::Other(
84                "could not find connection for resource".into(),
85            ))
86            .map(|(c, _permit)| c.as_ref())
87    }
88
89    fn get_conn_v3(
90        &self,
91        connection: Resource<v3::Connection>,
92    ) -> Result<Arc<dyn MqttClient>, v3::Error> {
93        self.connections
94            .get(connection.rep())
95            .map(|(c, _permit)| c.clone())
96            .ok_or(v3::Error::Other(
97                "could not find connection for resource".into(),
98            ))
99    }
100}
101
102impl v3::Host for InstanceState {
103    fn convert_error(&mut self, err: v3::Error) -> anyhow::Result<v3::Error> {
104        Ok(err)
105    }
106}
107
108impl v3::HostConnection for InstanceState {
109    async fn drop(&mut self, connection: Resource<v3::Connection>) -> anyhow::Result<()> {
110        self.connections.remove(connection.rep());
111        Ok(())
112    }
113}
114
115impl v3::HostConnectionWithStore for crate::MqttFactorData {
116    #[instrument(name = "spin_outbound_mqtt.open_connection", skip(accessor, password), err(level = Level::INFO), fields(otel.kind = "client"))]
117    async fn open<T: Send>(
118        accessor: &Accessor<T, Self>,
119        address: String,
120        username: String,
121        password: String,
122        keep_alive_interval_in_secs: u64,
123    ) -> Result<Resource<v3::Connection>, v3::Error> {
124        let (allowed_host_checker, create_client, semaphore) = accessor.with(|mut access| {
125            let host = access.get();
126            host.otel.reparent_tracing_span();
127            (
128                host.allowed_hosts.clone(),
129                host.create_client.clone(),
130                host.semaphore.clone(),
131            )
132        });
133
134        if !allowed_host_checker
135            .is_address_allowed(&address)
136            .await
137            .map_err(|e| v3::Error::Other(e.to_string()))?
138        {
139            return Err(v3::Error::ConnectionFailed(format!(
140                "address {address} is not permitted"
141            )));
142        }
143
144        let permit = semaphore
145            .acquire()
146            .await
147            .map_err(|_| v3::Error::TooManyConnections)?;
148
149        let client = create_client.create(
150            address,
151            username,
152            password,
153            Duration::from_secs(keep_alive_interval_in_secs),
154        )?;
155
156        accessor.with(|mut access| {
157            let host = access.get();
158            host.connections
159                .push((client, permit))
160                .map(Resource::new_own)
161                .map_err(|_| v3::Error::TooManyConnections)
162        })
163    }
164
165    #[instrument(name = "spin_outbound_mqtt.publish", skip(accessor, connection, payload), err(level = Level::INFO),
166        fields(otel.kind = "producer", otel.name = format!("{} publish", topic), messaging.operation = "publish",
167        messaging.system = "mqtt"))]
168    async fn publish<T: Send>(
169        accessor: &Accessor<T, Self>,
170        connection: Resource<v3::Connection>,
171        topic: String,
172        payload: v3::Payload,
173        qos: v3::Qos,
174    ) -> Result<(), v3::Error> {
175        let (conn, max_payload_size_bytes) = accessor.with(|mut access| {
176            let host = access.get();
177            host.otel.reparent_tracing_span();
178            host.get_conn_v3(connection)
179                .map(|c| (c, host.max_payload_size_bytes))
180        })?;
181
182        if let Some(limit) = max_payload_size_bytes
183            && payload.len() > limit
184        {
185            return Err(v3::Error::Other(format!(
186                "payload size {} exceeds the maximum allowed size of {} bytes",
187                payload.len(),
188                limit
189            )));
190        }
191
192        conn.publish_bytes(topic, qos, payload).await?;
193
194        Ok(())
195    }
196}
197
198impl v2::Host for InstanceState {
199    fn convert_error(&mut self, error: v2::Error) -> Result<v2::Error> {
200        Ok(error)
201    }
202}
203
204impl v2::HostConnection for InstanceState {
205    #[instrument(name = "spin_outbound_mqtt.open_connection", skip(self, password), err(level = Level::INFO), fields(otel.kind = "client"))]
206    async fn open(
207        &mut self,
208        address: String,
209        username: String,
210        password: String,
211        keep_alive_interval: u64,
212    ) -> Result<Resource<v2::Connection>, v2::Error> {
213        self.otel.reparent_tracing_span();
214
215        if !self
216            .is_address_allowed(&address)
217            .await
218            .map_err(|e| v2::Error::Other(e.to_string()))?
219        {
220            return Err(v2::Error::ConnectionFailed(format!(
221                "address {address} is not permitted"
222            )));
223        }
224        self.establish_connection(
225            address,
226            username,
227            password,
228            Duration::from_secs(keep_alive_interval),
229        )
230        .await
231    }
232
233    /// Publish a message to the MQTT broker.
234    ///
235    /// OTEL trace propagation is not directly supported in MQTT V3. You will need to embed the
236    /// current trace context into the payload yourself.
237    /// https://w3c.github.io/trace-context-mqtt/#mqtt-v3-recommendation.
238    #[instrument(name = "spin_outbound_mqtt.publish", skip(self, connection, payload), err(level = Level::INFO),
239        fields(otel.kind = "producer", otel.name = format!("{} publish", topic), messaging.operation = "publish",
240        messaging.system = "mqtt"))]
241    async fn publish(
242        &mut self,
243        connection: Resource<v2::Connection>,
244        topic: String,
245        payload: Vec<u8>,
246        qos: v2::Qos,
247    ) -> Result<(), v2::Error> {
248        self.otel.reparent_tracing_span();
249
250        if let Some(limit) = self.max_payload_size_bytes
251            && payload.len() > limit
252        {
253            return Err(v2::Error::Other(format!(
254                "payload size {} exceeds the maximum allowed size of {} bytes",
255                payload.len(),
256                limit
257            )));
258        }
259
260        let conn = self.get_conn(connection)?;
261
262        let qos = match qos {
263            v2::Qos::AtMostOnce => v3::Qos::AtMostOnce,
264            v2::Qos::AtLeastOnce => v3::Qos::AtLeastOnce,
265            v2::Qos::ExactlyOnce => v3::Qos::ExactlyOnce,
266        };
267
268        conn.publish_bytes(topic, qos, payload).await?;
269
270        Ok(())
271    }
272
273    async fn drop(&mut self, connection: Resource<v2::Connection>) -> anyhow::Result<()> {
274        self.connections.remove(connection.rep());
275        Ok(())
276    }
277}
278
279pub fn other_error_v3(e: impl std::fmt::Display) -> v3::Error {
280    v3::Error::Other(e.to_string())
281}