Skip to main content

spin_factor_outbound_mqtt/
lib.rs

1mod allowed_hosts;
2mod host;
3pub mod runtime_config;
4
5use std::sync::Arc;
6use std::time::Duration;
7
8use host::InstanceState;
9use rumqttc::{AsyncClient, Event, Incoming, Outgoing, QoS};
10use spin_core::async_trait;
11use spin_factor_otel::OtelFactorState;
12use spin_factor_outbound_networking::{
13    ConnectionSemaphore, OutboundNetworkingFactor, build_connection_semaphore,
14};
15use spin_factors::{
16    ConfigureAppContext, Factor, FactorData, PrepareContext, RuntimeFactors, SelfInstanceBuilder,
17    anyhow,
18};
19use spin_world::spin::mqtt::mqtt as v3;
20use spin_world::v2::mqtt as v2;
21use tokio::sync::Mutex;
22
23pub use host::MqttClient;
24
25use crate::host::other_error_v3;
26use crate::runtime_config::RuntimeConfig;
27
28pub struct OutboundMqttFactor {
29    create_client: Arc<dyn ClientCreator>,
30}
31
32impl OutboundMqttFactor {
33    pub fn new(create_client: Arc<dyn ClientCreator>) -> Self {
34        Self { create_client }
35    }
36}
37
38pub struct AppState {
39    /// Optional maximum payload size in bytes for MQTT messages. If `None`, no limit is enforced.
40    max_payload_size_bytes: Option<usize>,
41    /// Semaphore to limit concurrent outbound MQTT connections.
42    pub semaphore: ConnectionSemaphore,
43}
44
45impl Factor for OutboundMqttFactor {
46    type RuntimeConfig = RuntimeConfig;
47    type AppState = AppState;
48    type InstanceBuilder = InstanceState;
49
50    fn init(&mut self, ctx: &mut impl spin_factors::InitContext<Self>) -> anyhow::Result<()> {
51        ctx.link_bindings(v2::add_to_linker::<_, FactorData<Self>>)?;
52        ctx.link_bindings(v3::add_to_linker::<_, MqttFactorData>)?;
53        Ok(())
54    }
55
56    fn configure_app<T: RuntimeFactors>(
57        &self,
58        mut ctx: ConfigureAppContext<T, Self>,
59    ) -> anyhow::Result<Self::AppState> {
60        let config = ctx.take_runtime_config().unwrap_or_default();
61
62        Ok(AppState {
63            semaphore: build_connection_semaphore(
64                ctx.app_state::<OutboundNetworkingFactor>().ok(),
65                "mqtt",
66                config.max_connections,
67                config.wait_timeout,
68            ),
69            max_payload_size_bytes: config.max_payload_size_bytes,
70        })
71    }
72
73    fn prepare<T: RuntimeFactors>(
74        &self,
75        mut ctx: PrepareContext<T, Self>,
76    ) -> anyhow::Result<Self::InstanceBuilder> {
77        let allowed_hosts = ctx
78            .instance_builder::<OutboundNetworkingFactor>()?
79            .allowed_hosts();
80        let otel = OtelFactorState::from_prepare_context(&mut ctx)?;
81
82        Ok(InstanceState::new(
83            allowed_hosts,
84            self.create_client.clone(),
85            ctx.app_state().semaphore.clone(),
86            otel,
87            ctx.app_state().max_payload_size_bytes,
88        ))
89    }
90}
91
92impl SelfInstanceBuilder for InstanceState {}
93
94struct MqttFactorData;
95
96impl spin_core::wasmtime::component::HasData for MqttFactorData {
97    type Data<'a> = &'a mut InstanceState;
98}
99
100// This is a concrete implementation of the MQTT client using rumqttc.
101pub struct NetworkedMqttClient {
102    inner: rumqttc::AsyncClient,
103    event_loop: Mutex<rumqttc::EventLoop>,
104}
105
106const MQTT_CHANNEL_CAP: usize = 1000;
107
108impl NetworkedMqttClient {
109    /// Create a [`ClientCreator`] that creates a [`NetworkedMqttClient`].
110    pub fn creator() -> Arc<dyn ClientCreator> {
111        Arc::new(|address, username, password, keep_alive_interval| {
112            Ok(Arc::new(NetworkedMqttClient::create(
113                address,
114                username,
115                password,
116                keep_alive_interval,
117            )?) as _)
118        })
119    }
120
121    /// Create a new [`NetworkedMqttClient`] with the given address, username, password, and keep alive interval.
122    pub fn create(
123        address: String,
124        username: String,
125        password: String,
126        keep_alive_interval: Duration,
127    ) -> Result<Self, v3::Error> {
128        let mut conn_opts = rumqttc::MqttOptions::parse_url(address).map_err(|e| {
129            tracing::error!("MQTT URL parse error: {e:?}");
130            v3::Error::InvalidAddress
131        })?;
132        conn_opts.set_credentials(username, password);
133        conn_opts.set_keep_alive(keep_alive_interval);
134        let (client, event_loop) = AsyncClient::new(conn_opts, MQTT_CHANNEL_CAP);
135        Ok(Self {
136            inner: client,
137            event_loop: Mutex::new(event_loop),
138        })
139    }
140}
141
142#[async_trait]
143impl MqttClient for NetworkedMqttClient {
144    async fn publish_bytes(
145        &self,
146        topic: String,
147        qos: v3::Qos,
148        payload: Vec<u8>,
149    ) -> Result<(), v3::Error> {
150        let qos = match qos {
151            v3::Qos::AtMostOnce => rumqttc::QoS::AtMostOnce,
152            v3::Qos::AtLeastOnce => rumqttc::QoS::AtLeastOnce,
153            v3::Qos::ExactlyOnce => rumqttc::QoS::ExactlyOnce,
154        };
155        // Message published to EventLoop (not MQTT Broker)
156        self.inner
157            .publish_bytes(topic, qos, false, payload.into())
158            .await
159            .map_err(other_error_v3)?;
160
161        // Poll event loop until outgoing publish event is iterated over to send the message to MQTT broker or capture/throw error.
162        // We may revisit this later to manage long running connections, high throughput use cases and their issues in the connection pool.
163        let mut lock = self.event_loop.lock().await;
164        loop {
165            let event = lock
166                .poll()
167                .await
168                .map_err(|err| v3::Error::ConnectionFailed(err.to_string()))?;
169
170            match (qos, event) {
171                (QoS::AtMostOnce, Event::Outgoing(Outgoing::Publish(_)))
172                | (QoS::AtLeastOnce, Event::Incoming(Incoming::PubAck(_)))
173                | (QoS::ExactlyOnce, Event::Incoming(Incoming::PubComp(_))) => break,
174
175                (_, _) => continue,
176            }
177        }
178        Ok(())
179    }
180}
181
182/// A trait for creating MQTT client.
183#[async_trait]
184pub trait ClientCreator: Send + Sync {
185    fn create(
186        &self,
187        address: String,
188        username: String,
189        password: String,
190        keep_alive_interval: Duration,
191    ) -> Result<Arc<dyn MqttClient>, v3::Error>;
192}
193
194impl<F> ClientCreator for F
195where
196    F: Fn(String, String, String, Duration) -> Result<Arc<dyn MqttClient>, v3::Error> + Send + Sync,
197{
198    fn create(
199        &self,
200        address: String,
201        username: String,
202        password: String,
203        keep_alive_interval: Duration,
204    ) -> Result<Arc<dyn MqttClient>, v3::Error> {
205        self(address, username, password, keep_alive_interval)
206    }
207}