spin_llm_local/
lib.rs

1mod bert;
2mod llama;
3
4use anyhow::Context;
5use bert::{BertModel, Config};
6use candle::{safetensors::load_buffer, DType};
7use candle_nn::VarBuilder;
8use spin_common::ui::quoted_path;
9use spin_core::async_trait;
10use spin_world::v2::llm::{self as wasi_llm};
11use std::{
12    collections::{hash_map::Entry, HashMap},
13    path::{Path, PathBuf},
14    str::FromStr,
15    sync::Arc,
16};
17use tokenizers::PaddingParams;
18
19const MODEL_ALL_MINILM_L6_V2: &str = "all-minilm-l6-v2";
20type ModelName = String;
21
22#[derive(Clone)]
23pub struct LocalLlmEngine {
24    registry: PathBuf,
25    inferencing_models: HashMap<ModelName, Arc<dyn InferencingModel>>,
26    embeddings_models: HashMap<String, Arc<(tokenizers::Tokenizer, BertModel)>>,
27}
28
29#[derive(Debug)]
30enum InferencingModelArch {
31    Llama,
32}
33
34impl FromStr for InferencingModelArch {
35    type Err = ();
36
37    fn from_str(s: &str) -> Result<Self, Self::Err> {
38        match s {
39            "llama" => Ok(InferencingModelArch::Llama),
40            _ => Err(()),
41        }
42    }
43}
44
45/// A model that is prepared and cached after loading.
46///
47/// This trait does not specify anything about if the results are cached.
48#[async_trait]
49trait InferencingModel: Send + Sync {
50    async fn infer(
51        &self,
52        prompt: String,
53        params: wasi_llm::InferencingParams,
54    ) -> anyhow::Result<wasi_llm::InferencingResult>;
55}
56
57impl LocalLlmEngine {
58    pub async fn infer(
59        &mut self,
60        model: wasi_llm::InferencingModel,
61        prompt: String,
62        params: wasi_llm::InferencingParams,
63    ) -> Result<wasi_llm::InferencingResult, wasi_llm::Error> {
64        let model = self.inferencing_model(model).await?;
65
66        model
67            .infer(prompt, params)
68            .await
69            .map_err(|e| wasi_llm::Error::RuntimeError(e.to_string()))
70    }
71
72    pub async fn generate_embeddings(
73        &mut self,
74        model: wasi_llm::EmbeddingModel,
75        data: Vec<String>,
76    ) -> Result<wasi_llm::EmbeddingsResult, wasi_llm::Error> {
77        let model = self.embeddings_model(model).await?;
78        generate_embeddings(data, model).await.map_err(|e| {
79            wasi_llm::Error::RuntimeError(format!("Error occurred generating embeddings: {e}"))
80        })
81    }
82}
83
84impl LocalLlmEngine {
85    pub fn new(registry: PathBuf) -> Self {
86        Self {
87            registry,
88            inferencing_models: Default::default(),
89            embeddings_models: Default::default(),
90        }
91    }
92
93    /// Get embeddings model from cache or load from disk
94    async fn embeddings_model(
95        &mut self,
96        model: wasi_llm::EmbeddingModel,
97    ) -> Result<Arc<(tokenizers::Tokenizer, BertModel)>, wasi_llm::Error> {
98        let key = match model.as_str() {
99            MODEL_ALL_MINILM_L6_V2 => model,
100            _ => return Err(wasi_llm::Error::ModelNotSupported),
101        };
102        let registry_path = self.registry.join(&key);
103        let r = match self.embeddings_models.entry(key) {
104            Entry::Occupied(o) => o.get().clone(),
105            Entry::Vacant(v) => v
106                .insert({
107                    tokio::task::spawn_blocking(move || {
108                        if !registry_path.exists() {
109                            return Err(
110                                wasi_llm::Error::RuntimeError(format!(
111                                "The directory expected to house the embeddings models '{}' does not exist.",
112                                registry_path.display()
113                            )));
114                        }
115                        let tokenizer_file = registry_path.join("tokenizer.json");
116                        let model_file = registry_path.join("model.safetensors");
117                        let tokenizer = load_tokenizer(&tokenizer_file).map_err(|_| {
118                            wasi_llm::Error::RuntimeError(format!(
119                                "Failed to load embeddings tokenizer from '{}'",
120                                tokenizer_file.display()
121                            ))
122                        })?;
123                        let model = load_model(&model_file).map_err(|_| {
124                            wasi_llm::Error::RuntimeError(format!(
125                                "Failed to load embeddings model from '{}'",
126                                model_file.display()
127                            ))
128                        })?;
129                        Ok(Arc::new((tokenizer, model)))
130                    })
131                    .await
132                    .map_err(|_| {
133                        wasi_llm::Error::RuntimeError("Error loading inferencing model".into())
134                    })??
135                })
136                .clone(),
137        };
138        Ok(r)
139    }
140
141    /// Get inferencing model from cache or load from disk
142    async fn inferencing_model(
143        &mut self,
144        model: wasi_llm::InferencingModel,
145    ) -> Result<Arc<dyn InferencingModel>, wasi_llm::Error> {
146        let model = match self.inferencing_models.entry(model.clone()) {
147            Entry::Occupied(o) => o.get().clone(),
148            Entry::Vacant(v) => {
149                let (model_dir, arch) =
150                    walk_registry_for_model(&self.registry, model.clone()).await?;
151                let model = match arch {
152                    InferencingModelArch::Llama => Arc::new(
153                        llama::LlamaModels::new(&model_dir)
154                            .await
155                            .map_err(|e| wasi_llm::Error::RuntimeError(e.to_string()))?,
156                    ),
157                };
158
159                v.insert(model.clone());
160
161                model
162            }
163        };
164        Ok(model)
165    }
166}
167
168/// Walks the registry file structure and returns the directory the model is
169/// present along with its architecture
170async fn walk_registry_for_model(
171    registry_path: &Path,
172    model: String,
173) -> Result<(PathBuf, InferencingModelArch), wasi_llm::Error> {
174    let mut arch_dirs = tokio::fs::read_dir(registry_path).await.map_err(|e| {
175        wasi_llm::Error::RuntimeError(format!(
176            "Could not read model registry directory '{}': {e}",
177            registry_path.display()
178        ))
179    })?;
180    let mut result = None;
181    'outer: while let Some(arch_dir) = arch_dirs.next_entry().await.map_err(|e| {
182        wasi_llm::Error::RuntimeError(format!(
183            "Failed to read arch directory in model registry: {e}"
184        ))
185    })? {
186        if arch_dir
187            .file_type()
188            .await
189            .map_err(|e| {
190                wasi_llm::Error::RuntimeError(format!(
191                    "Could not read file type of '{}' dir: {e}",
192                    arch_dir.path().display()
193                ))
194            })?
195            .is_file()
196        {
197            continue;
198        }
199        let mut model_dirs = tokio::fs::read_dir(arch_dir.path()).await.map_err(|e| {
200            wasi_llm::Error::RuntimeError(format!(
201                "Error reading architecture directory in model registry: {e}"
202            ))
203        })?;
204        while let Some(model_dir) = model_dirs.next_entry().await.map_err(|e| {
205            wasi_llm::Error::RuntimeError(format!(
206                "Error reading model folder in model registry: {e}"
207            ))
208        })? {
209            // Models need to be a directory. So ignore any files.
210            if model_dir
211                .file_type()
212                .await
213                .map_err(|e| {
214                    wasi_llm::Error::RuntimeError(format!(
215                        "Could not read file type of '{}' dir: {e}",
216                        model_dir.path().display()
217                    ))
218                })?
219                .is_file()
220            {
221                continue;
222            }
223            if model_dir
224                .file_name()
225                .to_str()
226                .map(|m| m == model)
227                .unwrap_or_default()
228            {
229                let arch = arch_dir.file_name();
230                let arch = arch
231                    .to_str()
232                    .ok_or(wasi_llm::Error::ModelNotSupported)?
233                    .parse()
234                    .map_err(|_| wasi_llm::Error::ModelNotSupported)?;
235                result = Some((model_dir.path(), arch));
236                break 'outer;
237            }
238        }
239    }
240
241    result.ok_or_else(|| {
242        wasi_llm::Error::InvalidInput(format!(
243            "no model directory found in registry for model '{model}'"
244        ))
245    })
246}
247
248async fn generate_embeddings(
249    data: Vec<String>,
250    model: Arc<(tokenizers::Tokenizer, BertModel)>,
251) -> anyhow::Result<wasi_llm::EmbeddingsResult> {
252    let n_sentences = data.len();
253    tokio::task::spawn_blocking(move || {
254        let mut tokenizer = model.0.clone();
255        let model = &model.1;
256        // This function attempts to generate the embeddings for a batch of inputs, most
257        // likely of different lengths.
258        // The tokenizer expects all inputs in a batch to have the same length, so the
259        // following is configuring the tokenizer to pad (add trailing zeros) each input
260        // to match the length of the longest in the batch.
261        if let Some(pp) = tokenizer.get_padding_mut() {
262            pp.strategy = tokenizers::PaddingStrategy::BatchLongest
263        } else {
264            let pp = PaddingParams {
265                strategy: tokenizers::PaddingStrategy::BatchLongest,
266                ..Default::default()
267            };
268            tokenizer.with_padding(Some(pp));
269        }
270        let tokens = tokenizer
271            .encode_batch(data, true)
272            .map_err(|e| anyhow::anyhow!("{e}"))?;
273        let token_ids = tokens
274            .iter()
275            .map(|tokens| {
276                let tokens = tokens.get_ids().to_vec();
277                Ok(candle::Tensor::new(
278                    tokens.as_slice(),
279                    &candle::Device::Cpu,
280                )?)
281            })
282            .collect::<anyhow::Result<Vec<_>>>()?;
283
284        // Execute the model's forward propagation function, which generates the raw embeddings.
285        let token_ids = candle::Tensor::stack(&token_ids, 0)?;
286        let embeddings = model.forward(&token_ids, &token_ids.zeros_like()?)?;
287
288        // SBERT adds a pooling operation to the raw output to derive a fixed sized sentence embedding.
289        // The BERT models suggest using mean pooling, which is what the operation below performs.
290        // https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2#usage-huggingface-transformers
291        let (_, n_tokens, _) = embeddings.dims3()?;
292        let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
293
294        // Take each sentence embedding from the batch and arrange it in the final result tensor.
295        // Normalize each embedding as the last step (this generates vectors with length 1, which
296        // makes the cosine similarity function significantly more efficient (it becomes a simple
297        // dot product).
298        let mut results: Vec<Vec<f32>> = Vec::new();
299        for j in 0..n_sentences {
300            let e_j = embeddings.get(j)?;
301            let mut emb: Vec<f32> = e_j.to_vec1()?;
302            let length: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
303            emb.iter_mut().for_each(|x| *x /= length);
304            results.push(emb);
305        }
306
307        let result = wasi_llm::EmbeddingsResult {
308            embeddings: results,
309            usage: wasi_llm::EmbeddingsUsage {
310                prompt_token_count: n_tokens as u32,
311            },
312        };
313        Ok(result)
314    })
315    .await?
316}
317
318fn load_tokenizer(tokenizer_file: &Path) -> anyhow::Result<tokenizers::Tokenizer> {
319    let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_file).map_err(|e| {
320        anyhow::anyhow!(
321            "Failed to read tokenizer file {}: {e}",
322            quoted_path(tokenizer_file)
323        )
324    })?;
325    Ok(tokenizer)
326}
327
328fn load_model(model_file: &Path) -> anyhow::Result<BertModel> {
329    let device = &candle::Device::Cpu;
330    let data = std::fs::read(model_file)?;
331    let tensors = load_buffer(&data, device)?;
332    let vb = VarBuilder::from_tensors(tensors, DType::F32, device);
333    let model = BertModel::load(vb, &Config::default()).context("error loading bert model")?;
334    Ok(model)
335}