1mod host;
2pub mod spin;
3
4use std::collections::{HashMap, HashSet};
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use spin_factors::{
9 ConfigureAppContext, Factor, PrepareContext, RuntimeFactors, SelfInstanceBuilder,
10};
11use spin_locked_app::MetadataKey;
12use spin_world::v1::llm::{self as v1};
13use spin_world::v2::llm::{self as v2};
14use tokio::sync::Mutex;
15
16pub const ALLOWED_MODELS_KEY: MetadataKey<Vec<String>> = MetadataKey::new("ai_models");
17
18pub struct LlmFactor {
20 default_engine_creator: Box<dyn LlmEngineCreator>,
21}
22
23impl LlmFactor {
24 pub fn new<F: LlmEngineCreator + 'static>(default_engine_creator: F) -> Self {
28 Self {
29 default_engine_creator: Box::new(default_engine_creator),
30 }
31 }
32}
33
34impl Factor for LlmFactor {
35 type RuntimeConfig = RuntimeConfig;
36 type AppState = AppState;
37 type InstanceBuilder = InstanceState;
38
39 fn init(&mut self, ctx: &mut impl spin_factors::InitContext<Self>) -> anyhow::Result<()> {
40 ctx.link_bindings(spin_world::v1::llm::add_to_linker)?;
41 ctx.link_bindings(spin_world::v2::llm::add_to_linker)?;
42 Ok(())
43 }
44
45 fn configure_app<T: RuntimeFactors>(
46 &self,
47 mut ctx: ConfigureAppContext<T, Self>,
48 ) -> anyhow::Result<Self::AppState> {
49 let component_allowed_models = ctx
50 .app()
51 .components()
52 .map(|component| {
53 Ok((
54 component.id().to_string(),
55 component
56 .get_metadata(ALLOWED_MODELS_KEY)?
57 .unwrap_or_default()
58 .into_iter()
59 .collect::<HashSet<_>>()
60 .into(),
61 ))
62 })
63 .collect::<anyhow::Result<_>>()?;
64 let engine = ctx
65 .take_runtime_config()
66 .map(|c| c.engine)
67 .unwrap_or_else(|| self.default_engine_creator.create());
68 Ok(AppState {
69 engine,
70 component_allowed_models,
71 })
72 }
73
74 fn prepare<T: RuntimeFactors>(
75 &self,
76 ctx: PrepareContext<T, Self>,
77 ) -> anyhow::Result<Self::InstanceBuilder> {
78 let allowed_models = ctx
79 .app_state()
80 .component_allowed_models
81 .get(ctx.app_component().id())
82 .cloned()
83 .unwrap_or_default();
84 let engine = ctx.app_state().engine.clone();
85
86 Ok(InstanceState {
87 engine,
88 allowed_models,
89 })
90 }
91}
92
93pub struct AppState {
95 engine: Arc<Mutex<dyn LlmEngine>>,
96 component_allowed_models: HashMap<String, Arc<HashSet<String>>>,
97}
98
99pub struct InstanceState {
101 engine: Arc<Mutex<dyn LlmEngine>>,
102 pub allowed_models: Arc<HashSet<String>>,
103}
104
105pub struct RuntimeConfig {
107 engine: Arc<Mutex<dyn LlmEngine>>,
108}
109
110impl SelfInstanceBuilder for InstanceState {}
111
112#[async_trait]
114pub trait LlmEngine: Send + Sync {
115 async fn infer(
116 &mut self,
117 model: v1::InferencingModel,
118 prompt: String,
119 params: v2::InferencingParams,
120 ) -> Result<v2::InferencingResult, v2::Error>;
121
122 async fn generate_embeddings(
123 &mut self,
124 model: v2::EmbeddingModel,
125 data: Vec<String>,
126 ) -> Result<v2::EmbeddingsResult, v2::Error>;
127
128 fn summary(&self) -> Option<String> {
132 None
133 }
134}
135
136pub trait LlmEngineCreator: Send + Sync {
138 fn create(&self) -> Arc<Mutex<dyn LlmEngine>>;
139}
140
141impl<F> LlmEngineCreator for F
142where
143 F: Fn() -> Arc<Mutex<dyn LlmEngine>> + Send + Sync,
144{
145 fn create(&self) -> Arc<Mutex<dyn LlmEngine>> {
146 self()
147 }
148}