1mod host;
2pub mod spin;
3
4use std::collections::{HashMap, HashSet};
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use spin_factors::{
9 ConfigureAppContext, Factor, PrepareContext, RuntimeFactors, SelfInstanceBuilder,
10};
11use spin_locked_app::MetadataKey;
12use spin_world::v1::llm::{self as v1};
13use spin_world::v2::llm::{self as v2};
14use tokio::sync::Mutex;
15
16pub const ALLOWED_MODELS_KEY: MetadataKey<Vec<String>> = MetadataKey::new("ai_models");
17
18pub struct LlmFactor {
20 default_engine_creator: Box<dyn LlmEngineCreator>,
21}
22
23impl LlmFactor {
24 pub fn new<F: LlmEngineCreator + 'static>(default_engine_creator: F) -> Self {
28 Self {
29 default_engine_creator: Box::new(default_engine_creator),
30 }
31 }
32}
33
34impl Factor for LlmFactor {
35 type RuntimeConfig = RuntimeConfig;
36 type AppState = AppState;
37 type InstanceBuilder = InstanceState;
38
39 fn init<T: Send + 'static>(
40 &mut self,
41 mut ctx: spin_factors::InitContext<T, Self>,
42 ) -> anyhow::Result<()> {
43 ctx.link_bindings(spin_world::v1::llm::add_to_linker)?;
44 ctx.link_bindings(spin_world::v2::llm::add_to_linker)?;
45 Ok(())
46 }
47
48 fn configure_app<T: RuntimeFactors>(
49 &self,
50 mut ctx: ConfigureAppContext<T, Self>,
51 ) -> anyhow::Result<Self::AppState> {
52 let component_allowed_models = ctx
53 .app()
54 .components()
55 .map(|component| {
56 Ok((
57 component.id().to_string(),
58 component
59 .get_metadata(ALLOWED_MODELS_KEY)?
60 .unwrap_or_default()
61 .into_iter()
62 .collect::<HashSet<_>>()
63 .into(),
64 ))
65 })
66 .collect::<anyhow::Result<_>>()?;
67 let engine = ctx
68 .take_runtime_config()
69 .map(|c| c.engine)
70 .unwrap_or_else(|| self.default_engine_creator.create());
71 Ok(AppState {
72 engine,
73 component_allowed_models,
74 })
75 }
76
77 fn prepare<T: RuntimeFactors>(
78 &self,
79 ctx: PrepareContext<T, Self>,
80 ) -> anyhow::Result<Self::InstanceBuilder> {
81 let allowed_models = ctx
82 .app_state()
83 .component_allowed_models
84 .get(ctx.app_component().id())
85 .cloned()
86 .unwrap_or_default();
87 let engine = ctx.app_state().engine.clone();
88
89 Ok(InstanceState {
90 engine,
91 allowed_models,
92 })
93 }
94}
95
96pub struct AppState {
98 engine: Arc<Mutex<dyn LlmEngine>>,
99 component_allowed_models: HashMap<String, Arc<HashSet<String>>>,
100}
101
102pub struct InstanceState {
104 engine: Arc<Mutex<dyn LlmEngine>>,
105 pub allowed_models: Arc<HashSet<String>>,
106}
107
108pub struct RuntimeConfig {
110 engine: Arc<Mutex<dyn LlmEngine>>,
111}
112
113impl SelfInstanceBuilder for InstanceState {}
114
115#[async_trait]
117pub trait LlmEngine: Send + Sync {
118 async fn infer(
119 &mut self,
120 model: v1::InferencingModel,
121 prompt: String,
122 params: v2::InferencingParams,
123 ) -> Result<v2::InferencingResult, v2::Error>;
124
125 async fn generate_embeddings(
126 &mut self,
127 model: v2::EmbeddingModel,
128 data: Vec<String>,
129 ) -> Result<v2::EmbeddingsResult, v2::Error>;
130
131 fn summary(&self) -> Option<String> {
135 None
136 }
137}
138
139pub trait LlmEngineCreator: Send + Sync {
141 fn create(&self) -> Arc<Mutex<dyn LlmEngine>>;
142}
143
144impl<F> LlmEngineCreator for F
145where
146 F: Fn() -> Arc<Mutex<dyn LlmEngine>> + Send + Sync,
147{
148 fn create(&self) -> Arc<Mutex<dyn LlmEngine>> {
149 self()
150 }
151}