1use std::collections::HashMap;
12use std::path::PathBuf;
13use std::sync::Arc;
14
15use std::num::NonZeroU32;
16
17use laminar_ai::backends::{
18 local, AnthropicProvider, LocalProvider, OpenAiProvider, RateLimitedProvider,
19};
20use laminar_ai::{
21 AiCallLog, AiResultCache, AiRuntime, InferenceProvider, ModelBackend, ModelEntry,
22 ModelRegistry, Task,
23};
24
25use crate::config::{ProviderConfig, ServerConfig};
26use crate::server::ServerError;
27
28const CALL_LOG_CAPACITY: usize = 10_000;
30
31#[allow(clippy::result_large_err)] pub(crate) fn build_ai_runtime(
40 config: &ServerConfig,
41) -> Result<Option<Arc<AiRuntime>>, ServerError> {
42 if config.models.is_empty() {
43 return Ok(None);
44 }
45
46 let mut locals = config
50 .ai
51 .providers
52 .iter()
53 .filter(|(name, cfg)| provider_kind(name, cfg) == "local");
54 let local = locals.next();
55 if locals.next().is_some() {
56 return Err(build_err(
57 "more than one local AI provider configured; only one is supported".to_string(),
58 ));
59 }
60 let local_cache_dir: Option<PathBuf> = local
61 .and_then(|(_, cfg)| cfg.cache_dir.clone())
62 .map(PathBuf::from);
63 let local_provider: Option<Arc<dyn InferenceProvider>> = local_cache_dir
64 .clone()
65 .map(|dir| Arc::new(LocalProvider::new(dir)) as Arc<dyn InferenceProvider>);
66
67 let mut providers: HashMap<String, Arc<dyn InferenceProvider>> = HashMap::new();
68 for (name, provider) in &config.ai.providers {
69 if let Some(client) = build_provider(name, provider)? {
70 providers.insert(name.clone(), client);
71 }
72 }
73
74 let mut registry = ModelRegistry::new();
75 for (name, model) in &config.models {
76 let tasks = model
77 .task
78 .tasks()
79 .iter()
80 .map(|t| {
81 t.parse::<Task>()
82 .map_err(|e| build_err(format!("model '{name}': {e}")))
83 })
84 .collect::<Result<Vec<_>, _>>()?;
85 let backend = match model.kind.as_str() {
86 "remote" => ModelBackend::Remote {
87 provider: model.provider.clone().ok_or_else(|| {
88 build_err(format!("model '{name}': remote model requires a provider"))
89 })?,
90 model: model.model.clone().ok_or_else(|| {
91 build_err(format!("model '{name}': remote model requires a model id"))
92 })?,
93 },
94 "local" => {
95 let source = model.source.clone().ok_or_else(|| {
96 build_err(format!("model '{name}': local model requires a source"))
97 })?;
98 let labels = local_cache_dir
100 .as_ref()
101 .map(|dir| local::load_labels(dir, &source))
102 .filter(|l| !l.is_empty());
103 ModelBackend::Local { source, labels }
104 }
105 other => {
106 return Err(build_err(format!(
107 "model '{name}': kind must be 'local' or 'remote', got '{other}'"
108 )))
109 }
110 };
111 registry
112 .register(ModelEntry {
113 id: name.clone(),
114 tasks,
115 backend,
116 })
117 .map_err(|e| build_err(e.to_string()))?;
118 }
119
120 for (task_name, model_name) in &config.ai.defaults {
121 let task = task_name
122 .parse::<Task>()
123 .map_err(|e| build_err(format!("ai.defaults: {e}")))?;
124 registry.set_default(task, model_name.clone());
125 }
126
127 let cache = Arc::new(AiResultCache::with_defaults());
128 let call_log = Arc::new(AiCallLog::new(CALL_LOG_CAPACITY));
129 Ok(Some(Arc::new(AiRuntime::new(
130 registry,
131 providers,
132 local_provider,
133 cache,
134 call_log,
135 ))))
136}
137
138fn provider_kind<'a>(name: &'a str, cfg: &'a ProviderConfig) -> &'a str {
140 cfg.kind.as_deref().unwrap_or(name)
141}
142
143#[allow(clippy::result_large_err)]
146fn build_provider(
147 name: &str,
148 cfg: &ProviderConfig,
149) -> Result<Option<Arc<dyn InferenceProvider>>, ServerError> {
150 let base: Arc<dyn InferenceProvider> = match provider_kind(name, cfg) {
151 "local" => return Ok(None),
152 "anthropic" => {
153 let key = resolve_key(name, cfg)?;
154 let base_url = cfg
155 .base_url
156 .clone()
157 .unwrap_or_else(|| "https://api.anthropic.com".to_string());
158 Arc::new(
159 AnthropicProvider::new(base_url, key, cfg.max_concurrency)
160 .map_err(|e| build_err(e.to_string()))?,
161 )
162 }
163 _ => {
165 let key = resolve_key(name, cfg)?;
166 let base_url = cfg
167 .base_url
168 .clone()
169 .unwrap_or_else(|| "https://api.openai.com/v1".to_string());
170 Arc::new(
171 OpenAiProvider::new(base_url, key, cfg.max_concurrency)
172 .map_err(|e| build_err(e.to_string()))?,
173 )
174 }
175 };
176 Ok(Some(maybe_rate_limit(base, cfg)))
177}
178
179fn maybe_rate_limit(
181 provider: Arc<dyn InferenceProvider>,
182 cfg: &ProviderConfig,
183) -> Arc<dyn InferenceProvider> {
184 match cfg.requests_per_second.and_then(NonZeroU32::new) {
185 Some(rps) => Arc::new(RateLimitedProvider::new(provider, rps)),
186 None => provider,
187 }
188}
189
190#[allow(clippy::result_large_err)]
192fn resolve_key(name: &str, cfg: &ProviderConfig) -> Result<String, ServerError> {
193 let var = cfg.api_key_env.as_deref().ok_or_else(|| {
194 build_err(format!(
195 "provider '{name}': remote provider requires api_key_env"
196 ))
197 })?;
198 std::env::var(var).map_err(|_| {
199 build_err(format!(
200 "provider '{name}': environment variable '{var}' (api_key_env) is not set"
201 ))
202 })
203}
204
205fn build_err(msg: String) -> ServerError {
206 ServerError::Build(msg)
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 fn parse(toml: &str) -> ServerConfig {
214 toml::from_str(toml).unwrap()
215 }
216
217 #[test]
218 fn no_models_yields_none() {
219 assert!(build_ai_runtime(&parse("[server]\n")).unwrap().is_none());
220 }
221
222 #[test]
223 fn local_model_builds_without_a_key() {
224 let config = parse(
225 r#"
226[server]
227[ai.providers.local]
228cache_dir = "/tmp/models"
229[models.finbert]
230kind = "local"
231source = "hf:onnx-community/finbert"
232task = "classify"
233"#,
234 );
235 assert!(build_ai_runtime(&config).unwrap().is_some());
236 }
237
238 #[test]
242 fn crypto_sentiment_demo_builds_a_local_runtime() {
243 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
244 .join("../../examples/demos/crypto_sentiment/pipeline.toml");
245 let config = crate::config::load_config(&path).expect("demo config parses and validates");
246 let runtime = build_ai_runtime(&config)
247 .expect("AI runtime builds")
248 .expect("the demo configures a model");
249 assert_eq!(
250 runtime.registry().default_for(laminar_ai::Task::Sentiment),
251 Some("sentiment"),
252 "ai_sentiment resolves to the configured default"
253 );
254 assert_eq!(
255 runtime.resolve("sentiment").unwrap().kind,
256 laminar_ai::BackendKind::Local,
257 "the demo scores sentiment on a local model"
258 );
259 }
260
261 #[test]
262 fn remote_model_without_env_key_fails_fast() {
263 let config = parse(
264 r#"
265[server]
266[ai.providers.openai]
267api_key_env = "LAMINAR_TEST_DEFINITELY_UNSET_KEY_XYZ"
268base_url = "http://localhost:1234/v1"
269[models.embed]
270kind = "remote"
271provider = "openai"
272model = "text-embedding-3-small"
273task = "embed"
274"#,
275 );
276 let Err(err) = build_ai_runtime(&config) else {
277 panic!("expected an error for the unset api_key_env");
278 };
279 assert!(format!("{err}").contains("not set"), "{err}");
280 }
281}