spin_factors_derive/
lib.rs

1use proc_macro2::TokenStream;
2use quote::{format_ident, quote};
3use syn::{parse_macro_input, Data, DeriveInput, Error};
4
5#[proc_macro_derive(RuntimeFactors)]
6pub fn derive_factors(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
7    let input = parse_macro_input!(input as DeriveInput);
8    let expanded = expand_factors(&input).unwrap_or_else(|err| err.into_compile_error());
9
10    #[cfg(feature = "expander")]
11    let expanded = if let Some(dest_dir) = std::env::var_os("SPIN_FACTORS_DERIVE_EXPAND_DIR") {
12        expander::Expander::new("factors")
13            .write_to(expanded, std::path::Path::new(&dest_dir))
14            .unwrap()
15    } else {
16        expanded
17    };
18
19    expanded.into()
20}
21
22#[allow(non_snake_case)]
23fn expand_factors(input: &DeriveInput) -> syn::Result<TokenStream> {
24    let name = &input.ident;
25    let vis = &input.vis;
26
27    let app_state_name = format_ident!("{name}AppState");
28    let builders_name = format_ident!("{name}InstanceBuilders");
29    let state_name = format_ident!("{name}InstanceState");
30    let runtime_config_name = format_ident!("{name}RuntimeConfig");
31
32    if !input.generics.params.is_empty() {
33        return Err(Error::new_spanned(
34            input,
35            "cannot derive Factors for generic structs",
36        ));
37    }
38
39    // Get struct fields
40    let fields = match &input.data {
41        Data::Struct(struct_data) => &struct_data.fields,
42        _ => {
43            return Err(Error::new_spanned(
44                input,
45                "can only derive Factors for structs",
46            ))
47        }
48    };
49    let mut factor_names = Vec::with_capacity(fields.len());
50    let mut factor_types = Vec::with_capacity(fields.len());
51    for field in fields.iter() {
52        factor_names.push(
53            field
54                .ident
55                .as_ref()
56                .ok_or_else(|| Error::new_spanned(input, "tuple structs are not supported"))?,
57        );
58        factor_types.push(&field.ty);
59    }
60
61    let Any = quote!(::std::any::Any);
62    let Send = quote!(::std::marker::Send);
63    let TypeId = quote!(::std::any::TypeId);
64    let factors_crate = format_ident!("spin_factors");
65    let factors_path = quote!(::#factors_crate);
66    let wasmtime = quote!(#factors_path::wasmtime);
67    let ResourceTable = quote!(#wasmtime::component::ResourceTable);
68    let Result = quote!(#factors_path::Result);
69    let Error = quote!(#factors_path::Error);
70    let Factor = quote!(#factors_path::Factor);
71    let ConfiguredApp = quote!(#factors_path::ConfiguredApp);
72    let FactorInstanceBuilder = quote!(#factors_path::FactorInstanceBuilder);
73
74    Ok(quote! {
75        impl #factors_path::RuntimeFactors for #name {
76            type AppState = #app_state_name;
77            type InstanceBuilders = #builders_name;
78            type InstanceState = #state_name;
79            type RuntimeConfig = #runtime_config_name;
80
81            fn init<T: #factors_path::AsInstanceState<Self::InstanceState> + Send + 'static>(
82                &mut self,
83                linker: &mut #wasmtime::component::Linker<T>,
84            ) -> #Result<()> {
85                let factor_type_ids = [#(
86                    (stringify!(#factor_types), #TypeId::of::<(<#factor_types as #Factor>::InstanceBuilder, <#factor_types as #Factor>::AppState)>()),
87                )*];
88
89                let mut unique = ::std::collections::HashSet::new();
90                for (name, type_id) in factor_type_ids {
91                    if !unique.insert(type_id) {
92                        return Err(#Error::DuplicateFactorTypes(name.to_owned()));
93                    }
94                }
95
96                #(
97                    #[allow(non_camel_case_types)]
98                    struct #factor_names;
99
100                    impl #factors_path::FactorField for #factor_names {
101                        type State = #state_name;
102                        type Factor = #factor_types;
103
104                        fn get(state: &mut #state_name) -> (
105                            &mut #factors_path::FactorInstanceState<#factor_types>,
106                            &mut #wasmtime::component::ResourceTable,
107                        ) {
108                            (&mut state.#factor_names, &mut state.__table)
109                        }
110                    }
111
112                    #Factor::init(
113                        &mut self.#factor_names,
114                        &mut #factors_path::FactorInitContext::<'_, T, #factor_names> {
115                            linker,
116                            _marker: std::marker::PhantomData,
117                        },
118                    ).map_err(#Error::factor_init_error::<#factor_types>)?;
119                )*
120                Ok(())
121            }
122
123            fn configure_app(
124                &self,
125                app: #factors_path::App,
126                runtime_config: Self::RuntimeConfig,
127            ) -> #Result<#ConfiguredApp<Self>> {
128                let mut app_state = #app_state_name {
129                    #( #factor_names: None, )*
130                };
131                #(
132                    app_state.#factor_names = Some(
133                        #Factor::configure_app(
134                            &self.#factor_names,
135                            #factors_path::ConfigureAppContext::<Self, #factor_types>::new(
136                                &app,
137                                &app_state,
138                                runtime_config.#factor_names,
139                            )?,
140                        ).map_err(#Error::factor_configure_app_error::<#factor_types>)?
141                    );
142                )*
143                Ok(#ConfiguredApp::new(app, app_state))
144            }
145
146            fn prepare(
147                &self, configured_app: &#ConfiguredApp<Self>,
148                component_id: &str,
149            ) -> #Result<Self::InstanceBuilders> {
150                let app_component = configured_app.app().get_component(component_id).ok_or_else(|| {
151                    #factors_path::Error::UnknownComponent(component_id.to_string())
152                })?;
153                let mut builders = #builders_name {
154                    #( #factor_names: None, )*
155                };
156                #(
157                    builders.#factor_names = Some(
158                        #Factor::prepare::<Self>(
159                            &self.#factor_names,
160                            #factors_path::PrepareContext::new(
161                                configured_app.app_state::<#factor_types>().unwrap(),
162                                &app_component,
163                                &mut builders,
164                            ),
165                        ).map_err(#Error::factor_prepare_error::<#factor_types>)?
166                    );
167                )*
168                Ok(builders)
169            }
170
171            fn build_instance_state(
172                &self,
173                builders: Self::InstanceBuilders,
174            ) -> #Result<Self::InstanceState> {
175                Ok(#state_name {
176                    __table: #ResourceTable::new(),
177                    #(
178                        #factor_names: #FactorInstanceBuilder::build(
179                            builders.#factor_names.unwrap()
180                        ).map_err(#Error::factor_build_error::<#factor_types>)?,
181                    )*
182                })
183            }
184
185            fn app_state<F: #Factor>(app_state: &Self::AppState) -> Option<&F::AppState> {
186                #(
187                    if let Some(state) = &app_state.#factor_names {
188                        if let Some(state) = <dyn #Any>::downcast_ref(state) {
189                            return Some(state)
190                        }
191                    }
192                )*
193                None
194            }
195
196            fn instance_builder_mut<F: #Factor>(
197                builders: &mut Self::InstanceBuilders,
198            ) -> Option<Option<&mut F::InstanceBuilder>> {
199                let type_id = #TypeId::of::<(F::InstanceBuilder, F::AppState)>();
200                #(
201                    if type_id == #TypeId::of::<(<#factor_types as #Factor>::InstanceBuilder, <#factor_types as #Factor>::AppState)>() {
202                        return Some(
203                            builders.#factor_names.as_mut().map(|builder| {
204                                <dyn #Any>::downcast_mut(builder).unwrap()
205                            })
206                        );
207                    }
208                )*
209                None
210            }
211        }
212
213        #vis struct #app_state_name {
214            #(
215                pub #factor_names: Option<<#factor_types as #Factor>::AppState>,
216            )*
217        }
218
219        #vis struct #builders_name {
220            #(
221                #factor_names: Option<<#factor_types as #Factor>::InstanceBuilder>,
222            )*
223        }
224
225        #[allow(dead_code)]
226        impl #builders_name {
227            #(
228                pub fn #factor_names(&mut self) -> &mut <#factor_types as #Factor>::InstanceBuilder {
229                    self.#factor_names.as_mut().unwrap()
230                }
231            )*
232        }
233
234        impl #factors_path::HasInstanceBuilder for #builders_name {
235            fn for_factor<F: #Factor>(
236                &mut self
237            ) -> Option<&mut F::InstanceBuilder> {
238                let type_id = #TypeId::of::<F::InstanceBuilder>();
239                #(
240                    if type_id == #TypeId::of::<<#factor_types as #Factor>::InstanceBuilder>() {
241                        let builder = self.#factor_names.as_mut().unwrap();
242                        return Some(
243                            <dyn #Any>::downcast_mut(builder).unwrap()
244                        );
245                    }
246                )*
247                None
248            }
249        }
250
251        #vis struct #state_name {
252            __table: #ResourceTable,
253            #(
254                pub #factor_names: #factors_path::FactorInstanceState<#factor_types>,
255            )*
256        }
257
258        impl #factors_path::RuntimeFactorsInstanceState for #state_name {
259            fn get_with_table<F: #Factor>(
260                &mut self
261            ) -> ::std::option::Option<(&mut #factors_path::FactorInstanceState<F>, &mut #ResourceTable)> {
262                #(
263                    if let Some(state) = (&mut self.#factor_names as &mut (dyn #Any + #Send)).downcast_mut() {
264                        return Some((state, &mut self.__table))
265                    }
266                )*
267                None
268            }
269
270            fn table(&self) -> &#ResourceTable {
271                &self.__table
272            }
273
274            fn table_mut(&mut self) -> &mut #ResourceTable {
275                &mut self.__table
276            }
277        }
278
279        impl #factors_path::AsInstanceState<#state_name> for #state_name {
280            fn as_instance_state(&mut self) -> &mut Self {
281                self
282            }
283        }
284
285        #[derive(Default)]
286        #vis struct #runtime_config_name {
287            #(
288                pub #factor_names: Option<<#factor_types as #Factor>::RuntimeConfig>,
289            )*
290        }
291
292        impl #runtime_config_name {
293            /// Get the runtime configuration from the given source.
294            #[allow(dead_code)]
295            pub fn from_source<T>(mut source: T) -> anyhow::Result<Self>
296                where T: #(#factors_path::FactorRuntimeConfigSource<#factor_types> +)* #factors_path::RuntimeConfigSourceFinalizer
297            {
298                #(
299                    let #factor_names = <T as #factors_path::FactorRuntimeConfigSource<#factor_types>>::get_runtime_config(&mut source)?;
300                )*
301                source.finalize()?;
302                Ok(#runtime_config_name {
303                    #(
304                        #factor_names,
305                    )*
306                })
307            }
308        }
309    })
310}