1use std::collections::HashMap;
12use std::sync::Arc;
13
14use thiserror::Error;
15
16use crate::cache::AiResultCache;
17use crate::call_log::AiCallLog;
18use crate::provider::InferenceProvider;
19use crate::registry::{BackendKind, ModelBackend, ModelRegistry, RegistryError};
20
21#[derive(Clone)]
23pub struct ResolvedModel {
24 pub kind: BackendKind,
26 pub model_id: u32,
28 pub provider: Arc<dyn InferenceProvider>,
30 pub provider_model: String,
32 pub labels: Option<Vec<String>>,
34}
35
36#[derive(Debug, Error)]
38pub enum AiRuntimeError {
39 #[error(transparent)]
41 Registry(#[from] RegistryError),
42
43 #[error("model '{model}' references provider '{provider}', which is not configured")]
45 UnknownProvider {
46 model: String,
48 provider: String,
50 },
51
52 #[error("model '{0}' is local, but the local backend is not enabled in this build")]
54 LocalBackendUnavailable(String),
55}
56
57pub struct AiRuntime {
59 registry: ModelRegistry,
60 providers: HashMap<String, Arc<dyn InferenceProvider>>,
61 local_provider: Option<Arc<dyn InferenceProvider>>,
62 cache: Arc<AiResultCache>,
63 call_log: Arc<AiCallLog>,
64 model_ids: HashMap<String, u32>,
65}
66
67impl AiRuntime {
68 #[must_use]
72 pub fn new(
73 registry: ModelRegistry,
74 providers: impl IntoIterator<Item = (String, Arc<dyn InferenceProvider>)>,
75 local_provider: Option<Arc<dyn InferenceProvider>>,
76 cache: Arc<AiResultCache>,
77 call_log: Arc<AiCallLog>,
78 ) -> Self {
79 let model_ids = registry
80 .iter()
81 .enumerate()
82 .map(|(i, entry)| (entry.id.clone(), u32::try_from(i).unwrap_or(u32::MAX)))
83 .collect();
84 Self {
85 registry,
86 providers: providers.into_iter().collect(),
87 local_provider,
88 cache,
89 call_log,
90 model_ids,
91 }
92 }
93
94 #[must_use]
96 pub fn registry(&self) -> &ModelRegistry {
97 &self.registry
98 }
99
100 #[must_use]
102 pub fn cache(&self) -> &Arc<AiResultCache> {
103 &self.cache
104 }
105
106 #[must_use]
108 pub fn call_log(&self) -> &Arc<AiCallLog> {
109 &self.call_log
110 }
111
112 pub fn resolve(&self, model_name: &str) -> Result<ResolvedModel, AiRuntimeError> {
119 let entry = self.registry.resolve(model_name)?;
120 let model_id = self.model_ids.get(model_name).copied().unwrap_or(u32::MAX);
121 match &entry.backend {
122 ModelBackend::Remote { provider, model } => {
123 let client = self.providers.get(provider).ok_or_else(|| {
124 AiRuntimeError::UnknownProvider {
125 model: model_name.to_string(),
126 provider: provider.clone(),
127 }
128 })?;
129 Ok(ResolvedModel {
130 kind: BackendKind::Remote,
131 model_id,
132 provider: Arc::clone(client),
133 provider_model: model.clone(),
134 labels: None,
135 })
136 }
137 ModelBackend::Local { labels, source } => {
138 let client = self.local_provider.as_ref().ok_or_else(|| {
139 AiRuntimeError::LocalBackendUnavailable(model_name.to_string())
140 })?;
141 Ok(ResolvedModel {
142 kind: BackendKind::Local,
143 model_id,
144 provider: Arc::clone(client),
145 provider_model: source.clone(),
146 labels: labels.clone(),
147 })
148 }
149 }
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156 use crate::provider::{
157 InferenceOutputs, InferenceRequest, InferenceResponse, ProviderError, Usage,
158 };
159 use crate::registry::{ModelEntry, Task};
160 use async_trait::async_trait;
161
162 struct Stub;
163
164 #[async_trait]
165 impl InferenceProvider for Stub {
166 async fn infer_batch(
167 &self,
168 request: InferenceRequest,
169 ) -> Result<InferenceResponse, ProviderError> {
170 Ok(InferenceResponse {
171 outputs: InferenceOutputs::Text(vec![String::new(); request.inputs.len()]),
172 usage: Usage::ZERO,
173 })
174 }
175 fn name(&self) -> &'static str {
176 "stub"
177 }
178 }
179
180 fn runtime() -> AiRuntime {
181 let mut registry = ModelRegistry::new();
182 registry
183 .register(ModelEntry {
184 id: "haiku".into(),
185 tasks: vec![Task::Classify],
186 backend: ModelBackend::Remote {
187 provider: "anthropic".into(),
188 model: "claude-haiku-4-5-20251001".into(),
189 },
190 })
191 .unwrap();
192 registry
193 .register(ModelEntry {
194 id: "finbert".into(),
195 tasks: vec![Task::Classify],
196 backend: ModelBackend::Local {
197 source: "hf:onnx-community/finbert".into(),
198 labels: Some(vec!["positive".into(), "negative".into()]),
199 },
200 })
201 .unwrap();
202 let mut providers: HashMap<String, Arc<dyn InferenceProvider>> = HashMap::new();
203 providers.insert("anthropic".into(), Arc::new(Stub));
204 AiRuntime::new(
205 registry,
206 providers,
207 None,
208 Arc::new(AiResultCache::with_defaults()),
209 Arc::new(AiCallLog::with_defaults()),
210 )
211 }
212
213 #[test]
214 fn resolves_remote_model_to_its_provider() {
215 let rt = runtime();
216 let resolved = rt.resolve("haiku").unwrap();
217 assert_eq!(resolved.kind, BackendKind::Remote);
218 assert_eq!(resolved.provider.name(), "stub");
219 assert_eq!(resolved.provider_model, "claude-haiku-4-5-20251001");
220 }
221
222 #[test]
223 fn local_model_without_backend_errors() {
224 let rt = runtime();
225 assert!(matches!(
226 rt.resolve("finbert"),
227 Err(AiRuntimeError::LocalBackendUnavailable(_))
228 ));
229 }
230
231 #[test]
232 fn unknown_model_errors() {
233 let rt = runtime();
234 assert!(matches!(
235 rt.resolve("ghost"),
236 Err(AiRuntimeError::Registry(RegistryError::UnknownModel(_)))
237 ));
238 }
239}