Skip to main content

spin_factor_outbound_mqtt/
lib.rs

1mod host;
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use host::other_error;
7use host::InstanceState;
8use rumqttc::{AsyncClient, Event, Incoming, Outgoing, QoS};
9use spin_core::async_trait;
10use spin_factor_otel::OtelFactorState;
11use spin_factor_outbound_networking::OutboundNetworkingFactor;
12use spin_factors::{
13    ConfigureAppContext, Factor, FactorData, PrepareContext, RuntimeFactors, SelfInstanceBuilder,
14};
15use spin_world::v2::mqtt::{self as v2, Error, Qos};
16use tokio::sync::Mutex;
17
18pub use host::MqttClient;
19
20pub struct OutboundMqttFactor {
21    create_client: Arc<dyn ClientCreator>,
22}
23
24impl OutboundMqttFactor {
25    pub fn new(create_client: Arc<dyn ClientCreator>) -> Self {
26        Self { create_client }
27    }
28}
29
30impl Factor for OutboundMqttFactor {
31    type RuntimeConfig = ();
32    type AppState = ();
33    type InstanceBuilder = InstanceState;
34
35    fn init(&mut self, ctx: &mut impl spin_factors::InitContext<Self>) -> anyhow::Result<()> {
36        ctx.link_bindings(spin_world::v2::mqtt::add_to_linker::<_, FactorData<Self>>)?;
37        Ok(())
38    }
39
40    fn configure_app<T: RuntimeFactors>(
41        &self,
42        _ctx: ConfigureAppContext<T, Self>,
43    ) -> anyhow::Result<Self::AppState> {
44        Ok(())
45    }
46
47    fn prepare<T: RuntimeFactors>(
48        &self,
49        mut ctx: PrepareContext<T, Self>,
50    ) -> anyhow::Result<Self::InstanceBuilder> {
51        let allowed_hosts = ctx
52            .instance_builder::<OutboundNetworkingFactor>()?
53            .allowed_hosts();
54        let otel = OtelFactorState::from_prepare_context(&mut ctx)?;
55
56        Ok(InstanceState::new(
57            allowed_hosts,
58            self.create_client.clone(),
59            otel,
60        ))
61    }
62}
63
64impl SelfInstanceBuilder for InstanceState {}
65
66// This is a concrete implementation of the MQTT client using rumqttc.
67pub struct NetworkedMqttClient {
68    inner: rumqttc::AsyncClient,
69    event_loop: Mutex<rumqttc::EventLoop>,
70}
71
72const MQTT_CHANNEL_CAP: usize = 1000;
73
74impl NetworkedMqttClient {
75    /// Create a [`ClientCreator`] that creates a [`NetworkedMqttClient`].
76    pub fn creator() -> Arc<dyn ClientCreator> {
77        Arc::new(|address, username, password, keep_alive_interval| {
78            Ok(Arc::new(NetworkedMqttClient::create(
79                address,
80                username,
81                password,
82                keep_alive_interval,
83            )?) as _)
84        })
85    }
86
87    /// Create a new [`NetworkedMqttClient`] with the given address, username, password, and keep alive interval.
88    pub fn create(
89        address: String,
90        username: String,
91        password: String,
92        keep_alive_interval: Duration,
93    ) -> Result<Self, Error> {
94        let mut conn_opts = rumqttc::MqttOptions::parse_url(address).map_err(|e| {
95            tracing::error!("MQTT URL parse error: {e:?}");
96            Error::InvalidAddress
97        })?;
98        conn_opts.set_credentials(username, password);
99        conn_opts.set_keep_alive(keep_alive_interval);
100        let (client, event_loop) = AsyncClient::new(conn_opts, MQTT_CHANNEL_CAP);
101        Ok(Self {
102            inner: client,
103            event_loop: Mutex::new(event_loop),
104        })
105    }
106}
107
108#[async_trait]
109impl MqttClient for NetworkedMqttClient {
110    async fn publish_bytes(&self, topic: String, qos: Qos, payload: Vec<u8>) -> Result<(), Error> {
111        let qos = match qos {
112            Qos::AtMostOnce => rumqttc::QoS::AtMostOnce,
113            Qos::AtLeastOnce => rumqttc::QoS::AtLeastOnce,
114            Qos::ExactlyOnce => rumqttc::QoS::ExactlyOnce,
115        };
116        // Message published to EventLoop (not MQTT Broker)
117        self.inner
118            .publish_bytes(topic, qos, false, payload.into())
119            .await
120            .map_err(other_error)?;
121
122        // Poll event loop until outgoing publish event is iterated over to send the message to MQTT broker or capture/throw error.
123        // We may revisit this later to manage long running connections, high throughput use cases and their issues in the connection pool.
124        let mut lock = self.event_loop.lock().await;
125        loop {
126            let event = lock
127                .poll()
128                .await
129                .map_err(|err| v2::Error::ConnectionFailed(err.to_string()))?;
130
131            match (qos, event) {
132                (QoS::AtMostOnce, Event::Outgoing(Outgoing::Publish(_)))
133                | (QoS::AtLeastOnce, Event::Incoming(Incoming::PubAck(_)))
134                | (QoS::ExactlyOnce, Event::Incoming(Incoming::PubComp(_))) => break,
135
136                (_, _) => continue,
137            }
138        }
139        Ok(())
140    }
141}
142
143/// A trait for creating MQTT client.
144#[async_trait]
145pub trait ClientCreator: Send + Sync {
146    fn create(
147        &self,
148        address: String,
149        username: String,
150        password: String,
151        keep_alive_interval: Duration,
152    ) -> Result<Arc<dyn MqttClient>, Error>;
153}
154
155impl<F> ClientCreator for F
156where
157    F: Fn(String, String, String, Duration) -> Result<Arc<dyn MqttClient>, Error> + Send + Sync,
158{
159    fn create(
160        &self,
161        address: String,
162        username: String,
163        password: String,
164        keep_alive_interval: Duration,
165    ) -> Result<Arc<dyn MqttClient>, Error> {
166        self(address, username, password, keep_alive_interval)
167    }
168}