spin_factor_outbound_mqtt/
lib.rs1mod host;
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use host::other_error;
7use host::InstanceState;
8use rumqttc::{AsyncClient, Event, Incoming, Outgoing, QoS};
9use spin_core::async_trait;
10use spin_factor_outbound_networking::OutboundNetworkingFactor;
11use spin_factors::{
12 ConfigureAppContext, Factor, PrepareContext, RuntimeFactors, SelfInstanceBuilder,
13};
14use spin_world::v2::mqtt::{self as v2, Error, Qos};
15use tokio::sync::Mutex;
16
17pub use host::MqttClient;
18
19pub struct OutboundMqttFactor {
20 create_client: Arc<dyn ClientCreator>,
21}
22
23impl OutboundMqttFactor {
24 pub fn new(create_client: Arc<dyn ClientCreator>) -> Self {
25 Self { create_client }
26 }
27}
28
29impl Factor for OutboundMqttFactor {
30 type RuntimeConfig = ();
31 type AppState = ();
32 type InstanceBuilder = InstanceState;
33
34 fn init<T: Send + 'static>(
35 &mut self,
36 mut ctx: spin_factors::InitContext<T, Self>,
37 ) -> anyhow::Result<()> {
38 ctx.link_bindings(spin_world::v2::mqtt::add_to_linker)?;
39 Ok(())
40 }
41
42 fn configure_app<T: RuntimeFactors>(
43 &self,
44 _ctx: ConfigureAppContext<T, Self>,
45 ) -> anyhow::Result<Self::AppState> {
46 Ok(())
47 }
48
49 fn prepare<T: RuntimeFactors>(
50 &self,
51 mut ctx: PrepareContext<T, Self>,
52 ) -> anyhow::Result<Self::InstanceBuilder> {
53 let allowed_hosts = ctx
54 .instance_builder::<OutboundNetworkingFactor>()?
55 .allowed_hosts();
56 Ok(InstanceState::new(
57 allowed_hosts,
58 self.create_client.clone(),
59 ))
60 }
61}
62
63impl SelfInstanceBuilder for InstanceState {}
64
65pub struct NetworkedMqttClient {
67 inner: rumqttc::AsyncClient,
68 event_loop: Mutex<rumqttc::EventLoop>,
69}
70
71const MQTT_CHANNEL_CAP: usize = 1000;
72
73impl NetworkedMqttClient {
74 pub fn creator() -> Arc<dyn ClientCreator> {
76 Arc::new(|address, username, password, keep_alive_interval| {
77 Ok(Arc::new(NetworkedMqttClient::create(
78 address,
79 username,
80 password,
81 keep_alive_interval,
82 )?) as _)
83 })
84 }
85
86 pub fn create(
88 address: String,
89 username: String,
90 password: String,
91 keep_alive_interval: Duration,
92 ) -> Result<Self, Error> {
93 let mut conn_opts = rumqttc::MqttOptions::parse_url(address).map_err(|e| {
94 tracing::error!("MQTT URL parse error: {e:?}");
95 Error::InvalidAddress
96 })?;
97 conn_opts.set_credentials(username, password);
98 conn_opts.set_keep_alive(keep_alive_interval);
99 let (client, event_loop) = AsyncClient::new(conn_opts, MQTT_CHANNEL_CAP);
100 Ok(Self {
101 inner: client,
102 event_loop: Mutex::new(event_loop),
103 })
104 }
105}
106
107#[async_trait]
108impl MqttClient for NetworkedMqttClient {
109 async fn publish_bytes(&self, topic: String, qos: Qos, payload: Vec<u8>) -> Result<(), Error> {
110 let qos = match qos {
111 Qos::AtMostOnce => rumqttc::QoS::AtMostOnce,
112 Qos::AtLeastOnce => rumqttc::QoS::AtLeastOnce,
113 Qos::ExactlyOnce => rumqttc::QoS::ExactlyOnce,
114 };
115 self.inner
117 .publish_bytes(topic, qos, false, payload.into())
118 .await
119 .map_err(other_error)?;
120
121 let mut lock = self.event_loop.lock().await;
124 loop {
125 let event = lock
126 .poll()
127 .await
128 .map_err(|err| v2::Error::ConnectionFailed(err.to_string()))?;
129
130 match (qos, event) {
131 (QoS::AtMostOnce, Event::Outgoing(Outgoing::Publish(_)))
132 | (QoS::AtLeastOnce, Event::Incoming(Incoming::PubAck(_)))
133 | (QoS::ExactlyOnce, Event::Incoming(Incoming::PubComp(_))) => break,
134
135 (_, _) => continue,
136 }
137 }
138 Ok(())
139 }
140}
141
142#[async_trait]
144pub trait ClientCreator: Send + Sync {
145 fn create(
146 &self,
147 address: String,
148 username: String,
149 password: String,
150 keep_alive_interval: Duration,
151 ) -> Result<Arc<dyn MqttClient>, Error>;
152}
153
154impl<F> ClientCreator for F
155where
156 F: Fn(String, String, String, Duration) -> Result<Arc<dyn MqttClient>, Error> + Send + Sync,
157{
158 fn create(
159 &self,
160 address: String,
161 username: String,
162 password: String,
163 keep_alive_interval: Duration,
164 ) -> Result<Arc<dyn MqttClient>, Error> {
165 self(address, username, password, keep_alive_interval)
166 }
167}