Skip to main content

spin_factor_llm/
host.rs

1use spin_world::v1::llm::{self as v1};
2use spin_world::v2::llm::{self as v2};
3use spin_world::MAX_HOST_BUFFERED_BYTES;
4use tracing::field::Empty;
5use tracing::{instrument, Level};
6
7use crate::InstanceState;
8
9impl v2::Host for InstanceState {
10    #[instrument(name = "spin_llm.infer", skip(self, prompt), err(level = Level::INFO), fields(otel.kind = "client", llm.backend = Empty))]
11    async fn infer(
12        &mut self,
13        model: v2::InferencingModel,
14        prompt: String,
15        params: Option<v2::InferencingParams>,
16    ) -> Result<v2::InferencingResult, v2::Error> {
17        self.otel.reparent_tracing_span();
18
19        if !self.allowed_models.contains(&model) {
20            return Err(access_denied_error(&model));
21        }
22        let mut engine = self.engine.lock().await;
23        tracing::Span::current().record("llm.backend", engine.summary());
24        engine
25            .infer(
26                model,
27                prompt,
28                params.unwrap_or(v2::InferencingParams {
29                    max_tokens: 100,
30                    repeat_penalty: 1.1,
31                    repeat_penalty_last_n_token_count: 64,
32                    temperature: 0.8,
33                    top_k: 40,
34                    top_p: 0.9,
35                }),
36                MAX_HOST_BUFFERED_BYTES,
37            )
38            .await
39    }
40
41    #[instrument(name = "spin_llm.generate_embeddings", skip(self, data), err(level = Level::INFO), fields(otel.kind = "client", llm.backend = Empty))]
42    async fn generate_embeddings(
43        &mut self,
44        model: v1::EmbeddingModel,
45        data: Vec<String>,
46    ) -> Result<v2::EmbeddingsResult, v2::Error> {
47        self.otel.reparent_tracing_span();
48
49        if !self.allowed_models.contains(&model) {
50            return Err(access_denied_error(&model));
51        }
52        let mut engine = self.engine.lock().await;
53        tracing::Span::current().record("llm.backend", engine.summary());
54        engine
55            .generate_embeddings(model, data, MAX_HOST_BUFFERED_BYTES)
56            .await
57    }
58
59    fn convert_error(&mut self, error: v2::Error) -> anyhow::Result<v2::Error> {
60        Ok(error)
61    }
62}
63
64impl v1::Host for InstanceState {
65    async fn infer(
66        &mut self,
67        model: v1::InferencingModel,
68        prompt: String,
69        params: Option<v1::InferencingParams>,
70    ) -> Result<v1::InferencingResult, v1::Error> {
71        <Self as v2::Host>::infer(self, model, prompt, params.map(Into::into))
72            .await
73            .map(Into::into)
74            .map_err(Into::into)
75    }
76
77    async fn generate_embeddings(
78        &mut self,
79        model: v1::EmbeddingModel,
80        data: Vec<String>,
81    ) -> Result<v1::EmbeddingsResult, v1::Error> {
82        <Self as v2::Host>::generate_embeddings(self, model, data)
83            .await
84            .map(Into::into)
85            .map_err(Into::into)
86    }
87
88    fn convert_error(&mut self, error: v1::Error) -> anyhow::Result<v1::Error> {
89        Ok(error)
90    }
91}
92
93fn access_denied_error(model: &str) -> v2::Error {
94    v2::Error::InvalidInput(format!(
95        "The component does not have access to use '{model}'. To give the component access, add '{model}' to the 'ai_models' key for the component in your spin.toml manifest"
96    ))
97}