spin_llm_remote_http/
lib.rs

1use anyhow::Result;
2use reqwest::Url;
3use spin_world::{
4    async_trait,
5    v2::llm::{self as wasi_llm},
6};
7
8mod default;
9mod open_ai;
10
11pub struct RemoteHttpLlmEngine {
12    worker: Box<dyn LlmWorker>,
13}
14
15impl RemoteHttpLlmEngine {
16    pub fn new(url: Url, auth_token: String, api_type: ApiType) -> Self {
17        let worker: Box<dyn LlmWorker> = match api_type {
18            ApiType::OpenAi => Box::new(open_ai::AgentEngine::new(auth_token, url, None)),
19            ApiType::Default => Box::new(default::AgentEngine::new(auth_token, url, None)),
20        };
21        Self { worker }
22    }
23}
24
25#[async_trait]
26pub trait LlmWorker: Send + Sync {
27    async fn infer(
28        &mut self,
29        model: wasi_llm::InferencingModel,
30        prompt: String,
31        params: wasi_llm::InferencingParams,
32    ) -> Result<wasi_llm::InferencingResult, wasi_llm::Error>;
33
34    async fn generate_embeddings(
35        &mut self,
36        model: wasi_llm::EmbeddingModel,
37        data: Vec<String>,
38    ) -> Result<wasi_llm::EmbeddingsResult, wasi_llm::Error>;
39
40    fn url(&self) -> Url;
41}
42
43impl RemoteHttpLlmEngine {
44    pub async fn infer(
45        &mut self,
46        model: wasi_llm::InferencingModel,
47        prompt: String,
48        params: wasi_llm::InferencingParams,
49    ) -> Result<wasi_llm::InferencingResult, wasi_llm::Error> {
50        self.worker.infer(model, prompt, params).await
51    }
52
53    pub async fn generate_embeddings(
54        &mut self,
55        model: wasi_llm::EmbeddingModel,
56        data: Vec<String>,
57    ) -> Result<wasi_llm::EmbeddingsResult, wasi_llm::Error> {
58        self.worker.generate_embeddings(model, data).await
59    }
60
61    pub fn url(&self) -> Url {
62        self.worker.url()
63    }
64}
65
66#[derive(Debug, Default, serde::Deserialize, PartialEq)]
67#[serde(rename_all = "snake_case")]
68pub enum ApiType {
69    /// Compatible with OpenAI's API alongside some other LLMs
70    OpenAi,
71    #[default]
72    Default,
73}