spin_llm_remote_http/
lib.rs1use anyhow::Result;
2use reqwest::{
3 header::{HeaderMap, HeaderValue},
4 Client, Url,
5};
6use serde::{Deserialize, Serialize};
7use serde_json::json;
8use spin_world::v2::llm::{self as wasi_llm};
9
10#[derive(Clone)]
11pub struct RemoteHttpLlmEngine {
12 auth_token: String,
13 url: Url,
14 client: Option<Client>,
15}
16
17#[derive(Serialize)]
18#[serde(rename_all(serialize = "camelCase"))]
19struct InferRequestBodyParams {
20 max_tokens: u32,
21 repeat_penalty: f32,
22 repeat_penalty_last_n_token_count: u32,
23 temperature: f32,
24 top_k: u32,
25 top_p: f32,
26}
27
28#[derive(Deserialize)]
29#[serde(rename_all(deserialize = "camelCase"))]
30struct InferUsage {
31 prompt_token_count: u32,
32 generated_token_count: u32,
33}
34
35#[derive(Deserialize)]
36struct InferResponseBody {
37 text: String,
38 usage: InferUsage,
39}
40
41#[derive(Deserialize)]
42#[serde(rename_all(deserialize = "camelCase"))]
43struct EmbeddingUsage {
44 prompt_token_count: u32,
45}
46
47#[derive(Deserialize)]
48struct EmbeddingResponseBody {
49 embeddings: Vec<Vec<f32>>,
50 usage: EmbeddingUsage,
51}
52
53impl RemoteHttpLlmEngine {
54 pub async fn infer(
55 &mut self,
56 model: wasi_llm::InferencingModel,
57 prompt: String,
58 params: wasi_llm::InferencingParams,
59 ) -> Result<wasi_llm::InferencingResult, wasi_llm::Error> {
60 let client = self.client.get_or_insert_with(Default::default);
61
62 let mut headers = HeaderMap::new();
63 headers.insert(
64 "authorization",
65 HeaderValue::from_str(&format!("bearer {}", self.auth_token)).map_err(|_| {
66 wasi_llm::Error::RuntimeError("Failed to create authorization header".to_string())
67 })?,
68 );
69 spin_telemetry::inject_trace_context(&mut headers);
70
71 let inference_options = InferRequestBodyParams {
72 max_tokens: params.max_tokens,
73 repeat_penalty: params.repeat_penalty,
74 repeat_penalty_last_n_token_count: params.repeat_penalty_last_n_token_count,
75 temperature: params.temperature,
76 top_k: params.top_k,
77 top_p: params.top_p,
78 };
79 let body = serde_json::to_string(&json!({
80 "model": model,
81 "prompt": prompt,
82 "options": inference_options
83 }))
84 .map_err(|_| wasi_llm::Error::RuntimeError("Failed to serialize JSON".to_string()))?;
85
86 let infer_url = self
87 .url
88 .join("/infer")
89 .map_err(|_| wasi_llm::Error::RuntimeError("Failed to create URL".to_string()))?;
90 tracing::info!("Sending remote inference request to {infer_url}");
91
92 let resp = client
93 .request(reqwest::Method::POST, infer_url)
94 .headers(headers)
95 .body(body)
96 .send()
97 .await
98 .map_err(|err| {
99 wasi_llm::Error::RuntimeError(format!("POST /infer request error: {err}"))
100 })?;
101
102 match resp.json::<InferResponseBody>().await {
103 Ok(val) => Ok(wasi_llm::InferencingResult {
104 text: val.text,
105 usage: wasi_llm::InferencingUsage {
106 prompt_token_count: val.usage.prompt_token_count,
107 generated_token_count: val.usage.generated_token_count,
108 },
109 }),
110 Err(err) => Err(wasi_llm::Error::RuntimeError(format!(
111 "Failed to deserialize response for \"POST /index\": {err}"
112 ))),
113 }
114 }
115
116 pub async fn generate_embeddings(
117 &mut self,
118 model: wasi_llm::EmbeddingModel,
119 data: Vec<String>,
120 ) -> Result<wasi_llm::EmbeddingsResult, wasi_llm::Error> {
121 let client = self.client.get_or_insert_with(Default::default);
122
123 let mut headers = HeaderMap::new();
124 headers.insert(
125 "authorization",
126 HeaderValue::from_str(&format!("bearer {}", self.auth_token)).map_err(|_| {
127 wasi_llm::Error::RuntimeError("Failed to create authorization header".to_string())
128 })?,
129 );
130 spin_telemetry::inject_trace_context(&mut headers);
131
132 let body = serde_json::to_string(&json!({
133 "model": model,
134 "input": data
135 }))
136 .map_err(|_| wasi_llm::Error::RuntimeError("Failed to serialize JSON".to_string()))?;
137
138 let resp = client
139 .request(
140 reqwest::Method::POST,
141 self.url.join("/embed").map_err(|_| {
142 wasi_llm::Error::RuntimeError("Failed to create URL".to_string())
143 })?,
144 )
145 .headers(headers)
146 .body(body)
147 .send()
148 .await
149 .map_err(|err| {
150 wasi_llm::Error::RuntimeError(format!("POST /embed request error: {err}"))
151 })?;
152
153 match resp.json::<EmbeddingResponseBody>().await {
154 Ok(val) => Ok(wasi_llm::EmbeddingsResult {
155 embeddings: val.embeddings,
156 usage: wasi_llm::EmbeddingsUsage {
157 prompt_token_count: val.usage.prompt_token_count,
158 },
159 }),
160 Err(err) => Err(wasi_llm::Error::RuntimeError(format!(
161 "Failed to deserialize response for \"POST /embed\": {err}"
162 ))),
163 }
164 }
165
166 pub fn url(&self) -> Url {
167 self.url.clone()
168 }
169}
170
171impl RemoteHttpLlmEngine {
172 pub fn new(url: Url, auth_token: String) -> Self {
173 RemoteHttpLlmEngine {
174 url,
175 auth_token,
176 client: None,
177 }
178 }
179}