spin_llm_remote_http/
lib.rs

1use anyhow::Result;
2use reqwest::{
3    header::{HeaderMap, HeaderValue},
4    Client, Url,
5};
6use serde::{Deserialize, Serialize};
7use serde_json::json;
8use spin_world::v2::llm::{self as wasi_llm};
9
10#[derive(Clone)]
11pub struct RemoteHttpLlmEngine {
12    auth_token: String,
13    url: Url,
14    client: Option<Client>,
15}
16
17#[derive(Serialize)]
18#[serde(rename_all(serialize = "camelCase"))]
19struct InferRequestBodyParams {
20    max_tokens: u32,
21    repeat_penalty: f32,
22    repeat_penalty_last_n_token_count: u32,
23    temperature: f32,
24    top_k: u32,
25    top_p: f32,
26}
27
28#[derive(Deserialize)]
29#[serde(rename_all(deserialize = "camelCase"))]
30struct InferUsage {
31    prompt_token_count: u32,
32    generated_token_count: u32,
33}
34
35#[derive(Deserialize)]
36struct InferResponseBody {
37    text: String,
38    usage: InferUsage,
39}
40
41#[derive(Deserialize)]
42#[serde(rename_all(deserialize = "camelCase"))]
43struct EmbeddingUsage {
44    prompt_token_count: u32,
45}
46
47#[derive(Deserialize)]
48struct EmbeddingResponseBody {
49    embeddings: Vec<Vec<f32>>,
50    usage: EmbeddingUsage,
51}
52
53impl RemoteHttpLlmEngine {
54    pub async fn infer(
55        &mut self,
56        model: wasi_llm::InferencingModel,
57        prompt: String,
58        params: wasi_llm::InferencingParams,
59    ) -> Result<wasi_llm::InferencingResult, wasi_llm::Error> {
60        let client = self.client.get_or_insert_with(Default::default);
61
62        let mut headers = HeaderMap::new();
63        headers.insert(
64            "authorization",
65            HeaderValue::from_str(&format!("bearer {}", self.auth_token)).map_err(|_| {
66                wasi_llm::Error::RuntimeError("Failed to create authorization header".to_string())
67            })?,
68        );
69        spin_telemetry::inject_trace_context(&mut headers);
70
71        let inference_options = InferRequestBodyParams {
72            max_tokens: params.max_tokens,
73            repeat_penalty: params.repeat_penalty,
74            repeat_penalty_last_n_token_count: params.repeat_penalty_last_n_token_count,
75            temperature: params.temperature,
76            top_k: params.top_k,
77            top_p: params.top_p,
78        };
79        let body = serde_json::to_string(&json!({
80            "model": model,
81            "prompt": prompt,
82            "options": inference_options
83        }))
84        .map_err(|_| wasi_llm::Error::RuntimeError("Failed to serialize JSON".to_string()))?;
85
86        let infer_url = self
87            .url
88            .join("/infer")
89            .map_err(|_| wasi_llm::Error::RuntimeError("Failed to create URL".to_string()))?;
90        tracing::info!("Sending remote inference request to {infer_url}");
91
92        let resp = client
93            .request(reqwest::Method::POST, infer_url)
94            .headers(headers)
95            .body(body)
96            .send()
97            .await
98            .map_err(|err| {
99                wasi_llm::Error::RuntimeError(format!("POST /infer request error: {err}"))
100            })?;
101
102        match resp.json::<InferResponseBody>().await {
103            Ok(val) => Ok(wasi_llm::InferencingResult {
104                text: val.text,
105                usage: wasi_llm::InferencingUsage {
106                    prompt_token_count: val.usage.prompt_token_count,
107                    generated_token_count: val.usage.generated_token_count,
108                },
109            }),
110            Err(err) => Err(wasi_llm::Error::RuntimeError(format!(
111                "Failed to deserialize response for \"POST  /index\": {err}"
112            ))),
113        }
114    }
115
116    pub async fn generate_embeddings(
117        &mut self,
118        model: wasi_llm::EmbeddingModel,
119        data: Vec<String>,
120    ) -> Result<wasi_llm::EmbeddingsResult, wasi_llm::Error> {
121        let client = self.client.get_or_insert_with(Default::default);
122
123        let mut headers = HeaderMap::new();
124        headers.insert(
125            "authorization",
126            HeaderValue::from_str(&format!("bearer {}", self.auth_token)).map_err(|_| {
127                wasi_llm::Error::RuntimeError("Failed to create authorization header".to_string())
128            })?,
129        );
130        spin_telemetry::inject_trace_context(&mut headers);
131
132        let body = serde_json::to_string(&json!({
133            "model": model,
134            "input": data
135        }))
136        .map_err(|_| wasi_llm::Error::RuntimeError("Failed to serialize JSON".to_string()))?;
137
138        let resp = client
139            .request(
140                reqwest::Method::POST,
141                self.url.join("/embed").map_err(|_| {
142                    wasi_llm::Error::RuntimeError("Failed to create URL".to_string())
143                })?,
144            )
145            .headers(headers)
146            .body(body)
147            .send()
148            .await
149            .map_err(|err| {
150                wasi_llm::Error::RuntimeError(format!("POST /embed request error: {err}"))
151            })?;
152
153        match resp.json::<EmbeddingResponseBody>().await {
154            Ok(val) => Ok(wasi_llm::EmbeddingsResult {
155                embeddings: val.embeddings,
156                usage: wasi_llm::EmbeddingsUsage {
157                    prompt_token_count: val.usage.prompt_token_count,
158                },
159            }),
160            Err(err) => Err(wasi_llm::Error::RuntimeError(format!(
161                "Failed to deserialize response  for \"POST  /embed\": {err}"
162            ))),
163        }
164    }
165
166    pub fn url(&self) -> Url {
167        self.url.clone()
168    }
169}
170
171impl RemoteHttpLlmEngine {
172    pub fn new(url: Url, auth_token: String) -> Self {
173        RemoteHttpLlmEngine {
174            url,
175            auth_token,
176            client: None,
177        }
178    }
179}