Skip to main content

laminar_ai/
runtime.rs

1//! The assembled AI subsystem: the model registry, the provider clients that
2//! back it, the shared result cache, and the call log.
3//!
4//! Built once from server configuration and threaded into the engine. Given a
5//! model name referenced in SQL, [`AiRuntime::resolve`] returns everything the
6//! inference operator needs to run it: the backend kind, a stable cache id, the
7//! provider client, the provider-side model id, and any labels. The registry is
8//! kept whole (it also backs the `laminar.models` catalog view), even for models
9//! whose backend has no provider wired yet.
10
11use std::collections::HashMap;
12use std::sync::Arc;
13
14use thiserror::Error;
15
16use crate::cache::AiResultCache;
17use crate::call_log::AiCallLog;
18use crate::provider::InferenceProvider;
19use crate::registry::{BackendKind, ModelBackend, ModelRegistry, RegistryError};
20
21/// Everything the inference operator needs to run one model.
22#[derive(Clone)]
23pub struct ResolvedModel {
24    /// Backend kind (selects the adapter path).
25    pub kind: BackendKind,
26    /// Stable per-run integer id for the result-cache key.
27    pub model_id: u32,
28    /// The transport client.
29    pub provider: Arc<dyn InferenceProvider>,
30    /// The provider-side model identifier passed in the request.
31    pub provider_model: String,
32    /// Intrinsic labels (local classifiers), if known.
33    pub labels: Option<Vec<String>>,
34}
35
36/// Errors from resolving a model to a runnable backend.
37#[derive(Debug, Error)]
38pub enum AiRuntimeError {
39    /// The model is not registered, or it cannot perform the requested task.
40    #[error(transparent)]
41    Registry(#[from] RegistryError),
42
43    /// A remote model names a provider that was not configured.
44    #[error("model '{model}' references provider '{provider}', which is not configured")]
45    UnknownProvider {
46        /// The model name.
47        model: String,
48        /// The missing provider name.
49        provider: String,
50    },
51
52    /// A local model was referenced but the local backend is not available.
53    #[error("model '{0}' is local, but the local backend is not enabled in this build")]
54    LocalBackendUnavailable(String),
55}
56
57/// The assembled AI subsystem.
58pub struct AiRuntime {
59    registry: ModelRegistry,
60    providers: HashMap<String, Arc<dyn InferenceProvider>>,
61    local_provider: Option<Arc<dyn InferenceProvider>>,
62    cache: Arc<AiResultCache>,
63    call_log: Arc<AiCallLog>,
64    model_ids: HashMap<String, u32>,
65}
66
67impl AiRuntime {
68    /// Assemble a runtime. `providers` is keyed by provider name (matching a
69    /// model's `provider`); `local_provider`, when present, serves every local
70    /// model. Each registered model is assigned a stable cache id.
71    #[must_use]
72    pub fn new(
73        registry: ModelRegistry,
74        providers: impl IntoIterator<Item = (String, Arc<dyn InferenceProvider>)>,
75        local_provider: Option<Arc<dyn InferenceProvider>>,
76        cache: Arc<AiResultCache>,
77        call_log: Arc<AiCallLog>,
78    ) -> Self {
79        let model_ids = registry
80            .iter()
81            .enumerate()
82            .map(|(i, entry)| (entry.id.clone(), u32::try_from(i).unwrap_or(u32::MAX)))
83            .collect();
84        Self {
85            registry,
86            providers: providers.into_iter().collect(),
87            local_provider,
88            cache,
89            call_log,
90            model_ids,
91        }
92    }
93
94    /// The model registry (backs `laminar.models` and plan-time validation).
95    #[must_use]
96    pub fn registry(&self) -> &ModelRegistry {
97        &self.registry
98    }
99
100    /// The shared result cache.
101    #[must_use]
102    pub fn cache(&self) -> &Arc<AiResultCache> {
103        &self.cache
104    }
105
106    /// The call log (backs `laminar.ai_calls`).
107    #[must_use]
108    pub fn call_log(&self) -> &Arc<AiCallLog> {
109        &self.call_log
110    }
111
112    /// Resolve a model name to a runnable backend.
113    ///
114    /// # Errors
115    ///
116    /// Returns [`AiRuntimeError`] if the model is unknown, names an unconfigured
117    /// provider, or is local while the local backend is unavailable.
118    pub fn resolve(&self, model_name: &str) -> Result<ResolvedModel, AiRuntimeError> {
119        let entry = self.registry.resolve(model_name)?;
120        let model_id = self.model_ids.get(model_name).copied().unwrap_or(u32::MAX);
121        match &entry.backend {
122            ModelBackend::Remote { provider, model } => {
123                let client = self.providers.get(provider).ok_or_else(|| {
124                    AiRuntimeError::UnknownProvider {
125                        model: model_name.to_string(),
126                        provider: provider.clone(),
127                    }
128                })?;
129                Ok(ResolvedModel {
130                    kind: BackendKind::Remote,
131                    model_id,
132                    provider: Arc::clone(client),
133                    provider_model: model.clone(),
134                    labels: None,
135                })
136            }
137            ModelBackend::Local { labels, source } => {
138                let client = self.local_provider.as_ref().ok_or_else(|| {
139                    AiRuntimeError::LocalBackendUnavailable(model_name.to_string())
140                })?;
141                Ok(ResolvedModel {
142                    kind: BackendKind::Local,
143                    model_id,
144                    provider: Arc::clone(client),
145                    provider_model: source.clone(),
146                    labels: labels.clone(),
147                })
148            }
149        }
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use crate::provider::{
157        InferenceOutputs, InferenceRequest, InferenceResponse, ProviderError, Usage,
158    };
159    use crate::registry::{ModelEntry, Task};
160    use async_trait::async_trait;
161
162    struct Stub;
163
164    #[async_trait]
165    impl InferenceProvider for Stub {
166        async fn infer_batch(
167            &self,
168            request: InferenceRequest,
169        ) -> Result<InferenceResponse, ProviderError> {
170            Ok(InferenceResponse {
171                outputs: InferenceOutputs::Text(vec![String::new(); request.inputs.len()]),
172                usage: Usage::ZERO,
173            })
174        }
175        fn name(&self) -> &'static str {
176            "stub"
177        }
178    }
179
180    fn runtime() -> AiRuntime {
181        let mut registry = ModelRegistry::new();
182        registry
183            .register(ModelEntry {
184                id: "haiku".into(),
185                tasks: vec![Task::Classify],
186                backend: ModelBackend::Remote {
187                    provider: "anthropic".into(),
188                    model: "claude-haiku-4-5-20251001".into(),
189                },
190            })
191            .unwrap();
192        registry
193            .register(ModelEntry {
194                id: "finbert".into(),
195                tasks: vec![Task::Classify],
196                backend: ModelBackend::Local {
197                    source: "hf:onnx-community/finbert".into(),
198                    labels: Some(vec!["positive".into(), "negative".into()]),
199                },
200            })
201            .unwrap();
202        let mut providers: HashMap<String, Arc<dyn InferenceProvider>> = HashMap::new();
203        providers.insert("anthropic".into(), Arc::new(Stub));
204        AiRuntime::new(
205            registry,
206            providers,
207            None,
208            Arc::new(AiResultCache::with_defaults()),
209            Arc::new(AiCallLog::with_defaults()),
210        )
211    }
212
213    #[test]
214    fn resolves_remote_model_to_its_provider() {
215        let rt = runtime();
216        let resolved = rt.resolve("haiku").unwrap();
217        assert_eq!(resolved.kind, BackendKind::Remote);
218        assert_eq!(resolved.provider.name(), "stub");
219        assert_eq!(resolved.provider_model, "claude-haiku-4-5-20251001");
220    }
221
222    #[test]
223    fn local_model_without_backend_errors() {
224        let rt = runtime();
225        assert!(matches!(
226            rt.resolve("finbert"),
227            Err(AiRuntimeError::LocalBackendUnavailable(_))
228        ));
229    }
230
231    #[test]
232    fn unknown_model_errors() {
233        let rt = runtime();
234        assert!(matches!(
235            rt.resolve("ghost"),
236            Err(AiRuntimeError::Registry(RegistryError::UnknownModel(_)))
237        ));
238    }
239}