Skip to main content

spin_trigger_redis/
lib.rs

1use std::{collections::HashMap, sync::Arc};
2
3use anyhow::Context;
4use futures::{StreamExt, TryFutureExt};
5use redis::{Client, Msg};
6use serde::Deserialize;
7use spin_factor_variables::VariablesFactor;
8use spin_factors::RuntimeFactors;
9use spin_trigger::{App, Trigger, TriggerApp, cli::NoCliArgs};
10use spin_world::exports::fermyon::spin::inbound_redis as v1;
11use spin_world::exports::spin::redis::inbound_redis as v3;
12use tracing::{Level, instrument};
13
14pub struct RedisTrigger;
15
16/// Redis trigger metadata.
17#[derive(Clone, Debug, Default, Deserialize)]
18#[serde(deny_unknown_fields)]
19struct TriggerMetadata {
20    address: String,
21}
22
23/// Redis trigger configuration.
24#[derive(Clone, Debug, Default, Deserialize)]
25#[serde(deny_unknown_fields)]
26struct TriggerConfig {
27    /// Component ID to invoke
28    component: String,
29    /// Channel to subscribe to
30    channel: String,
31    /// Optionally override address for trigger
32    address: Option<String>,
33}
34
35impl<F: RuntimeFactors> Trigger<F> for RedisTrigger {
36    const TYPE: &'static str = "redis";
37
38    type CliArgs = NoCliArgs;
39
40    type InstanceState = ();
41
42    fn new(_cli_args: Self::CliArgs, _app: &App) -> anyhow::Result<Self> {
43        Ok(Self)
44    }
45
46    async fn run(self, trigger_app: spin_trigger::TriggerApp<Self, F>) -> anyhow::Result<()> {
47        let app_variables = trigger_app
48            .configured_app()
49            .app_state::<VariablesFactor>()
50            .context("RedisTrigger depends on VariablesFactor")?;
51
52        let app = trigger_app.app();
53        let trigger_type = <Self as Trigger<F>>::TYPE;
54        let metadata = app
55            .get_trigger_metadata::<TriggerMetadata>(trigger_type)?
56            .unwrap_or_default();
57        let default_address_expr = &metadata.address;
58        let default_address = app_variables
59            .resolve_expression(default_address_expr.clone())
60            .await
61            .with_context(|| {
62                format!("failed to resolve redis trigger default address {default_address_expr:?}")
63            })?;
64
65        // Maps <server address> -> <channel> -> <component IDs>
66        let mut server_channel_components: HashMap<String, ChannelComponents> = HashMap::new();
67
68        // Resolve trigger configs before starting any subscribers
69        for (_, config) in app
70            .trigger_configs::<TriggerConfig>(trigger_type)?
71            .into_iter()
72            .collect::<Vec<_>>()
73        {
74            let component_id = config.component;
75
76            let address_expr = config.address.as_ref().unwrap_or(&default_address);
77            let address = app_variables
78                .resolve_expression(address_expr.clone())
79                .await
80                .with_context(|| {
81                    format!(
82                        "failed to resolve redis trigger address {address_expr:?} for component {component_id}"
83                    )
84                })?;
85
86            let channel_expr = &config.channel;
87            let channel = app_variables
88                .resolve_expression(channel_expr.clone())
89                .await
90                .with_context(|| {
91                    format!(
92                        "failed to resolve redis trigger channel {channel_expr:?} for component {component_id}"
93                    )
94                })?;
95
96            server_channel_components
97                .entry(address)
98                .or_default()
99                .entry(channel)
100                .or_default()
101                .push(component_id);
102        }
103
104        // Start subscriber(s)
105        let trigger_app = Arc::new(trigger_app);
106        let mut subscriber_tasks = Vec::new();
107        for (address, channel_components) in server_channel_components {
108            let subscriber = Subscriber::new(address, trigger_app.clone(), channel_components)?;
109            let task = tokio::spawn(subscriber.run_listener());
110            subscriber_tasks.push(task);
111        }
112
113        // Wait for any task to complete
114        let (res, _, _) = futures::future::select_all(subscriber_tasks).await;
115        res?
116    }
117}
118
119/// Maps <channel> -> <component IDs>
120type ChannelComponents = HashMap<String, Vec<String>>;
121
122/// Subscribes to channels from a single Redis server.
123struct Subscriber<F: RuntimeFactors> {
124    client: Client,
125    trigger_app: Arc<TriggerApp<RedisTrigger, F>>,
126    channel_components: ChannelComponents,
127}
128
129impl<F: RuntimeFactors> Subscriber<F> {
130    fn new(
131        address: String,
132        trigger_app: Arc<TriggerApp<RedisTrigger, F>>,
133        channel_components: ChannelComponents,
134    ) -> anyhow::Result<Self> {
135        let client = Client::open(address)?;
136        Ok(Self {
137            client,
138            trigger_app,
139            channel_components,
140        })
141    }
142
143    async fn run_listener(self) -> anyhow::Result<()> {
144        let server_addr = &self.client.get_connection_info().addr;
145
146        tracing::info!("Connecting to Redis server at {server_addr}");
147        let mut pubsub = self
148            .client
149            .get_async_pubsub()
150            .await
151            .with_context(|| format!("Redis trigger failed to connect to {server_addr}"))?;
152
153        println!("Active Channels on {server_addr}:");
154
155        // Subscribe to channels
156        for (channel, components) in &self.channel_components {
157            tracing::info!("Subscribing to {channel:?} on {server_addr}");
158            pubsub.subscribe(channel).await.with_context(|| {
159                format!("Redis trigger failed to subscribe to channel {channel:?} on {server_addr}")
160            })?;
161            println!("\t{server_addr}/{channel}: [{}]", components.join(","));
162        }
163
164        let mut message_stream = pubsub.on_message();
165        while let Some(msg) = message_stream.next().await {
166            if let Err(err) = self.handle_message(msg).await {
167                tracing::error!("Error handling message from {server_addr}: {err}");
168            }
169        }
170        Err(anyhow::anyhow!("disconnected from {server_addr}"))
171    }
172
173    #[instrument(name = "spin_trigger_redis.handle_message", skip_all, err(level = Level::INFO), fields(
174        otel.name = format!("{} receive", msg.get_channel_name()),
175        otel.kind = "consumer",
176        messaging.operation = "receive",
177        messaging.system = "redis"
178    ))]
179    async fn handle_message(&self, msg: Msg) -> anyhow::Result<()> {
180        let server_addr = &self.client.get_connection_info().addr;
181        let channel = msg.get_channel_name();
182        tracing::trace!(%server_addr, %channel, "Received message");
183
184        let Some(component_ids) = self.channel_components.get(channel) else {
185            anyhow::bail!("message from unexpected channel {channel:?}");
186        };
187
188        let dispatch_futures = component_ids.iter().map(|component_id| {
189            tracing::trace!("Executing Redis component {component_id}");
190            self.dispatch_handler(&msg, component_id)
191                .inspect_err(move |err| {
192                    tracing::info!("Component {component_id} handler failed: {err}");
193                })
194        });
195        futures::future::join_all(dispatch_futures).await;
196
197        Ok(())
198    }
199
200    async fn dispatch_handler(&self, msg: &Msg, component_id: &str) -> anyhow::Result<()> {
201        spin_telemetry::metrics::monotonic_counter!(
202            spin.request_count = 1,
203            trigger_type = "redis",
204            app_id = self.trigger_app.app().id(),
205            component_id = component_id
206        );
207
208        let (instance, mut store) = self
209            .trigger_app
210            .prepare(component_id)?
211            .instantiate(())
212            .await?;
213
214        let pre = instance.instance_pre(&store);
215
216        match HandlerType::from_instance_pre(&pre)? {
217            HandlerType::V1(guest_indices) => {
218                let guest = guest_indices.load(&mut store, &instance)?;
219
220                let payload = msg.get_payload_bytes().to_vec();
221
222                guest
223                    .call_handle_message(&mut store, &payload)
224                    .await?
225                    .context("Redis handler returned an error")
226            }
227            HandlerType::V3(guest_indices) => {
228                let guest = guest_indices.load(&mut store, &instance)?;
229
230                let payload = msg.get_payload_bytes().to_vec();
231                let res = std::pin::pin!(store.as_mut().run_concurrent(async |accessor| {
232                    guest.call_handle_message(accessor, payload).await
233                }))
234                .await;
235
236                res.map_err(|e| anyhow::anyhow!("{e}"))
237                    .context("Redis handler returned an error (run_concurrent)")?
238                    .map_err(|e| anyhow::anyhow!("{e}"))
239                    .context("Redis handler returned an error")?
240                    .context("Redis handler returned an error")
241            }
242        }
243    }
244}
245
246/// The type of Redis handler export used by a component.
247pub enum HandlerType /*<S: HandlerState>*/ {
248    V1(v1::GuestIndices),
249    V3(v3::GuestIndices),
250}
251
252impl HandlerType {
253    /// Determine the handler type from the exports of a component.
254    pub fn from_instance_pre<T: 'static>(
255        pre: &spin_factors::wasmtime::component::InstancePre<T>, /*, handler_state: S*/
256    ) -> anyhow::Result<Self> {
257        let mut candidates = Vec::new();
258        if let Ok(indices) = v1::GuestIndices::new(pre) {
259            candidates.push(HandlerType::V1(indices));
260        }
261        if let Ok(indices) = v3::GuestIndices::new(pre) {
262            candidates.push(HandlerType::V3(indices));
263        }
264
265        match candidates.len() {
266            0 => anyhow::bail!("component does not export a Redis interface"),
267            1 => Ok(candidates.pop().unwrap()),
268            _ => anyhow::bail!("component exports multiple Redis interfaces"),
269        }
270    }
271}