spin_factor_outbound_mqtt/
host.rs1use std::{sync::Arc, time::Duration};
2
3use anyhow::Result;
4use spin_core::{
5 async_trait,
6 wasmtime::component::{Accessor, Resource},
7};
8use spin_factor_otel::OtelFactorState;
9use spin_factor_outbound_networking::config::allowed_hosts::OutboundAllowedHosts;
10use spin_world::spin::mqtt::mqtt as v3;
11use spin_world::v2::mqtt as v2;
12use tracing::{Level, instrument};
13
14use crate::{ClientCreator, allowed_hosts::AllowedHostChecker};
15
16pub struct InstanceState {
17 allowed_hosts: AllowedHostChecker,
18 connections: spin_resource_table::Table<Arc<dyn MqttClient>>,
19 create_client: Arc<dyn ClientCreator>,
20 otel: OtelFactorState,
21}
22
23impl InstanceState {
24 pub fn new(
25 allowed_hosts: OutboundAllowedHosts,
26 create_client: Arc<dyn ClientCreator>,
27 otel: OtelFactorState,
28 ) -> Self {
29 Self {
30 allowed_hosts: AllowedHostChecker::new(allowed_hosts),
31 create_client,
32 connections: spin_resource_table::Table::new(1024),
33 otel,
34 }
35 }
36}
37
38#[async_trait]
39pub trait MqttClient: Send + Sync {
40 async fn publish_bytes(
41 &self,
42 topic: String,
43 qos: v3::Qos,
44 payload: Vec<u8>,
45 ) -> Result<(), v3::Error>;
46}
47
48impl InstanceState {
49 async fn is_address_allowed(&self, address: &str) -> Result<bool> {
50 self.allowed_hosts.is_address_allowed(address).await
51 }
52
53 async fn establish_connection(
54 &mut self,
55 address: String,
56 username: String,
57 password: String,
58 keep_alive_interval: Duration,
59 ) -> Result<Resource<v2::Connection>, v2::Error> {
60 self.connections
61 .push((self.create_client).create(address, username, password, keep_alive_interval)?)
62 .map(Resource::new_own)
63 .map_err(|_| v2::Error::TooManyConnections)
64 }
65
66 fn get_conn(&self, connection: Resource<v2::Connection>) -> Result<&dyn MqttClient, v2::Error> {
67 self.connections
68 .get(connection.rep())
69 .ok_or(v2::Error::Other(
70 "could not find connection for resource".into(),
71 ))
72 .map(|c| c.as_ref())
73 }
74
75 fn get_conn_v3(
76 &self,
77 connection: Resource<v3::Connection>,
78 ) -> Result<Arc<dyn MqttClient>, v3::Error> {
79 self.connections
80 .get(connection.rep())
81 .cloned()
82 .ok_or(v3::Error::Other(
83 "could not find connection for resource".into(),
84 ))
85 }
86}
87
88impl v3::Host for InstanceState {
89 fn convert_error(&mut self, err: v3::Error) -> anyhow::Result<v3::Error> {
90 Ok(err)
91 }
92}
93
94impl v3::HostConnection for InstanceState {
95 async fn drop(&mut self, connection: Resource<v3::Connection>) -> anyhow::Result<()> {
96 self.connections.remove(connection.rep());
97 Ok(())
98 }
99}
100
101impl v3::HostConnectionWithStore for crate::MqttFactorData {
102 #[instrument(name = "spin_outbound_mqtt.open_connection", skip(accessor, password), err(level = Level::INFO), fields(otel.kind = "client"))]
103 async fn open<T: Send>(
104 accessor: &Accessor<T, Self>,
105 address: String,
106 username: String,
107 password: String,
108 keep_alive_interval_in_secs: u64,
109 ) -> Result<Resource<v3::Connection>, v3::Error> {
110 let (allowed_host_checker, create_client) = accessor.with(|mut access| {
111 let host = access.get();
112 host.otel.reparent_tracing_span();
113 (host.allowed_hosts.clone(), host.create_client.clone())
114 });
115
116 if !allowed_host_checker
117 .is_address_allowed(&address)
118 .await
119 .map_err(|e| v3::Error::Other(e.to_string()))?
120 {
121 return Err(v3::Error::ConnectionFailed(format!(
122 "address {address} is not permitted"
123 )));
124 }
125
126 let client = create_client
127 .create(
128 address,
129 username,
130 password,
131 Duration::from_secs(keep_alive_interval_in_secs),
132 )
133 .unwrap();
134
135 accessor.with(|mut access| {
136 let host = access.get();
137 host.connections
138 .push(client)
139 .map(Resource::new_own)
140 .map_err(|_| v3::Error::TooManyConnections)
141 })
142 }
143
144 #[instrument(name = "spin_outbound_mqtt.publish", skip(accessor, connection, payload), err(level = Level::INFO),
145 fields(otel.kind = "producer", otel.name = format!("{} publish", topic), messaging.operation = "publish",
146 messaging.system = "mqtt"))]
147 async fn publish<T: Send>(
148 accessor: &Accessor<T, Self>,
149 connection: Resource<v3::Connection>,
150 topic: String,
151 payload: v3::Payload,
152 qos: v3::Qos,
153 ) -> Result<(), v3::Error> {
154 let conn = accessor.with(|mut access| {
155 let host = access.get();
156 host.otel.reparent_tracing_span();
157 host.get_conn_v3(connection)
158 })?;
159
160 conn.publish_bytes(topic, qos, payload).await?;
161
162 Ok(())
163 }
164}
165
166impl v2::Host for InstanceState {
167 fn convert_error(&mut self, error: v2::Error) -> Result<v2::Error> {
168 Ok(error)
169 }
170}
171
172impl v2::HostConnection for InstanceState {
173 #[instrument(name = "spin_outbound_mqtt.open_connection", skip(self, password), err(level = Level::INFO), fields(otel.kind = "client"))]
174 async fn open(
175 &mut self,
176 address: String,
177 username: String,
178 password: String,
179 keep_alive_interval: u64,
180 ) -> Result<Resource<v2::Connection>, v2::Error> {
181 self.otel.reparent_tracing_span();
182
183 if !self
184 .is_address_allowed(&address)
185 .await
186 .map_err(|e| v2::Error::Other(e.to_string()))?
187 {
188 return Err(v2::Error::ConnectionFailed(format!(
189 "address {address} is not permitted"
190 )));
191 }
192 self.establish_connection(
193 address,
194 username,
195 password,
196 Duration::from_secs(keep_alive_interval),
197 )
198 .await
199 }
200
201 #[instrument(name = "spin_outbound_mqtt.publish", skip(self, connection, payload), err(level = Level::INFO),
207 fields(otel.kind = "producer", otel.name = format!("{} publish", topic), messaging.operation = "publish",
208 messaging.system = "mqtt"))]
209 async fn publish(
210 &mut self,
211 connection: Resource<v2::Connection>,
212 topic: String,
213 payload: Vec<u8>,
214 qos: v2::Qos,
215 ) -> Result<(), v2::Error> {
216 self.otel.reparent_tracing_span();
217
218 let conn = self.get_conn(connection)?;
219
220 let qos = match qos {
221 v2::Qos::AtMostOnce => v3::Qos::AtMostOnce,
222 v2::Qos::AtLeastOnce => v3::Qos::AtLeastOnce,
223 v2::Qos::ExactlyOnce => v3::Qos::ExactlyOnce,
224 };
225
226 conn.publish_bytes(topic, qos, payload).await?;
227
228 Ok(())
229 }
230
231 async fn drop(&mut self, connection: Resource<v2::Connection>) -> anyhow::Result<()> {
232 self.connections.remove(connection.rep());
233 Ok(())
234 }
235}
236
237pub fn other_error_v3(e: impl std::fmt::Display) -> v3::Error {
238 v3::Error::Other(e.to_string())
239}