spin_factor_outbound_mqtt/
lib.rs1mod allowed_hosts;
2mod host;
3pub mod runtime_config;
4
5use std::sync::Arc;
6use std::time::Duration;
7
8use host::InstanceState;
9use rumqttc::{AsyncClient, Event, Incoming, Outgoing, QoS};
10use spin_core::async_trait;
11use spin_factor_otel::OtelFactorState;
12use spin_factor_outbound_networking::{
13 ConnectionSemaphore, OutboundNetworkingFactor, build_connection_semaphore,
14};
15use spin_factors::{
16 ConfigureAppContext, Factor, FactorData, PrepareContext, RuntimeFactors, SelfInstanceBuilder,
17 anyhow,
18};
19use spin_world::spin::mqtt::mqtt as v3;
20use spin_world::v2::mqtt as v2;
21use tokio::sync::Mutex;
22
23pub use host::MqttClient;
24
25use crate::host::other_error_v3;
26use crate::runtime_config::RuntimeConfig;
27
28pub struct OutboundMqttFactor {
29 create_client: Arc<dyn ClientCreator>,
30}
31
32impl OutboundMqttFactor {
33 pub fn new(create_client: Arc<dyn ClientCreator>) -> Self {
34 Self { create_client }
35 }
36}
37
38pub struct AppState {
39 max_payload_size_bytes: Option<usize>,
41 pub semaphore: ConnectionSemaphore,
43}
44
45impl Factor for OutboundMqttFactor {
46 type RuntimeConfig = RuntimeConfig;
47 type AppState = AppState;
48 type InstanceBuilder = InstanceState;
49
50 fn init(&mut self, ctx: &mut impl spin_factors::InitContext<Self>) -> anyhow::Result<()> {
51 ctx.link_bindings(v2::add_to_linker::<_, FactorData<Self>>)?;
52 ctx.link_bindings(v3::add_to_linker::<_, MqttFactorData>)?;
53 Ok(())
54 }
55
56 fn configure_app<T: RuntimeFactors>(
57 &self,
58 mut ctx: ConfigureAppContext<T, Self>,
59 ) -> anyhow::Result<Self::AppState> {
60 let config = ctx.take_runtime_config().unwrap_or_default();
61
62 Ok(AppState {
63 semaphore: build_connection_semaphore(
64 ctx.app_state::<OutboundNetworkingFactor>().ok(),
65 "mqtt",
66 config.max_connections,
67 config.wait_timeout,
68 ),
69 max_payload_size_bytes: config.max_payload_size_bytes,
70 })
71 }
72
73 fn prepare<T: RuntimeFactors>(
74 &self,
75 mut ctx: PrepareContext<T, Self>,
76 ) -> anyhow::Result<Self::InstanceBuilder> {
77 let allowed_hosts = ctx
78 .instance_builder::<OutboundNetworkingFactor>()?
79 .allowed_hosts();
80 let otel = OtelFactorState::from_prepare_context(&mut ctx)?;
81
82 Ok(InstanceState::new(
83 allowed_hosts,
84 self.create_client.clone(),
85 ctx.app_state().semaphore.clone(),
86 otel,
87 ctx.app_state().max_payload_size_bytes,
88 ))
89 }
90}
91
92impl SelfInstanceBuilder for InstanceState {}
93
94struct MqttFactorData;
95
96impl spin_core::wasmtime::component::HasData for MqttFactorData {
97 type Data<'a> = &'a mut InstanceState;
98}
99
100pub struct NetworkedMqttClient {
102 inner: rumqttc::AsyncClient,
103 event_loop: Mutex<rumqttc::EventLoop>,
104}
105
106const MQTT_CHANNEL_CAP: usize = 1000;
107
108impl NetworkedMqttClient {
109 pub fn creator() -> Arc<dyn ClientCreator> {
111 Arc::new(|address, username, password, keep_alive_interval| {
112 Ok(Arc::new(NetworkedMqttClient::create(
113 address,
114 username,
115 password,
116 keep_alive_interval,
117 )?) as _)
118 })
119 }
120
121 pub fn create(
123 address: String,
124 username: String,
125 password: String,
126 keep_alive_interval: Duration,
127 ) -> Result<Self, v3::Error> {
128 let mut conn_opts = rumqttc::MqttOptions::parse_url(address).map_err(|e| {
129 tracing::error!("MQTT URL parse error: {e:?}");
130 v3::Error::InvalidAddress
131 })?;
132 conn_opts.set_credentials(username, password);
133 conn_opts.set_keep_alive(keep_alive_interval);
134 let (client, event_loop) = AsyncClient::new(conn_opts, MQTT_CHANNEL_CAP);
135 Ok(Self {
136 inner: client,
137 event_loop: Mutex::new(event_loop),
138 })
139 }
140}
141
142#[async_trait]
143impl MqttClient for NetworkedMqttClient {
144 async fn publish_bytes(
145 &self,
146 topic: String,
147 qos: v3::Qos,
148 payload: Vec<u8>,
149 ) -> Result<(), v3::Error> {
150 let qos = match qos {
151 v3::Qos::AtMostOnce => rumqttc::QoS::AtMostOnce,
152 v3::Qos::AtLeastOnce => rumqttc::QoS::AtLeastOnce,
153 v3::Qos::ExactlyOnce => rumqttc::QoS::ExactlyOnce,
154 };
155 self.inner
157 .publish_bytes(topic, qos, false, payload.into())
158 .await
159 .map_err(other_error_v3)?;
160
161 let mut lock = self.event_loop.lock().await;
164 loop {
165 let event = lock
166 .poll()
167 .await
168 .map_err(|err| v3::Error::ConnectionFailed(err.to_string()))?;
169
170 match (qos, event) {
171 (QoS::AtMostOnce, Event::Outgoing(Outgoing::Publish(_)))
172 | (QoS::AtLeastOnce, Event::Incoming(Incoming::PubAck(_)))
173 | (QoS::ExactlyOnce, Event::Incoming(Incoming::PubComp(_))) => break,
174
175 (_, _) => continue,
176 }
177 }
178 Ok(())
179 }
180}
181
182#[async_trait]
184pub trait ClientCreator: Send + Sync {
185 fn create(
186 &self,
187 address: String,
188 username: String,
189 password: String,
190 keep_alive_interval: Duration,
191 ) -> Result<Arc<dyn MqttClient>, v3::Error>;
192}
193
194impl<F> ClientCreator for F
195where
196 F: Fn(String, String, String, Duration) -> Result<Arc<dyn MqttClient>, v3::Error> + Send + Sync,
197{
198 fn create(
199 &self,
200 address: String,
201 username: String,
202 password: String,
203 keep_alive_interval: Duration,
204 ) -> Result<Arc<dyn MqttClient>, v3::Error> {
205 self(address, username, password, keep_alive_interval)
206 }
207}