Skip to main content

spin_factor_llm/
spin.rs

1use std::path::PathBuf;
2use std::sync::Arc;
3
4use spin_factors::runtime_config::toml::GetTomlValue;
5use spin_llm_remote_http::{ApiType, RemoteHttpLlmEngine};
6use spin_world::async_trait;
7use spin_world::v1::llm::{self as v1};
8use spin_world::v2::llm::{self as v2};
9use tokio::sync::Mutex;
10use url::Url;
11
12use crate::{LlmEngine, LlmEngineCreator, RuntimeConfig};
13
14#[cfg(feature = "llm")]
15mod local {
16    use super::*;
17    pub use spin_llm_local::LocalLlmEngine;
18
19    #[async_trait]
20    impl LlmEngine for LocalLlmEngine {
21        async fn infer(
22            &mut self,
23            model: v2::InferencingModel,
24            prompt: String,
25            params: v2::InferencingParams,
26            max_result_bytes: usize,
27        ) -> Result<v2::InferencingResult, v2::Error> {
28            self.infer(model, prompt, params, max_result_bytes).await
29        }
30
31        async fn generate_embeddings(
32            &mut self,
33            model: v2::EmbeddingModel,
34            data: Vec<String>,
35            max_result_bytes: usize,
36        ) -> Result<v2::EmbeddingsResult, v2::Error> {
37            self.generate_embeddings(model, data, max_result_bytes)
38                .await
39        }
40
41        fn summary(&self) -> Option<String> {
42            Some("local model".to_string())
43        }
44    }
45}
46
47/// The default engine creator for the LLM factor when used in the Spin CLI.
48pub fn default_engine_creator(
49    state_dir: Option<PathBuf>,
50) -> anyhow::Result<impl LlmEngineCreator + 'static> {
51    #[cfg(feature = "llm")]
52    let engine = {
53        use anyhow::Context as _;
54        let models_dir_parent = match state_dir {
55            Some(ref dir) => dir.clone(),
56            None => std::env::current_dir().context("failed to get current working directory")?,
57        };
58        spin_llm_local::LocalLlmEngine::new(models_dir_parent.join("ai-models"))
59    };
60    #[cfg(not(feature = "llm"))]
61    let engine = {
62        let _ = state_dir;
63        noop::NoopLlmEngine
64    };
65    let engine = Arc::new(Mutex::new(engine)) as Arc<Mutex<dyn LlmEngine>>;
66    Ok(move || engine.clone())
67}
68
69#[async_trait]
70impl LlmEngine for RemoteHttpLlmEngine {
71    async fn infer(
72        &mut self,
73        model: v1::InferencingModel,
74        prompt: String,
75        params: v2::InferencingParams,
76        max_result_bytes: usize,
77    ) -> Result<v2::InferencingResult, v2::Error> {
78        spin_telemetry::monotonic_counter!(spin.llm_infer = 1, model_name = model);
79        self.infer(model, prompt, params, max_result_bytes).await
80    }
81
82    async fn generate_embeddings(
83        &mut self,
84        model: v2::EmbeddingModel,
85        data: Vec<String>,
86        max_result_bytes: usize,
87    ) -> Result<v2::EmbeddingsResult, v2::Error> {
88        self.generate_embeddings(model, data, max_result_bytes)
89            .await
90    }
91
92    fn summary(&self) -> Option<String> {
93        Some(format!("model at {}", self.url()))
94    }
95}
96
97pub fn runtime_config_from_toml(
98    table: &impl GetTomlValue,
99    state_dir: Option<PathBuf>,
100) -> anyhow::Result<Option<RuntimeConfig>> {
101    let Some(value) = table.get("llm_compute") else {
102        return Ok(None);
103    };
104    let config: LlmCompute = value.clone().try_into()?;
105
106    Ok(Some(RuntimeConfig {
107        engine: config.into_engine(state_dir)?,
108    }))
109}
110
111#[derive(Debug, serde::Deserialize)]
112#[serde(rename_all = "snake_case", tag = "type")]
113pub enum LlmCompute {
114    Spin,
115    RemoteHttp(RemoteHttpCompute),
116}
117
118impl LlmCompute {
119    fn into_engine(self, state_dir: Option<PathBuf>) -> anyhow::Result<Arc<Mutex<dyn LlmEngine>>> {
120        let engine: Arc<Mutex<dyn LlmEngine>> = match self {
121            #[cfg(not(feature = "llm"))]
122            LlmCompute::Spin => {
123                let _ = state_dir;
124                Arc::new(Mutex::new(noop::NoopLlmEngine))
125            }
126            #[cfg(feature = "llm")]
127            LlmCompute::Spin => default_engine_creator(state_dir)?.create(),
128            LlmCompute::RemoteHttp(config) => Arc::new(Mutex::new(RemoteHttpLlmEngine::new(
129                config.url,
130                config.auth_token,
131                config.api_type,
132            ))),
133        };
134        Ok(engine)
135    }
136}
137
138#[derive(Debug, serde::Deserialize)]
139pub struct RemoteHttpCompute {
140    url: Url,
141    auth_token: String,
142    #[serde(default)]
143    api_type: ApiType,
144}
145
146/// A noop engine used when the local engine feature is disabled.
147#[cfg(not(feature = "llm"))]
148mod noop {
149    use super::*;
150
151    #[derive(Clone, Copy)]
152    pub(super) struct NoopLlmEngine;
153
154    #[async_trait]
155    impl LlmEngine for NoopLlmEngine {
156        async fn infer(
157            &mut self,
158            _model: v2::InferencingModel,
159            _prompt: String,
160            _params: v2::InferencingParams,
161            _max_result_bytes: usize,
162        ) -> Result<v2::InferencingResult, v2::Error> {
163            Err(v2::Error::RuntimeError(
164                "Local LLM operations are not supported in this version of Spin.".into(),
165            ))
166        }
167
168        async fn generate_embeddings(
169            &mut self,
170            _model: v2::EmbeddingModel,
171            _data: Vec<String>,
172            _max_result_bytes: usize,
173        ) -> Result<v2::EmbeddingsResult, v2::Error> {
174            Err(v2::Error::RuntimeError(
175                "Local LLM operations are not supported in this version of Spin.".into(),
176            ))
177        }
178
179        fn summary(&self) -> Option<String> {
180            Some("noop model".to_owned())
181        }
182    }
183}