1mod bert;
2mod llama;
3
4use anyhow::Context;
5use bert::{BertModel, Config};
6use candle::{safetensors::load_buffer, DType};
7use candle_nn::VarBuilder;
8use spin_common::ui::quoted_path;
9use spin_core::async_trait;
10use spin_world::v2::llm::{self as wasi_llm};
11use std::{
12 collections::{hash_map::Entry, HashMap},
13 path::{Path, PathBuf},
14 str::FromStr,
15 sync::Arc,
16};
17use tokenizers::PaddingParams;
18
19const MODEL_ALL_MINILM_L6_V2: &str = "all-minilm-l6-v2";
20type ModelName = String;
21
22#[derive(Clone)]
23pub struct LocalLlmEngine {
24 registry: PathBuf,
25 inferencing_models: HashMap<ModelName, Arc<dyn InferencingModel>>,
26 embeddings_models: HashMap<String, Arc<(tokenizers::Tokenizer, BertModel)>>,
27}
28
29#[derive(Debug)]
30enum InferencingModelArch {
31 Llama,
32}
33
34impl FromStr for InferencingModelArch {
35 type Err = ();
36
37 fn from_str(s: &str) -> Result<Self, Self::Err> {
38 match s {
39 "llama" => Ok(InferencingModelArch::Llama),
40 _ => Err(()),
41 }
42 }
43}
44
45#[async_trait]
49trait InferencingModel: Send + Sync {
50 async fn infer(
51 &self,
52 prompt: String,
53 params: wasi_llm::InferencingParams,
54 max_result_bytes: usize,
55 ) -> anyhow::Result<wasi_llm::InferencingResult>;
56}
57
58impl LocalLlmEngine {
59 pub async fn infer(
60 &mut self,
61 model: wasi_llm::InferencingModel,
62 prompt: String,
63 params: wasi_llm::InferencingParams,
64 max_result_bytes: usize,
65 ) -> Result<wasi_llm::InferencingResult, wasi_llm::Error> {
66 let model = self.inferencing_model(model).await?;
67
68 model
69 .infer(prompt, params, max_result_bytes)
70 .await
71 .map_err(|e| wasi_llm::Error::RuntimeError(e.to_string()))
72 }
73
74 pub async fn generate_embeddings(
75 &mut self,
76 model: wasi_llm::EmbeddingModel,
77 data: Vec<String>,
78 max_result_bytes: usize,
79 ) -> Result<wasi_llm::EmbeddingsResult, wasi_llm::Error> {
80 let model = self.embeddings_model(model).await?;
81 generate_embeddings(data, model, max_result_bytes)
82 .await
83 .map_err(|e| {
84 wasi_llm::Error::RuntimeError(format!("Error occurred generating embeddings: {e}"))
85 })
86 }
87}
88
89impl LocalLlmEngine {
90 pub fn new(registry: PathBuf) -> Self {
91 Self {
92 registry,
93 inferencing_models: Default::default(),
94 embeddings_models: Default::default(),
95 }
96 }
97
98 async fn embeddings_model(
100 &mut self,
101 model: wasi_llm::EmbeddingModel,
102 ) -> Result<Arc<(tokenizers::Tokenizer, BertModel)>, wasi_llm::Error> {
103 let key = match model.as_str() {
104 MODEL_ALL_MINILM_L6_V2 => model,
105 _ => return Err(wasi_llm::Error::ModelNotSupported),
106 };
107 let registry_path = self.registry.join(&key);
108 let r = match self.embeddings_models.entry(key) {
109 Entry::Occupied(o) => o.get().clone(),
110 Entry::Vacant(v) => v
111 .insert({
112 tokio::task::spawn_blocking(move || {
113 if !registry_path.exists() {
114 return Err(
115 wasi_llm::Error::RuntimeError(format!(
116 "The directory expected to house the embeddings models '{}' does not exist.",
117 registry_path.display()
118 )));
119 }
120 let tokenizer_file = registry_path.join("tokenizer.json");
121 let model_file = registry_path.join("model.safetensors");
122 let tokenizer = load_tokenizer(&tokenizer_file).map_err(|_| {
123 wasi_llm::Error::RuntimeError(format!(
124 "Failed to load embeddings tokenizer from '{}'",
125 tokenizer_file.display()
126 ))
127 })?;
128 let model = load_model(&model_file).map_err(|_| {
129 wasi_llm::Error::RuntimeError(format!(
130 "Failed to load embeddings model from '{}'",
131 model_file.display()
132 ))
133 })?;
134 Ok(Arc::new((tokenizer, model)))
135 })
136 .await
137 .map_err(|_| {
138 wasi_llm::Error::RuntimeError("Error loading inferencing model".into())
139 })??
140 })
141 .clone(),
142 };
143 Ok(r)
144 }
145
146 async fn inferencing_model(
148 &mut self,
149 model: wasi_llm::InferencingModel,
150 ) -> Result<Arc<dyn InferencingModel>, wasi_llm::Error> {
151 let model = match self.inferencing_models.entry(model.clone()) {
152 Entry::Occupied(o) => o.get().clone(),
153 Entry::Vacant(v) => {
154 let (model_dir, arch) =
155 walk_registry_for_model(&self.registry, model.clone()).await?;
156 let model = match arch {
157 InferencingModelArch::Llama => Arc::new(
158 llama::LlamaModels::new(&model_dir)
159 .await
160 .map_err(|e| wasi_llm::Error::RuntimeError(e.to_string()))?,
161 ),
162 };
163
164 v.insert(model.clone());
165
166 model
167 }
168 };
169 Ok(model)
170 }
171}
172
173async fn walk_registry_for_model(
176 registry_path: &Path,
177 model: String,
178) -> Result<(PathBuf, InferencingModelArch), wasi_llm::Error> {
179 let mut arch_dirs = tokio::fs::read_dir(registry_path).await.map_err(|e| {
180 wasi_llm::Error::RuntimeError(format!(
181 "Could not read model registry directory '{}': {e}",
182 registry_path.display()
183 ))
184 })?;
185 let mut result = None;
186 'outer: while let Some(arch_dir) = arch_dirs.next_entry().await.map_err(|e| {
187 wasi_llm::Error::RuntimeError(format!(
188 "Failed to read arch directory in model registry: {e}"
189 ))
190 })? {
191 if arch_dir
192 .file_type()
193 .await
194 .map_err(|e| {
195 wasi_llm::Error::RuntimeError(format!(
196 "Could not read file type of '{}' dir: {e}",
197 arch_dir.path().display()
198 ))
199 })?
200 .is_file()
201 {
202 continue;
203 }
204 let mut model_dirs = tokio::fs::read_dir(arch_dir.path()).await.map_err(|e| {
205 wasi_llm::Error::RuntimeError(format!(
206 "Error reading architecture directory in model registry: {e}"
207 ))
208 })?;
209 while let Some(model_dir) = model_dirs.next_entry().await.map_err(|e| {
210 wasi_llm::Error::RuntimeError(format!(
211 "Error reading model folder in model registry: {e}"
212 ))
213 })? {
214 if model_dir
216 .file_type()
217 .await
218 .map_err(|e| {
219 wasi_llm::Error::RuntimeError(format!(
220 "Could not read file type of '{}' dir: {e}",
221 model_dir.path().display()
222 ))
223 })?
224 .is_file()
225 {
226 continue;
227 }
228 if model_dir
229 .file_name()
230 .to_str()
231 .map(|m| m == model)
232 .unwrap_or_default()
233 {
234 let arch = arch_dir.file_name();
235 let arch = arch
236 .to_str()
237 .ok_or(wasi_llm::Error::ModelNotSupported)?
238 .parse()
239 .map_err(|_| wasi_llm::Error::ModelNotSupported)?;
240 result = Some((model_dir.path(), arch));
241 break 'outer;
242 }
243 }
244 }
245
246 result.ok_or_else(|| {
247 wasi_llm::Error::InvalidInput(format!(
248 "no model directory found in registry for model '{model}'"
249 ))
250 })
251}
252
253async fn generate_embeddings(
254 data: Vec<String>,
255 model: Arc<(tokenizers::Tokenizer, BertModel)>,
256 max_result_bytes: usize,
257) -> anyhow::Result<wasi_llm::EmbeddingsResult> {
258 let n_sentences = data.len();
259 tokio::task::spawn_blocking(move || {
260 let mut tokenizer = model.0.clone();
261 let model = &model.1;
262 if let Some(pp) = tokenizer.get_padding_mut() {
268 pp.strategy = tokenizers::PaddingStrategy::BatchLongest
269 } else {
270 let pp = PaddingParams {
271 strategy: tokenizers::PaddingStrategy::BatchLongest,
272 ..Default::default()
273 };
274 tokenizer.with_padding(Some(pp));
275 }
276 let tokens = tokenizer
277 .encode_batch(data, true)
278 .map_err(|e| anyhow::anyhow!("{e}"))?;
279 let token_ids = tokens
280 .iter()
281 .map(|tokens| {
282 let tokens = tokens.get_ids().to_vec();
283 Ok(candle::Tensor::new(
284 tokens.as_slice(),
285 &candle::Device::Cpu,
286 )?)
287 })
288 .collect::<anyhow::Result<Vec<_>>>()?;
289
290 let token_ids = candle::Tensor::stack(&token_ids, 0)?;
292 let embeddings = model.forward(&token_ids, &token_ids.zeros_like()?)?;
293
294 let (_, n_tokens, _) = embeddings.dims3()?;
298 let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
299
300 let mut results: Vec<Vec<f32>> = Vec::new();
305 for j in 0..n_sentences {
306 let e_j = embeddings.get(j)?;
307 let mut emb: Vec<f32> = e_j.to_vec1()?;
308 let length: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
309 emb.iter_mut().for_each(|x| *x /= length);
310 results.push(emb);
311 }
312
313 if std::mem::size_of::<wasi_llm::EmbeddingsResult>()
317 + results
318 .iter()
319 .map(|v| std::mem::size_of::<Vec<f32>>() + (v.len() * std::mem::size_of::<f32>()))
320 .sum::<usize>()
321 > max_result_bytes
322 {
323 anyhow::bail!("query result exceeds limit of {max_result_bytes} bytes")
324 }
325
326 let result = wasi_llm::EmbeddingsResult {
327 embeddings: results,
328 usage: wasi_llm::EmbeddingsUsage {
329 prompt_token_count: n_tokens as u32,
330 },
331 };
332 Ok(result)
333 })
334 .await?
335}
336
337fn load_tokenizer(tokenizer_file: &Path) -> anyhow::Result<tokenizers::Tokenizer> {
338 let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_file).map_err(|e| {
339 anyhow::anyhow!(
340 "Failed to read tokenizer file {}: {e}",
341 quoted_path(tokenizer_file)
342 )
343 })?;
344 Ok(tokenizer)
345}
346
347fn load_model(model_file: &Path) -> anyhow::Result<BertModel> {
348 let device = &candle::Device::Cpu;
349 let data = std::fs::read(model_file)?;
350 let tensors = load_buffer(&data, device)?;
351 let vb = VarBuilder::from_tensors(tensors, DType::F32, device);
352 let model = BertModel::load(vb, &Config::default()).context("error loading bert model")?;
353 Ok(model)
354}