spin_factor_outbound_mqtt/
host.rs1use std::{sync::Arc, time::Duration};
2
3use anyhow::Result;
4use spin_core::{async_trait, wasmtime::component::Resource};
5use spin_factor_outbound_networking::OutboundAllowedHosts;
6use spin_world::v2::mqtt::{self as v2, Connection, Error, Qos};
7use tracing::{instrument, Level};
8
9use crate::ClientCreator;
10
11pub struct InstanceState {
12 allowed_hosts: OutboundAllowedHosts,
13 connections: spin_resource_table::Table<Arc<dyn MqttClient>>,
14 create_client: Arc<dyn ClientCreator>,
15}
16
17impl InstanceState {
18 pub fn new(allowed_hosts: OutboundAllowedHosts, create_client: Arc<dyn ClientCreator>) -> Self {
19 Self {
20 allowed_hosts,
21 create_client,
22 connections: spin_resource_table::Table::new(1024),
23 }
24 }
25}
26
27#[async_trait]
28pub trait MqttClient: Send + Sync {
29 async fn publish_bytes(&self, topic: String, qos: Qos, payload: Vec<u8>) -> Result<(), Error>;
30}
31
32impl InstanceState {
33 async fn is_address_allowed(&self, address: &str) -> Result<bool> {
34 self.allowed_hosts.check_url(address, "mqtt").await
35 }
36
37 async fn establish_connection(
38 &mut self,
39 address: String,
40 username: String,
41 password: String,
42 keep_alive_interval: Duration,
43 ) -> Result<Resource<Connection>, Error> {
44 self.connections
45 .push((self.create_client).create(address, username, password, keep_alive_interval)?)
46 .map(Resource::new_own)
47 .map_err(|_| Error::TooManyConnections)
48 }
49
50 async fn get_conn(&self, connection: Resource<Connection>) -> Result<&dyn MqttClient, Error> {
51 self.connections
52 .get(connection.rep())
53 .ok_or(Error::Other(
54 "could not find connection for resource".into(),
55 ))
56 .map(|c| c.as_ref())
57 }
58}
59
60impl v2::Host for InstanceState {
61 fn convert_error(&mut self, error: Error) -> Result<Error> {
62 Ok(error)
63 }
64}
65
66impl v2::HostConnection for InstanceState {
67 #[instrument(name = "spin_outbound_mqtt.open_connection", skip(self, password), err(level = Level::INFO), fields(otel.kind = "client"))]
68 async fn open(
69 &mut self,
70 address: String,
71 username: String,
72 password: String,
73 keep_alive_interval: u64,
74 ) -> Result<Resource<Connection>, Error> {
75 if !self
76 .is_address_allowed(&address)
77 .await
78 .map_err(|e| v2::Error::Other(e.to_string()))?
79 {
80 return Err(v2::Error::ConnectionFailed(format!(
81 "address {address} is not permitted"
82 )));
83 }
84 self.establish_connection(
85 address,
86 username,
87 password,
88 Duration::from_secs(keep_alive_interval),
89 )
90 .await
91 }
92
93 #[instrument(name = "spin_outbound_mqtt.publish", skip(self, connection, payload), err(level = Level::INFO),
99 fields(otel.kind = "producer", otel.name = format!("{} publish", topic), messaging.operation = "publish",
100 messaging.system = "mqtt"))]
101 async fn publish(
102 &mut self,
103 connection: Resource<Connection>,
104 topic: String,
105 payload: Vec<u8>,
106 qos: Qos,
107 ) -> Result<(), Error> {
108 let conn = self.get_conn(connection).await.map_err(other_error)?;
109
110 conn.publish_bytes(topic, qos, payload).await?;
111
112 Ok(())
113 }
114
115 async fn drop(&mut self, connection: Resource<Connection>) -> anyhow::Result<()> {
116 self.connections.remove(connection.rep());
117 Ok(())
118 }
119}
120
121pub fn other_error(e: impl std::fmt::Display) -> Error {
122 Error::Other(e.to_string())
123}