spin_factor_outbound_mqtt/
host.rs1use std::{sync::Arc, time::Duration};
2
3use anyhow::Result;
4use spin_core::{
5 async_trait,
6 wasmtime::component::{Accessor, Resource},
7};
8use spin_factor_otel::OtelFactorState;
9use spin_factor_outbound_networking::config::allowed_hosts::OutboundAllowedHosts;
10use spin_factor_outbound_networking::{ConnectionPermit, ConnectionSemaphore};
11use spin_world::spin::mqtt::mqtt as v3;
12use spin_world::v2::mqtt as v2;
13use tracing::{Level, instrument};
14
15use crate::{ClientCreator, allowed_hosts::AllowedHostChecker};
16
17pub struct InstanceState {
18 allowed_hosts: AllowedHostChecker,
19 connections: spin_resource_table::Table<(Arc<dyn MqttClient>, ConnectionPermit)>,
20 create_client: Arc<dyn ClientCreator>,
21 semaphore: ConnectionSemaphore,
22 otel: OtelFactorState,
23 max_payload_size_bytes: Option<usize>,
24}
25
26impl InstanceState {
27 pub fn new(
28 allowed_hosts: OutboundAllowedHosts,
29 create_client: Arc<dyn ClientCreator>,
30 semaphore: ConnectionSemaphore,
31 otel: OtelFactorState,
32 max_payload_size_bytes: Option<usize>,
33 ) -> Self {
34 Self {
35 allowed_hosts: AllowedHostChecker::new(allowed_hosts),
36 create_client,
37 connections: spin_resource_table::Table::new(1024),
38 semaphore,
39 otel,
40 max_payload_size_bytes,
41 }
42 }
43}
44
45#[async_trait]
46pub trait MqttClient: Send + Sync {
47 async fn publish_bytes(
48 &self,
49 topic: String,
50 qos: v3::Qos,
51 payload: Vec<u8>,
52 ) -> Result<(), v3::Error>;
53}
54
55impl InstanceState {
56 async fn is_address_allowed(&self, address: &str) -> Result<bool> {
57 self.allowed_hosts.is_address_allowed(address).await
58 }
59
60 async fn establish_connection(
61 &mut self,
62 address: String,
63 username: String,
64 password: String,
65 keep_alive_interval: Duration,
66 ) -> Result<Resource<v2::Connection>, v2::Error> {
67 let permit = self
68 .semaphore
69 .acquire()
70 .await
71 .map_err(|_| v2::Error::TooManyConnections)?;
72 let client =
73 (self.create_client).create(address, username, password, keep_alive_interval)?;
74 self.connections
75 .push((client, permit))
76 .map(Resource::new_own)
77 .map_err(|_| v2::Error::TooManyConnections)
78 }
79
80 fn get_conn(&self, connection: Resource<v2::Connection>) -> Result<&dyn MqttClient, v2::Error> {
81 self.connections
82 .get(connection.rep())
83 .ok_or(v2::Error::Other(
84 "could not find connection for resource".into(),
85 ))
86 .map(|(c, _permit)| c.as_ref())
87 }
88
89 fn get_conn_v3(
90 &self,
91 connection: Resource<v3::Connection>,
92 ) -> Result<Arc<dyn MqttClient>, v3::Error> {
93 self.connections
94 .get(connection.rep())
95 .map(|(c, _permit)| c.clone())
96 .ok_or(v3::Error::Other(
97 "could not find connection for resource".into(),
98 ))
99 }
100}
101
102impl v3::Host for InstanceState {
103 fn convert_error(&mut self, err: v3::Error) -> anyhow::Result<v3::Error> {
104 Ok(err)
105 }
106}
107
108impl v3::HostConnection for InstanceState {
109 async fn drop(&mut self, connection: Resource<v3::Connection>) -> anyhow::Result<()> {
110 self.connections.remove(connection.rep());
111 Ok(())
112 }
113}
114
115impl v3::HostConnectionWithStore for crate::MqttFactorData {
116 #[instrument(name = "spin_outbound_mqtt.open_connection", skip(accessor, password), err(level = Level::INFO), fields(otel.kind = "client"))]
117 async fn open<T: Send>(
118 accessor: &Accessor<T, Self>,
119 address: String,
120 username: String,
121 password: String,
122 keep_alive_interval_in_secs: u64,
123 ) -> Result<Resource<v3::Connection>, v3::Error> {
124 let (allowed_host_checker, create_client, semaphore) = accessor.with(|mut access| {
125 let host = access.get();
126 host.otel.reparent_tracing_span();
127 (
128 host.allowed_hosts.clone(),
129 host.create_client.clone(),
130 host.semaphore.clone(),
131 )
132 });
133
134 if !allowed_host_checker
135 .is_address_allowed(&address)
136 .await
137 .map_err(|e| v3::Error::Other(e.to_string()))?
138 {
139 return Err(v3::Error::ConnectionFailed(format!(
140 "address {address} is not permitted"
141 )));
142 }
143
144 let permit = semaphore
145 .acquire()
146 .await
147 .map_err(|_| v3::Error::TooManyConnections)?;
148
149 let client = create_client.create(
150 address,
151 username,
152 password,
153 Duration::from_secs(keep_alive_interval_in_secs),
154 )?;
155
156 accessor.with(|mut access| {
157 let host = access.get();
158 host.connections
159 .push((client, permit))
160 .map(Resource::new_own)
161 .map_err(|_| v3::Error::TooManyConnections)
162 })
163 }
164
165 #[instrument(name = "spin_outbound_mqtt.publish", skip(accessor, connection, payload), err(level = Level::INFO),
166 fields(otel.kind = "producer", otel.name = format!("{} publish", topic), messaging.operation = "publish",
167 messaging.system = "mqtt"))]
168 async fn publish<T: Send>(
169 accessor: &Accessor<T, Self>,
170 connection: Resource<v3::Connection>,
171 topic: String,
172 payload: v3::Payload,
173 qos: v3::Qos,
174 ) -> Result<(), v3::Error> {
175 let (conn, max_payload_size_bytes) = accessor.with(|mut access| {
176 let host = access.get();
177 host.otel.reparent_tracing_span();
178 host.get_conn_v3(connection)
179 .map(|c| (c, host.max_payload_size_bytes))
180 })?;
181
182 if let Some(limit) = max_payload_size_bytes
183 && payload.len() > limit
184 {
185 return Err(v3::Error::Other(format!(
186 "payload size {} exceeds the maximum allowed size of {} bytes",
187 payload.len(),
188 limit
189 )));
190 }
191
192 conn.publish_bytes(topic, qos, payload).await?;
193
194 Ok(())
195 }
196}
197
198impl v2::Host for InstanceState {
199 fn convert_error(&mut self, error: v2::Error) -> Result<v2::Error> {
200 Ok(error)
201 }
202}
203
204impl v2::HostConnection for InstanceState {
205 #[instrument(name = "spin_outbound_mqtt.open_connection", skip(self, password), err(level = Level::INFO), fields(otel.kind = "client"))]
206 async fn open(
207 &mut self,
208 address: String,
209 username: String,
210 password: String,
211 keep_alive_interval: u64,
212 ) -> Result<Resource<v2::Connection>, v2::Error> {
213 self.otel.reparent_tracing_span();
214
215 if !self
216 .is_address_allowed(&address)
217 .await
218 .map_err(|e| v2::Error::Other(e.to_string()))?
219 {
220 return Err(v2::Error::ConnectionFailed(format!(
221 "address {address} is not permitted"
222 )));
223 }
224 self.establish_connection(
225 address,
226 username,
227 password,
228 Duration::from_secs(keep_alive_interval),
229 )
230 .await
231 }
232
233 #[instrument(name = "spin_outbound_mqtt.publish", skip(self, connection, payload), err(level = Level::INFO),
239 fields(otel.kind = "producer", otel.name = format!("{} publish", topic), messaging.operation = "publish",
240 messaging.system = "mqtt"))]
241 async fn publish(
242 &mut self,
243 connection: Resource<v2::Connection>,
244 topic: String,
245 payload: Vec<u8>,
246 qos: v2::Qos,
247 ) -> Result<(), v2::Error> {
248 self.otel.reparent_tracing_span();
249
250 if let Some(limit) = self.max_payload_size_bytes
251 && payload.len() > limit
252 {
253 return Err(v2::Error::Other(format!(
254 "payload size {} exceeds the maximum allowed size of {} bytes",
255 payload.len(),
256 limit
257 )));
258 }
259
260 let conn = self.get_conn(connection)?;
261
262 let qos = match qos {
263 v2::Qos::AtMostOnce => v3::Qos::AtMostOnce,
264 v2::Qos::AtLeastOnce => v3::Qos::AtLeastOnce,
265 v2::Qos::ExactlyOnce => v3::Qos::ExactlyOnce,
266 };
267
268 conn.publish_bytes(topic, qos, payload).await?;
269
270 Ok(())
271 }
272
273 async fn drop(&mut self, connection: Resource<v2::Connection>) -> anyhow::Result<()> {
274 self.connections.remove(connection.rep());
275 Ok(())
276 }
277}
278
279pub fn other_error_v3(e: impl std::fmt::Display) -> v3::Error {
280 v3::Error::Other(e.to_string())
281}