spin_llm_remote_http/
lib.rs1use anyhow::Result;
2use futures::stream::TryStreamExt as _;
3use reqwest::Url;
4use spin_world::{
5 async_trait,
6 v2::llm::{self as wasi_llm},
7};
8
9mod default;
10mod open_ai;
11
12async fn read_body(
13 resp: reqwest::Response,
14 max_result_bytes: usize,
15) -> Result<Vec<u8>, wasi_llm::Error> {
16 let mut body = Vec::new();
17 let mut stream = resp.bytes_stream();
18 while let Some(chunk) = stream
19 .try_next()
20 .await
21 .map_err(|err| wasi_llm::Error::RuntimeError(format!("POST /infer request error: {err}")))?
22 {
23 body.extend(chunk);
24 if body.len() > max_result_bytes {
25 return Err(wasi_llm::Error::RuntimeError(format!(
26 "query result exceeds limit of {max_result_bytes} bytes"
27 )));
28 }
29 }
30 Ok(body)
31}
32
33pub struct RemoteHttpLlmEngine {
34 worker: Box<dyn LlmWorker>,
35}
36
37impl RemoteHttpLlmEngine {
38 pub fn new(url: Url, auth_token: String, api_type: ApiType) -> Self {
39 let worker: Box<dyn LlmWorker> = match api_type {
40 ApiType::OpenAi => Box::new(open_ai::AgentEngine::new(auth_token, url, None)),
41 ApiType::Default => Box::new(default::AgentEngine::new(auth_token, url, None)),
42 };
43 Self { worker }
44 }
45}
46
47#[async_trait]
48pub trait LlmWorker: Send + Sync {
49 async fn infer(
50 &mut self,
51 model: wasi_llm::InferencingModel,
52 prompt: String,
53 params: wasi_llm::InferencingParams,
54 max_result_bytes: usize,
55 ) -> Result<wasi_llm::InferencingResult, wasi_llm::Error>;
56
57 async fn generate_embeddings(
58 &mut self,
59 model: wasi_llm::EmbeddingModel,
60 data: Vec<String>,
61 max_result_bytes: usize,
62 ) -> Result<wasi_llm::EmbeddingsResult, wasi_llm::Error>;
63
64 fn url(&self) -> Url;
65}
66
67impl RemoteHttpLlmEngine {
68 pub async fn infer(
69 &mut self,
70 model: wasi_llm::InferencingModel,
71 prompt: String,
72 params: wasi_llm::InferencingParams,
73 max_result_bytes: usize,
74 ) -> Result<wasi_llm::InferencingResult, wasi_llm::Error> {
75 self.worker
76 .infer(model, prompt, params, max_result_bytes)
77 .await
78 }
79
80 pub async fn generate_embeddings(
81 &mut self,
82 model: wasi_llm::EmbeddingModel,
83 data: Vec<String>,
84 max_result_bytes: usize,
85 ) -> Result<wasi_llm::EmbeddingsResult, wasi_llm::Error> {
86 self.worker
87 .generate_embeddings(model, data, max_result_bytes)
88 .await
89 }
90
91 pub fn url(&self) -> Url {
92 self.worker.url()
93 }
94}
95
96#[derive(Debug, Default, serde::Deserialize, PartialEq)]
97#[serde(rename_all = "snake_case")]
98pub enum ApiType {
99 OpenAi,
101 #[default]
102 Default,
103}