spin_llm_remote_http/
lib.rs1use anyhow::Result;
2use reqwest::Url;
3use spin_world::{
4 async_trait,
5 v2::llm::{self as wasi_llm},
6};
7
8mod default;
9mod open_ai;
10
11pub struct RemoteHttpLlmEngine {
12 worker: Box<dyn LlmWorker>,
13}
14
15impl RemoteHttpLlmEngine {
16 pub fn new(url: Url, auth_token: String, api_type: ApiType) -> Self {
17 let worker: Box<dyn LlmWorker> = match api_type {
18 ApiType::OpenAi => Box::new(open_ai::AgentEngine::new(auth_token, url, None)),
19 ApiType::Default => Box::new(default::AgentEngine::new(auth_token, url, None)),
20 };
21 Self { worker }
22 }
23}
24
25#[async_trait]
26pub trait LlmWorker: Send + Sync {
27 async fn infer(
28 &mut self,
29 model: wasi_llm::InferencingModel,
30 prompt: String,
31 params: wasi_llm::InferencingParams,
32 ) -> Result<wasi_llm::InferencingResult, wasi_llm::Error>;
33
34 async fn generate_embeddings(
35 &mut self,
36 model: wasi_llm::EmbeddingModel,
37 data: Vec<String>,
38 ) -> Result<wasi_llm::EmbeddingsResult, wasi_llm::Error>;
39
40 fn url(&self) -> Url;
41}
42
43impl RemoteHttpLlmEngine {
44 pub async fn infer(
45 &mut self,
46 model: wasi_llm::InferencingModel,
47 prompt: String,
48 params: wasi_llm::InferencingParams,
49 ) -> Result<wasi_llm::InferencingResult, wasi_llm::Error> {
50 self.worker.infer(model, prompt, params).await
51 }
52
53 pub async fn generate_embeddings(
54 &mut self,
55 model: wasi_llm::EmbeddingModel,
56 data: Vec<String>,
57 ) -> Result<wasi_llm::EmbeddingsResult, wasi_llm::Error> {
58 self.worker.generate_embeddings(model, data).await
59 }
60
61 pub fn url(&self) -> Url {
62 self.worker.url()
63 }
64}
65
66#[derive(Debug, Default, serde::Deserialize, PartialEq)]
67#[serde(rename_all = "snake_case")]
68pub enum ApiType {
69 OpenAi,
71 #[default]
72 Default,
73}