Skip to main content

spin_llm_local/
lib.rs

1mod bert;
2mod llama;
3
4use anyhow::Context;
5use bert::{BertModel, Config};
6use candle::{safetensors::load_buffer, DType};
7use candle_nn::VarBuilder;
8use spin_common::ui::quoted_path;
9use spin_core::async_trait;
10use spin_world::v2::llm::{self as wasi_llm};
11use std::{
12    collections::{hash_map::Entry, HashMap},
13    path::{Path, PathBuf},
14    str::FromStr,
15    sync::Arc,
16};
17use tokenizers::PaddingParams;
18
19const MODEL_ALL_MINILM_L6_V2: &str = "all-minilm-l6-v2";
20type ModelName = String;
21
22#[derive(Clone)]
23pub struct LocalLlmEngine {
24    registry: PathBuf,
25    inferencing_models: HashMap<ModelName, Arc<dyn InferencingModel>>,
26    embeddings_models: HashMap<String, Arc<(tokenizers::Tokenizer, BertModel)>>,
27}
28
29#[derive(Debug)]
30enum InferencingModelArch {
31    Llama,
32}
33
34impl FromStr for InferencingModelArch {
35    type Err = ();
36
37    fn from_str(s: &str) -> Result<Self, Self::Err> {
38        match s {
39            "llama" => Ok(InferencingModelArch::Llama),
40            _ => Err(()),
41        }
42    }
43}
44
45/// A model that is prepared and cached after loading.
46///
47/// This trait does not specify anything about if the results are cached.
48#[async_trait]
49trait InferencingModel: Send + Sync {
50    async fn infer(
51        &self,
52        prompt: String,
53        params: wasi_llm::InferencingParams,
54        max_result_bytes: usize,
55    ) -> anyhow::Result<wasi_llm::InferencingResult>;
56}
57
58impl LocalLlmEngine {
59    pub async fn infer(
60        &mut self,
61        model: wasi_llm::InferencingModel,
62        prompt: String,
63        params: wasi_llm::InferencingParams,
64        max_result_bytes: usize,
65    ) -> Result<wasi_llm::InferencingResult, wasi_llm::Error> {
66        let model = self.inferencing_model(model).await?;
67
68        model
69            .infer(prompt, params, max_result_bytes)
70            .await
71            .map_err(|e| wasi_llm::Error::RuntimeError(e.to_string()))
72    }
73
74    pub async fn generate_embeddings(
75        &mut self,
76        model: wasi_llm::EmbeddingModel,
77        data: Vec<String>,
78        max_result_bytes: usize,
79    ) -> Result<wasi_llm::EmbeddingsResult, wasi_llm::Error> {
80        let model = self.embeddings_model(model).await?;
81        generate_embeddings(data, model, max_result_bytes)
82            .await
83            .map_err(|e| {
84                wasi_llm::Error::RuntimeError(format!("Error occurred generating embeddings: {e}"))
85            })
86    }
87}
88
89impl LocalLlmEngine {
90    pub fn new(registry: PathBuf) -> Self {
91        Self {
92            registry,
93            inferencing_models: Default::default(),
94            embeddings_models: Default::default(),
95        }
96    }
97
98    /// Get embeddings model from cache or load from disk
99    async fn embeddings_model(
100        &mut self,
101        model: wasi_llm::EmbeddingModel,
102    ) -> Result<Arc<(tokenizers::Tokenizer, BertModel)>, wasi_llm::Error> {
103        let key = match model.as_str() {
104            MODEL_ALL_MINILM_L6_V2 => model,
105            _ => return Err(wasi_llm::Error::ModelNotSupported),
106        };
107        let registry_path = self.registry.join(&key);
108        let r = match self.embeddings_models.entry(key) {
109            Entry::Occupied(o) => o.get().clone(),
110            Entry::Vacant(v) => v
111                .insert({
112                    tokio::task::spawn_blocking(move || {
113                        if !registry_path.exists() {
114                            return Err(
115                                wasi_llm::Error::RuntimeError(format!(
116                                "The directory expected to house the embeddings models '{}' does not exist.",
117                                registry_path.display()
118                            )));
119                        }
120                        let tokenizer_file = registry_path.join("tokenizer.json");
121                        let model_file = registry_path.join("model.safetensors");
122                        let tokenizer = load_tokenizer(&tokenizer_file).map_err(|_| {
123                            wasi_llm::Error::RuntimeError(format!(
124                                "Failed to load embeddings tokenizer from '{}'",
125                                tokenizer_file.display()
126                            ))
127                        })?;
128                        let model = load_model(&model_file).map_err(|_| {
129                            wasi_llm::Error::RuntimeError(format!(
130                                "Failed to load embeddings model from '{}'",
131                                model_file.display()
132                            ))
133                        })?;
134                        Ok(Arc::new((tokenizer, model)))
135                    })
136                    .await
137                    .map_err(|_| {
138                        wasi_llm::Error::RuntimeError("Error loading inferencing model".into())
139                    })??
140                })
141                .clone(),
142        };
143        Ok(r)
144    }
145
146    /// Get inferencing model from cache or load from disk
147    async fn inferencing_model(
148        &mut self,
149        model: wasi_llm::InferencingModel,
150    ) -> Result<Arc<dyn InferencingModel>, wasi_llm::Error> {
151        let model = match self.inferencing_models.entry(model.clone()) {
152            Entry::Occupied(o) => o.get().clone(),
153            Entry::Vacant(v) => {
154                let (model_dir, arch) =
155                    walk_registry_for_model(&self.registry, model.clone()).await?;
156                let model = match arch {
157                    InferencingModelArch::Llama => Arc::new(
158                        llama::LlamaModels::new(&model_dir)
159                            .await
160                            .map_err(|e| wasi_llm::Error::RuntimeError(e.to_string()))?,
161                    ),
162                };
163
164                v.insert(model.clone());
165
166                model
167            }
168        };
169        Ok(model)
170    }
171}
172
173/// Walks the registry file structure and returns the directory the model is
174/// present along with its architecture
175async fn walk_registry_for_model(
176    registry_path: &Path,
177    model: String,
178) -> Result<(PathBuf, InferencingModelArch), wasi_llm::Error> {
179    let mut arch_dirs = tokio::fs::read_dir(registry_path).await.map_err(|e| {
180        wasi_llm::Error::RuntimeError(format!(
181            "Could not read model registry directory '{}': {e}",
182            registry_path.display()
183        ))
184    })?;
185    let mut result = None;
186    'outer: while let Some(arch_dir) = arch_dirs.next_entry().await.map_err(|e| {
187        wasi_llm::Error::RuntimeError(format!(
188            "Failed to read arch directory in model registry: {e}"
189        ))
190    })? {
191        if arch_dir
192            .file_type()
193            .await
194            .map_err(|e| {
195                wasi_llm::Error::RuntimeError(format!(
196                    "Could not read file type of '{}' dir: {e}",
197                    arch_dir.path().display()
198                ))
199            })?
200            .is_file()
201        {
202            continue;
203        }
204        let mut model_dirs = tokio::fs::read_dir(arch_dir.path()).await.map_err(|e| {
205            wasi_llm::Error::RuntimeError(format!(
206                "Error reading architecture directory in model registry: {e}"
207            ))
208        })?;
209        while let Some(model_dir) = model_dirs.next_entry().await.map_err(|e| {
210            wasi_llm::Error::RuntimeError(format!(
211                "Error reading model folder in model registry: {e}"
212            ))
213        })? {
214            // Models need to be a directory. So ignore any files.
215            if model_dir
216                .file_type()
217                .await
218                .map_err(|e| {
219                    wasi_llm::Error::RuntimeError(format!(
220                        "Could not read file type of '{}' dir: {e}",
221                        model_dir.path().display()
222                    ))
223                })?
224                .is_file()
225            {
226                continue;
227            }
228            if model_dir
229                .file_name()
230                .to_str()
231                .map(|m| m == model)
232                .unwrap_or_default()
233            {
234                let arch = arch_dir.file_name();
235                let arch = arch
236                    .to_str()
237                    .ok_or(wasi_llm::Error::ModelNotSupported)?
238                    .parse()
239                    .map_err(|_| wasi_llm::Error::ModelNotSupported)?;
240                result = Some((model_dir.path(), arch));
241                break 'outer;
242            }
243        }
244    }
245
246    result.ok_or_else(|| {
247        wasi_llm::Error::InvalidInput(format!(
248            "no model directory found in registry for model '{model}'"
249        ))
250    })
251}
252
253async fn generate_embeddings(
254    data: Vec<String>,
255    model: Arc<(tokenizers::Tokenizer, BertModel)>,
256    max_result_bytes: usize,
257) -> anyhow::Result<wasi_llm::EmbeddingsResult> {
258    let n_sentences = data.len();
259    tokio::task::spawn_blocking(move || {
260        let mut tokenizer = model.0.clone();
261        let model = &model.1;
262        // This function attempts to generate the embeddings for a batch of inputs, most
263        // likely of different lengths.
264        // The tokenizer expects all inputs in a batch to have the same length, so the
265        // following is configuring the tokenizer to pad (add trailing zeros) each input
266        // to match the length of the longest in the batch.
267        if let Some(pp) = tokenizer.get_padding_mut() {
268            pp.strategy = tokenizers::PaddingStrategy::BatchLongest
269        } else {
270            let pp = PaddingParams {
271                strategy: tokenizers::PaddingStrategy::BatchLongest,
272                ..Default::default()
273            };
274            tokenizer.with_padding(Some(pp));
275        }
276        let tokens = tokenizer
277            .encode_batch(data, true)
278            .map_err(|e| anyhow::anyhow!("{e}"))?;
279        let token_ids = tokens
280            .iter()
281            .map(|tokens| {
282                let tokens = tokens.get_ids().to_vec();
283                Ok(candle::Tensor::new(
284                    tokens.as_slice(),
285                    &candle::Device::Cpu,
286                )?)
287            })
288            .collect::<anyhow::Result<Vec<_>>>()?;
289
290        // Execute the model's forward propagation function, which generates the raw embeddings.
291        let token_ids = candle::Tensor::stack(&token_ids, 0)?;
292        let embeddings = model.forward(&token_ids, &token_ids.zeros_like()?)?;
293
294        // SBERT adds a pooling operation to the raw output to derive a fixed sized sentence embedding.
295        // The BERT models suggest using mean pooling, which is what the operation below performs.
296        // https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2#usage-huggingface-transformers
297        let (_, n_tokens, _) = embeddings.dims3()?;
298        let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
299
300        // Take each sentence embedding from the batch and arrange it in the final result tensor.
301        // Normalize each embedding as the last step (this generates vectors with length 1, which
302        // makes the cosine similarity function significantly more efficient (it becomes a simple
303        // dot product).
304        let mut results: Vec<Vec<f32>> = Vec::new();
305        for j in 0..n_sentences {
306            let e_j = embeddings.get(j)?;
307            let mut emb: Vec<f32> = e_j.to_vec1()?;
308            let length: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
309            emb.iter_mut().for_each(|x| *x /= length);
310            results.push(emb);
311        }
312
313        // There doesn't seem to currently be a way to stream the embeddings
314        // without buffering, so the damage (in terms of host memory usage) is
315        // already done, but we can still enforce the limit:
316        if std::mem::size_of::<wasi_llm::EmbeddingsResult>()
317            + results
318                .iter()
319                .map(|v| std::mem::size_of::<Vec<f32>>() + (v.len() * std::mem::size_of::<f32>()))
320                .sum::<usize>()
321            > max_result_bytes
322        {
323            anyhow::bail!("query result exceeds limit of {max_result_bytes} bytes")
324        }
325
326        let result = wasi_llm::EmbeddingsResult {
327            embeddings: results,
328            usage: wasi_llm::EmbeddingsUsage {
329                prompt_token_count: n_tokens as u32,
330            },
331        };
332        Ok(result)
333    })
334    .await?
335}
336
337fn load_tokenizer(tokenizer_file: &Path) -> anyhow::Result<tokenizers::Tokenizer> {
338    let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_file).map_err(|e| {
339        anyhow::anyhow!(
340            "Failed to read tokenizer file {}: {e}",
341            quoted_path(tokenizer_file)
342        )
343    })?;
344    Ok(tokenizer)
345}
346
347fn load_model(model_file: &Path) -> anyhow::Result<BertModel> {
348    let device = &candle::Device::Cpu;
349    let data = std::fs::read(model_file)?;
350    let tensors = load_buffer(&data, device)?;
351    let vb = VarBuilder::from_tensors(tensors, DType::F32, device);
352    let model = BertModel::load(vb, &Config::default()).context("error loading bert model")?;
353    Ok(model)
354}