spin_trigger_redis/
lib.rs

1use std::{collections::HashMap, sync::Arc};
2
3use anyhow::Context;
4use futures::{StreamExt, TryFutureExt};
5use redis::{Client, Msg};
6use serde::Deserialize;
7use spin_factor_variables::VariablesFactor;
8use spin_factors::RuntimeFactors;
9use spin_trigger::{cli::NoCliArgs, App, Trigger, TriggerApp};
10use spin_world::exports::fermyon::spin::inbound_redis;
11use tracing::{instrument, Level};
12
13pub struct RedisTrigger;
14
15/// Redis trigger metadata.
16#[derive(Clone, Debug, Default, Deserialize)]
17#[serde(deny_unknown_fields)]
18struct TriggerMetadata {
19    address: String,
20}
21
22/// Redis trigger configuration.
23#[derive(Clone, Debug, Default, Deserialize)]
24#[serde(deny_unknown_fields)]
25struct TriggerConfig {
26    /// Component ID to invoke
27    component: String,
28    /// Channel to subscribe to
29    channel: String,
30    /// Optionally override address for trigger
31    address: Option<String>,
32}
33
34impl<F: RuntimeFactors> Trigger<F> for RedisTrigger {
35    const TYPE: &'static str = "redis";
36
37    type CliArgs = NoCliArgs;
38
39    type InstanceState = ();
40
41    fn new(_cli_args: Self::CliArgs, _app: &App) -> anyhow::Result<Self> {
42        Ok(Self)
43    }
44
45    async fn run(self, trigger_app: spin_trigger::TriggerApp<Self, F>) -> anyhow::Result<()> {
46        let app_variables = trigger_app
47            .configured_app()
48            .app_state::<VariablesFactor>()
49            .context("RedisTrigger depends on VariablesFactor")?;
50
51        let app = trigger_app.app();
52        let trigger_type = <Self as Trigger<F>>::TYPE;
53        let metadata = app
54            .get_trigger_metadata::<TriggerMetadata>(trigger_type)?
55            .unwrap_or_default();
56        let default_address_expr = &metadata.address;
57        let default_address = app_variables
58            .resolve_expression(default_address_expr.clone())
59            .await
60            .with_context(|| {
61                format!("failed to resolve redis trigger default address {default_address_expr:?}")
62            })?;
63
64        // Maps <server address> -> <channel> -> <component IDs>
65        let mut server_channel_components: HashMap<String, ChannelComponents> = HashMap::new();
66
67        // Resolve trigger configs before starting any subscribers
68        for (_, config) in app
69            .trigger_configs::<TriggerConfig>(trigger_type)?
70            .into_iter()
71            .collect::<Vec<_>>()
72        {
73            let component_id = config.component;
74
75            let address_expr = config.address.as_ref().unwrap_or(&default_address);
76            let address = app_variables
77                .resolve_expression(address_expr.clone())
78                .await
79                .with_context(|| {
80                    format!(
81                        "failed to resolve redis trigger address {address_expr:?} for component {component_id}"
82                    )
83                })?;
84
85            let channel_expr = &config.channel;
86            let channel = app_variables
87                .resolve_expression(channel_expr.clone())
88                .await
89                .with_context(|| {
90                    format!(
91                        "failed to resolve redis trigger channel {channel_expr:?} for component {component_id}"
92                    )
93                })?;
94
95            server_channel_components
96                .entry(address)
97                .or_default()
98                .entry(channel)
99                .or_default()
100                .push(component_id);
101        }
102
103        // Start subscriber(s)
104        let trigger_app = Arc::new(trigger_app);
105        let mut subscriber_tasks = Vec::new();
106        for (address, channel_components) in server_channel_components {
107            let subscriber = Subscriber::new(address, trigger_app.clone(), channel_components)?;
108            let task = tokio::spawn(subscriber.run_listener());
109            subscriber_tasks.push(task);
110        }
111
112        // Wait for any task to complete
113        let (res, _, _) = futures::future::select_all(subscriber_tasks).await;
114        res?
115    }
116}
117
118/// Maps <channel> -> <component IDs>
119type ChannelComponents = HashMap<String, Vec<String>>;
120
121/// Subscribes to channels from a single Redis server.
122struct Subscriber<F: RuntimeFactors> {
123    client: Client,
124    trigger_app: Arc<TriggerApp<RedisTrigger, F>>,
125    channel_components: ChannelComponents,
126}
127
128impl<F: RuntimeFactors> Subscriber<F> {
129    fn new(
130        address: String,
131        trigger_app: Arc<TriggerApp<RedisTrigger, F>>,
132        channel_components: ChannelComponents,
133    ) -> anyhow::Result<Self> {
134        let client = Client::open(address)?;
135        Ok(Self {
136            client,
137            trigger_app,
138            channel_components,
139        })
140    }
141
142    async fn run_listener(self) -> anyhow::Result<()> {
143        let server_addr = &self.client.get_connection_info().addr;
144
145        tracing::info!("Connecting to Redis server at {server_addr}");
146        let mut pubsub = self
147            .client
148            .get_async_pubsub()
149            .await
150            .with_context(|| format!("Redis trigger failed to connect to {server_addr}"))?;
151
152        println!("Active Channels on {server_addr}:");
153
154        // Subscribe to channels
155        for (channel, components) in &self.channel_components {
156            tracing::info!("Subscribing to {channel:?} on {server_addr}");
157            pubsub.subscribe(channel).await.with_context(|| {
158                format!("Redis trigger failed to subscribe to channel {channel:?} on {server_addr}")
159            })?;
160            println!("\t{server_addr}/{channel}: [{}]", components.join(","));
161        }
162
163        let mut message_stream = pubsub.on_message();
164        while let Some(msg) = message_stream.next().await {
165            if let Err(err) = self.handle_message(msg).await {
166                tracing::error!("Error handling message from {server_addr}: {err}");
167            }
168        }
169        Err(anyhow::anyhow!("disconnected from {server_addr}"))
170    }
171
172    #[instrument(name = "spin_trigger_redis.handle_message", skip_all, err(level = Level::INFO), fields(
173        otel.name = format!("{} receive", msg.get_channel_name()),
174        otel.kind = "consumer",
175        messaging.operation = "receive",
176        messaging.system = "redis"
177    ))]
178    async fn handle_message(&self, msg: Msg) -> anyhow::Result<()> {
179        let server_addr = &self.client.get_connection_info().addr;
180        let channel = msg.get_channel_name();
181        tracing::trace!(%server_addr, %channel, "Received message");
182
183        let Some(component_ids) = self.channel_components.get(channel) else {
184            anyhow::bail!("message from unexpected channel {channel:?}");
185        };
186
187        let dispatch_futures = component_ids.iter().map(|component_id| {
188            tracing::trace!("Executing Redis component {component_id}");
189            self.dispatch_handler(&msg, component_id)
190                .inspect_err(move |err| {
191                    tracing::info!("Component {component_id} handler failed: {err}");
192                })
193        });
194        futures::future::join_all(dispatch_futures).await;
195
196        Ok(())
197    }
198
199    async fn dispatch_handler(&self, msg: &Msg, component_id: &str) -> anyhow::Result<()> {
200        spin_telemetry::metrics::monotonic_counter!(
201            spin.request_count = 1,
202            trigger_type = "redis",
203            app_id = self.trigger_app.app().id(),
204            component_id = component_id
205        );
206
207        let (instance, mut store) = self
208            .trigger_app
209            .prepare(component_id)?
210            .instantiate(())
211            .await?;
212
213        let guest_indices = inbound_redis::GuestIndices::new_instance(&mut store, &instance)?;
214        let guest = guest_indices.load(&mut store, &instance)?;
215
216        let payload = msg.get_payload_bytes().to_vec();
217
218        guest
219            .call_handle_message(&mut store, &payload)
220            .await?
221            .context("Redis handler returned an error")
222    }
223}