Skip to main content

spin_llm_remote_http/
lib.rs

1use anyhow::Result;
2use futures::stream::TryStreamExt as _;
3use reqwest::Url;
4use spin_world::{
5    async_trait,
6    v2::llm::{self as wasi_llm},
7};
8
9mod default;
10mod open_ai;
11
12async fn read_body(
13    resp: reqwest::Response,
14    max_result_bytes: usize,
15) -> Result<Vec<u8>, wasi_llm::Error> {
16    let mut body = Vec::new();
17    let mut stream = resp.bytes_stream();
18    while let Some(chunk) = stream
19        .try_next()
20        .await
21        .map_err(|err| wasi_llm::Error::RuntimeError(format!("POST /infer request error: {err}")))?
22    {
23        body.extend(chunk);
24        if body.len() > max_result_bytes {
25            return Err(wasi_llm::Error::RuntimeError(format!(
26                "query result exceeds limit of {max_result_bytes} bytes"
27            )));
28        }
29    }
30    Ok(body)
31}
32
33pub struct RemoteHttpLlmEngine {
34    worker: Box<dyn LlmWorker>,
35}
36
37impl RemoteHttpLlmEngine {
38    pub fn new(url: Url, auth_token: String, api_type: ApiType) -> Self {
39        let worker: Box<dyn LlmWorker> = match api_type {
40            ApiType::OpenAi => Box::new(open_ai::AgentEngine::new(auth_token, url, None)),
41            ApiType::Default => Box::new(default::AgentEngine::new(auth_token, url, None)),
42        };
43        Self { worker }
44    }
45}
46
47#[async_trait]
48pub trait LlmWorker: Send + Sync {
49    async fn infer(
50        &mut self,
51        model: wasi_llm::InferencingModel,
52        prompt: String,
53        params: wasi_llm::InferencingParams,
54        max_result_bytes: usize,
55    ) -> Result<wasi_llm::InferencingResult, wasi_llm::Error>;
56
57    async fn generate_embeddings(
58        &mut self,
59        model: wasi_llm::EmbeddingModel,
60        data: Vec<String>,
61        max_result_bytes: usize,
62    ) -> Result<wasi_llm::EmbeddingsResult, wasi_llm::Error>;
63
64    fn url(&self) -> Url;
65}
66
67impl RemoteHttpLlmEngine {
68    pub async fn infer(
69        &mut self,
70        model: wasi_llm::InferencingModel,
71        prompt: String,
72        params: wasi_llm::InferencingParams,
73        max_result_bytes: usize,
74    ) -> Result<wasi_llm::InferencingResult, wasi_llm::Error> {
75        self.worker
76            .infer(model, prompt, params, max_result_bytes)
77            .await
78    }
79
80    pub async fn generate_embeddings(
81        &mut self,
82        model: wasi_llm::EmbeddingModel,
83        data: Vec<String>,
84        max_result_bytes: usize,
85    ) -> Result<wasi_llm::EmbeddingsResult, wasi_llm::Error> {
86        self.worker
87            .generate_embeddings(model, data, max_result_bytes)
88            .await
89    }
90
91    pub fn url(&self) -> Url {
92        self.worker.url()
93    }
94}
95
96#[derive(Debug, Default, serde::Deserialize, PartialEq)]
97#[serde(rename_all = "snake_case")]
98pub enum ApiType {
99    /// Compatible with OpenAI's API alongside some other LLMs
100    OpenAi,
101    #[default]
102    Default,
103}