Skip to main content

laminardb/
ai.rs

1//! Build the AI subsystem (model registry + provider clients + cache + call
2//! log) from `[ai]` / `[models]` server configuration.
3//!
4//! Secrets are resolved here, at startup: `api_key_env` names an environment
5//! variable, never the key itself. A provider's transport is its `kind` (or,
6//! when omitted, inferred from the provider name): `anthropic`, `local`, or
7//! otherwise an OpenAI-compatible endpoint (`openai`, Azure, vLLM, …). A single
8//! `local` provider (its `cache_dir`) backs every local model; remote providers
9//! are keyed by name and matched to each remote model's `provider`.
10
11use std::collections::HashMap;
12use std::path::PathBuf;
13use std::sync::Arc;
14
15use std::num::NonZeroU32;
16
17use laminar_ai::backends::{
18    local, AnthropicProvider, LocalProvider, OpenAiProvider, RateLimitedProvider,
19};
20use laminar_ai::{
21    AiCallLog, AiResultCache, AiRuntime, InferenceProvider, ModelBackend, ModelEntry,
22    ModelRegistry, Task,
23};
24
25use crate::config::{ProviderConfig, ServerConfig};
26use crate::server::ServerError;
27
28/// Retained `laminar.ai_calls` records.
29const CALL_LOG_CAPACITY: usize = 10_000;
30
31/// Build the AI runtime from configuration, or `None` if no models are
32/// configured.
33///
34/// # Errors
35///
36/// Returns [`ServerError::Build`] for an unset `api_key_env`, an unknown task
37/// name, a malformed model entry, or a provider client that fails to construct.
38#[allow(clippy::result_large_err)] // matches the crate's ServerError convention
39pub(crate) fn build_ai_runtime(
40    config: &ServerConfig,
41) -> Result<Option<Arc<AiRuntime>>, ServerError> {
42    if config.models.is_empty() {
43        return Ok(None);
44    }
45
46    // One local provider serves every local model. More than one is ambiguous
47    // (which cache_dir wins?), and picking via HashMap iteration order would be
48    // nondeterministic — so require exactly zero or one.
49    let mut locals = config
50        .ai
51        .providers
52        .iter()
53        .filter(|(name, cfg)| provider_kind(name, cfg) == "local");
54    let local = locals.next();
55    if locals.next().is_some() {
56        return Err(build_err(
57            "more than one local AI provider configured; only one is supported".to_string(),
58        ));
59    }
60    let local_cache_dir: Option<PathBuf> = local
61        .and_then(|(_, cfg)| cfg.cache_dir.clone())
62        .map(PathBuf::from);
63    let local_provider: Option<Arc<dyn InferenceProvider>> = local_cache_dir
64        .clone()
65        .map(|dir| Arc::new(LocalProvider::new(dir)) as Arc<dyn InferenceProvider>);
66
67    let mut providers: HashMap<String, Arc<dyn InferenceProvider>> = HashMap::new();
68    for (name, provider) in &config.ai.providers {
69        if let Some(client) = build_provider(name, provider)? {
70            providers.insert(name.clone(), client);
71        }
72    }
73
74    let mut registry = ModelRegistry::new();
75    for (name, model) in &config.models {
76        let tasks = model
77            .task
78            .tasks()
79            .iter()
80            .map(|t| {
81                t.parse::<Task>()
82                    .map_err(|e| build_err(format!("model '{name}': {e}")))
83            })
84            .collect::<Result<Vec<_>, _>>()?;
85        let backend = match model.kind.as_str() {
86            "remote" => ModelBackend::Remote {
87                provider: model.provider.clone().ok_or_else(|| {
88                    build_err(format!("model '{name}': remote model requires a provider"))
89                })?,
90                model: model.model.clone().ok_or_else(|| {
91                    build_err(format!("model '{name}': remote model requires a model id"))
92                })?,
93            },
94            "local" => {
95                let source = model.source.clone().ok_or_else(|| {
96                    build_err(format!("model '{name}': local model requires a source"))
97                })?;
98                // Auto-derive classifier labels from the model's config.json.
99                let labels = local_cache_dir
100                    .as_ref()
101                    .map(|dir| local::load_labels(dir, &source))
102                    .filter(|l| !l.is_empty());
103                ModelBackend::Local { source, labels }
104            }
105            other => {
106                return Err(build_err(format!(
107                    "model '{name}': kind must be 'local' or 'remote', got '{other}'"
108                )))
109            }
110        };
111        registry
112            .register(ModelEntry {
113                id: name.clone(),
114                tasks,
115                backend,
116            })
117            .map_err(|e| build_err(e.to_string()))?;
118    }
119
120    for (task_name, model_name) in &config.ai.defaults {
121        let task = task_name
122            .parse::<Task>()
123            .map_err(|e| build_err(format!("ai.defaults: {e}")))?;
124        registry.set_default(task, model_name.clone());
125    }
126
127    let cache = Arc::new(AiResultCache::with_defaults());
128    let call_log = Arc::new(AiCallLog::new(CALL_LOG_CAPACITY));
129    Ok(Some(Arc::new(AiRuntime::new(
130        registry,
131        providers,
132        local_provider,
133        cache,
134        call_log,
135    ))))
136}
137
138/// A provider's transport kind: its explicit `kind`, else inferred from the name.
139fn provider_kind<'a>(name: &'a str, cfg: &'a ProviderConfig) -> &'a str {
140    cfg.kind.as_deref().unwrap_or(name)
141}
142
143/// Build one remote provider client. Returns `None` for the local backend (it is
144/// the runtime's separate `local_provider`, not an entry in the providers map).
145#[allow(clippy::result_large_err)]
146fn build_provider(
147    name: &str,
148    cfg: &ProviderConfig,
149) -> Result<Option<Arc<dyn InferenceProvider>>, ServerError> {
150    let base: Arc<dyn InferenceProvider> = match provider_kind(name, cfg) {
151        "local" => return Ok(None),
152        "anthropic" => {
153            let key = resolve_key(name, cfg)?;
154            let base_url = cfg
155                .base_url
156                .clone()
157                .unwrap_or_else(|| "https://api.anthropic.com".to_string());
158            Arc::new(
159                AnthropicProvider::new(base_url, key, cfg.max_concurrency)
160                    .map_err(|e| build_err(e.to_string()))?,
161            )
162        }
163        // OpenAI-compatible (openai, Azure, vLLM, local servers via base_url).
164        _ => {
165            let key = resolve_key(name, cfg)?;
166            let base_url = cfg
167                .base_url
168                .clone()
169                .unwrap_or_else(|| "https://api.openai.com/v1".to_string());
170            Arc::new(
171                OpenAiProvider::new(base_url, key, cfg.max_concurrency)
172                    .map_err(|e| build_err(e.to_string()))?,
173            )
174        }
175    };
176    Ok(Some(maybe_rate_limit(base, cfg)))
177}
178
179/// Pace a provider to `requests_per_second` when configured, else leave it as-is.
180fn maybe_rate_limit(
181    provider: Arc<dyn InferenceProvider>,
182    cfg: &ProviderConfig,
183) -> Arc<dyn InferenceProvider> {
184    match cfg.requests_per_second.and_then(NonZeroU32::new) {
185        Some(rps) => Arc::new(RateLimitedProvider::new(provider, rps)),
186        None => provider,
187    }
188}
189
190/// Resolve a provider's API key from its `api_key_env` environment variable.
191#[allow(clippy::result_large_err)]
192fn resolve_key(name: &str, cfg: &ProviderConfig) -> Result<String, ServerError> {
193    let var = cfg.api_key_env.as_deref().ok_or_else(|| {
194        build_err(format!(
195            "provider '{name}': remote provider requires api_key_env"
196        ))
197    })?;
198    std::env::var(var).map_err(|_| {
199        build_err(format!(
200            "provider '{name}': environment variable '{var}' (api_key_env) is not set"
201        ))
202    })
203}
204
205fn build_err(msg: String) -> ServerError {
206    ServerError::Build(msg)
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    fn parse(toml: &str) -> ServerConfig {
214        toml::from_str(toml).unwrap()
215    }
216
217    #[test]
218    fn no_models_yields_none() {
219        assert!(build_ai_runtime(&parse("[server]\n")).unwrap().is_none());
220    }
221
222    #[test]
223    fn local_model_builds_without_a_key() {
224        let config = parse(
225            r#"
226[server]
227[ai.providers.local]
228cache_dir = "/tmp/models"
229[models.finbert]
230kind = "local"
231source = "hf:onnx-community/finbert"
232task = "classify"
233"#,
234        );
235        assert!(build_ai_runtime(&config).unwrap().is_some());
236    }
237
238    /// The shipped crypto-sentiment demo config parses, validates, and builds a
239    /// runtime whose `sentiment` default resolves to the local backend — the demo
240    /// is wired to a local ONNX model and must stay that way.
241    #[test]
242    fn crypto_sentiment_demo_builds_a_local_runtime() {
243        let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
244            .join("../../examples/demos/crypto_sentiment/pipeline.toml");
245        let config = crate::config::load_config(&path).expect("demo config parses and validates");
246        let runtime = build_ai_runtime(&config)
247            .expect("AI runtime builds")
248            .expect("the demo configures a model");
249        assert_eq!(
250            runtime.registry().default_for(laminar_ai::Task::Sentiment),
251            Some("sentiment"),
252            "ai_sentiment resolves to the configured default"
253        );
254        assert_eq!(
255            runtime.resolve("sentiment").unwrap().kind,
256            laminar_ai::BackendKind::Local,
257            "the demo scores sentiment on a local model"
258        );
259    }
260
261    #[test]
262    fn remote_model_without_env_key_fails_fast() {
263        let config = parse(
264            r#"
265[server]
266[ai.providers.openai]
267api_key_env = "LAMINAR_TEST_DEFINITELY_UNSET_KEY_XYZ"
268base_url = "http://localhost:1234/v1"
269[models.embed]
270kind = "remote"
271provider = "openai"
272model = "text-embedding-3-small"
273task = "embed"
274"#,
275        );
276        let Err(err) = build_ai_runtime(&config) else {
277            panic!("expected an error for the unset api_key_env");
278        };
279        assert!(format!("{err}").contains("not set"), "{err}");
280    }
281}