1mod bert;
2mod llama;
3
4use anyhow::Context;
5use bert::{BertModel, Config};
6use candle::{safetensors::load_buffer, DType};
7use candle_nn::VarBuilder;
8use spin_common::ui::quoted_path;
9use spin_core::async_trait;
10use spin_world::v2::llm::{self as wasi_llm};
11use std::{
12 collections::{hash_map::Entry, HashMap},
13 path::{Path, PathBuf},
14 str::FromStr,
15 sync::Arc,
16};
17use tokenizers::PaddingParams;
18
19const MODEL_ALL_MINILM_L6_V2: &str = "all-minilm-l6-v2";
20type ModelName = String;
21
22#[derive(Clone)]
23pub struct LocalLlmEngine {
24 registry: PathBuf,
25 inferencing_models: HashMap<ModelName, Arc<dyn InferencingModel>>,
26 embeddings_models: HashMap<String, Arc<(tokenizers::Tokenizer, BertModel)>>,
27}
28
29#[derive(Debug)]
30enum InferencingModelArch {
31 Llama,
32}
33
34impl FromStr for InferencingModelArch {
35 type Err = ();
36
37 fn from_str(s: &str) -> Result<Self, Self::Err> {
38 match s {
39 "llama" => Ok(InferencingModelArch::Llama),
40 _ => Err(()),
41 }
42 }
43}
44
45#[async_trait]
49trait InferencingModel: Send + Sync {
50 async fn infer(
51 &self,
52 prompt: String,
53 params: wasi_llm::InferencingParams,
54 ) -> anyhow::Result<wasi_llm::InferencingResult>;
55}
56
57impl LocalLlmEngine {
58 pub async fn infer(
59 &mut self,
60 model: wasi_llm::InferencingModel,
61 prompt: String,
62 params: wasi_llm::InferencingParams,
63 ) -> Result<wasi_llm::InferencingResult, wasi_llm::Error> {
64 let model = self.inferencing_model(model).await?;
65
66 model
67 .infer(prompt, params)
68 .await
69 .map_err(|e| wasi_llm::Error::RuntimeError(e.to_string()))
70 }
71
72 pub async fn generate_embeddings(
73 &mut self,
74 model: wasi_llm::EmbeddingModel,
75 data: Vec<String>,
76 ) -> Result<wasi_llm::EmbeddingsResult, wasi_llm::Error> {
77 let model = self.embeddings_model(model).await?;
78 generate_embeddings(data, model).await.map_err(|e| {
79 wasi_llm::Error::RuntimeError(format!("Error occurred generating embeddings: {e}"))
80 })
81 }
82}
83
84impl LocalLlmEngine {
85 pub fn new(registry: PathBuf) -> Self {
86 Self {
87 registry,
88 inferencing_models: Default::default(),
89 embeddings_models: Default::default(),
90 }
91 }
92
93 async fn embeddings_model(
95 &mut self,
96 model: wasi_llm::EmbeddingModel,
97 ) -> Result<Arc<(tokenizers::Tokenizer, BertModel)>, wasi_llm::Error> {
98 let key = match model.as_str() {
99 MODEL_ALL_MINILM_L6_V2 => model,
100 _ => return Err(wasi_llm::Error::ModelNotSupported),
101 };
102 let registry_path = self.registry.join(&key);
103 let r = match self.embeddings_models.entry(key) {
104 Entry::Occupied(o) => o.get().clone(),
105 Entry::Vacant(v) => v
106 .insert({
107 tokio::task::spawn_blocking(move || {
108 if !registry_path.exists() {
109 return Err(
110 wasi_llm::Error::RuntimeError(format!(
111 "The directory expected to house the embeddings models '{}' does not exist.",
112 registry_path.display()
113 )));
114 }
115 let tokenizer_file = registry_path.join("tokenizer.json");
116 let model_file = registry_path.join("model.safetensors");
117 let tokenizer = load_tokenizer(&tokenizer_file).map_err(|_| {
118 wasi_llm::Error::RuntimeError(format!(
119 "Failed to load embeddings tokenizer from '{}'",
120 tokenizer_file.display()
121 ))
122 })?;
123 let model = load_model(&model_file).map_err(|_| {
124 wasi_llm::Error::RuntimeError(format!(
125 "Failed to load embeddings model from '{}'",
126 model_file.display()
127 ))
128 })?;
129 Ok(Arc::new((tokenizer, model)))
130 })
131 .await
132 .map_err(|_| {
133 wasi_llm::Error::RuntimeError("Error loading inferencing model".into())
134 })??
135 })
136 .clone(),
137 };
138 Ok(r)
139 }
140
141 async fn inferencing_model(
143 &mut self,
144 model: wasi_llm::InferencingModel,
145 ) -> Result<Arc<dyn InferencingModel>, wasi_llm::Error> {
146 let model = match self.inferencing_models.entry(model.clone()) {
147 Entry::Occupied(o) => o.get().clone(),
148 Entry::Vacant(v) => {
149 let (model_dir, arch) =
150 walk_registry_for_model(&self.registry, model.clone()).await?;
151 let model = match arch {
152 InferencingModelArch::Llama => Arc::new(
153 llama::LlamaModels::new(&model_dir)
154 .await
155 .map_err(|e| wasi_llm::Error::RuntimeError(e.to_string()))?,
156 ),
157 };
158
159 v.insert(model.clone());
160
161 model
162 }
163 };
164 Ok(model)
165 }
166}
167
168async fn walk_registry_for_model(
171 registry_path: &Path,
172 model: String,
173) -> Result<(PathBuf, InferencingModelArch), wasi_llm::Error> {
174 let mut arch_dirs = tokio::fs::read_dir(registry_path).await.map_err(|e| {
175 wasi_llm::Error::RuntimeError(format!(
176 "Could not read model registry directory '{}': {e}",
177 registry_path.display()
178 ))
179 })?;
180 let mut result = None;
181 'outer: while let Some(arch_dir) = arch_dirs.next_entry().await.map_err(|e| {
182 wasi_llm::Error::RuntimeError(format!(
183 "Failed to read arch directory in model registry: {e}"
184 ))
185 })? {
186 if arch_dir
187 .file_type()
188 .await
189 .map_err(|e| {
190 wasi_llm::Error::RuntimeError(format!(
191 "Could not read file type of '{}' dir: {e}",
192 arch_dir.path().display()
193 ))
194 })?
195 .is_file()
196 {
197 continue;
198 }
199 let mut model_dirs = tokio::fs::read_dir(arch_dir.path()).await.map_err(|e| {
200 wasi_llm::Error::RuntimeError(format!(
201 "Error reading architecture directory in model registry: {e}"
202 ))
203 })?;
204 while let Some(model_dir) = model_dirs.next_entry().await.map_err(|e| {
205 wasi_llm::Error::RuntimeError(format!(
206 "Error reading model folder in model registry: {e}"
207 ))
208 })? {
209 if model_dir
211 .file_type()
212 .await
213 .map_err(|e| {
214 wasi_llm::Error::RuntimeError(format!(
215 "Could not read file type of '{}' dir: {e}",
216 model_dir.path().display()
217 ))
218 })?
219 .is_file()
220 {
221 continue;
222 }
223 if model_dir
224 .file_name()
225 .to_str()
226 .map(|m| m == model)
227 .unwrap_or_default()
228 {
229 let arch = arch_dir.file_name();
230 let arch = arch
231 .to_str()
232 .ok_or(wasi_llm::Error::ModelNotSupported)?
233 .parse()
234 .map_err(|_| wasi_llm::Error::ModelNotSupported)?;
235 result = Some((model_dir.path(), arch));
236 break 'outer;
237 }
238 }
239 }
240
241 result.ok_or_else(|| {
242 wasi_llm::Error::InvalidInput(format!(
243 "no model directory found in registry for model '{model}'"
244 ))
245 })
246}
247
248async fn generate_embeddings(
249 data: Vec<String>,
250 model: Arc<(tokenizers::Tokenizer, BertModel)>,
251) -> anyhow::Result<wasi_llm::EmbeddingsResult> {
252 let n_sentences = data.len();
253 tokio::task::spawn_blocking(move || {
254 let mut tokenizer = model.0.clone();
255 let model = &model.1;
256 if let Some(pp) = tokenizer.get_padding_mut() {
262 pp.strategy = tokenizers::PaddingStrategy::BatchLongest
263 } else {
264 let pp = PaddingParams {
265 strategy: tokenizers::PaddingStrategy::BatchLongest,
266 ..Default::default()
267 };
268 tokenizer.with_padding(Some(pp));
269 }
270 let tokens = tokenizer
271 .encode_batch(data, true)
272 .map_err(|e| anyhow::anyhow!("{e}"))?;
273 let token_ids = tokens
274 .iter()
275 .map(|tokens| {
276 let tokens = tokens.get_ids().to_vec();
277 Ok(candle::Tensor::new(
278 tokens.as_slice(),
279 &candle::Device::Cpu,
280 )?)
281 })
282 .collect::<anyhow::Result<Vec<_>>>()?;
283
284 let token_ids = candle::Tensor::stack(&token_ids, 0)?;
286 let embeddings = model.forward(&token_ids, &token_ids.zeros_like()?)?;
287
288 let (_, n_tokens, _) = embeddings.dims3()?;
292 let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
293
294 let mut results: Vec<Vec<f32>> = Vec::new();
299 for j in 0..n_sentences {
300 let e_j = embeddings.get(j)?;
301 let mut emb: Vec<f32> = e_j.to_vec1()?;
302 let length: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
303 emb.iter_mut().for_each(|x| *x /= length);
304 results.push(emb);
305 }
306
307 let result = wasi_llm::EmbeddingsResult {
308 embeddings: results,
309 usage: wasi_llm::EmbeddingsUsage {
310 prompt_token_count: n_tokens as u32,
311 },
312 };
313 Ok(result)
314 })
315 .await?
316}
317
318fn load_tokenizer(tokenizer_file: &Path) -> anyhow::Result<tokenizers::Tokenizer> {
319 let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_file).map_err(|e| {
320 anyhow::anyhow!(
321 "Failed to read tokenizer file {}: {e}",
322 quoted_path(tokenizer_file)
323 )
324 })?;
325 Ok(tokenizer)
326}
327
328fn load_model(model_file: &Path) -> anyhow::Result<BertModel> {
329 let device = &candle::Device::Cpu;
330 let data = std::fs::read(model_file)?;
331 let tensors = load_buffer(&data, device)?;
332 let vb = VarBuilder::from_tensors(tensors, DType::F32, device);
333 let model = BertModel::load(vb, &Config::default()).context("error loading bert model")?;
334 Ok(model)
335}