spin_factor_llm/
host.rs

1use spin_world::v1::llm::{self as v1};
2use spin_world::v2::llm::{self as v2};
3use tracing::field::Empty;
4use tracing::{instrument, Level};
5
6use crate::InstanceState;
7
8impl v2::Host for InstanceState {
9    #[instrument(name = "spin_llm.infer", skip(self, prompt), err(level = Level::INFO), fields(otel.kind = "client", llm.backend = Empty))]
10    async fn infer(
11        &mut self,
12        model: v2::InferencingModel,
13        prompt: String,
14        params: Option<v2::InferencingParams>,
15    ) -> Result<v2::InferencingResult, v2::Error> {
16        if !self.allowed_models.contains(&model) {
17            return Err(access_denied_error(&model));
18        }
19        let mut engine = self.engine.lock().await;
20        tracing::Span::current().record("llm.backend", engine.summary());
21        engine
22            .infer(
23                model,
24                prompt,
25                params.unwrap_or(v2::InferencingParams {
26                    max_tokens: 100,
27                    repeat_penalty: 1.1,
28                    repeat_penalty_last_n_token_count: 64,
29                    temperature: 0.8,
30                    top_k: 40,
31                    top_p: 0.9,
32                }),
33            )
34            .await
35    }
36
37    #[instrument(name = "spin_llm.generate_embeddings", skip(self, data), err(level = Level::INFO), fields(otel.kind = "client", llm.backend = Empty))]
38    async fn generate_embeddings(
39        &mut self,
40        model: v1::EmbeddingModel,
41        data: Vec<String>,
42    ) -> Result<v2::EmbeddingsResult, v2::Error> {
43        if !self.allowed_models.contains(&model) {
44            return Err(access_denied_error(&model));
45        }
46        let mut engine = self.engine.lock().await;
47        tracing::Span::current().record("llm.backend", engine.summary());
48        engine.generate_embeddings(model, data).await
49    }
50
51    fn convert_error(&mut self, error: v2::Error) -> anyhow::Result<v2::Error> {
52        Ok(error)
53    }
54}
55
56impl v1::Host for InstanceState {
57    async fn infer(
58        &mut self,
59        model: v1::InferencingModel,
60        prompt: String,
61        params: Option<v1::InferencingParams>,
62    ) -> Result<v1::InferencingResult, v1::Error> {
63        <Self as v2::Host>::infer(self, model, prompt, params.map(Into::into))
64            .await
65            .map(Into::into)
66            .map_err(Into::into)
67    }
68
69    async fn generate_embeddings(
70        &mut self,
71        model: v1::EmbeddingModel,
72        data: Vec<String>,
73    ) -> Result<v1::EmbeddingsResult, v1::Error> {
74        <Self as v2::Host>::generate_embeddings(self, model, data)
75            .await
76            .map(Into::into)
77            .map_err(Into::into)
78    }
79
80    fn convert_error(&mut self, error: v1::Error) -> anyhow::Result<v1::Error> {
81        Ok(error)
82    }
83}
84
85fn access_denied_error(model: &str) -> v2::Error {
86    v2::Error::InvalidInput(format!(
87        "The component does not have access to use '{model}'. To give the component access, add '{model}' to the 'ai_models' key for the component in your spin.toml manifest"
88    ))
89}