Skip to main content

laminar_ai/
registry.rs

1//! The model registry: the curated catalog that maps a SQL-referenced model
2//! name to the backend that runs it and the tasks it can serve.
3//!
4//! The registry is built once at startup from server configuration and is
5//! immutable thereafter. AI functions resolve `model => '<name>'` against it at
6//! plan time and fail the query if the model is unknown or cannot perform the
7//! requested task — never per event.
8
9use std::collections::HashMap;
10use std::fmt;
11use std::str::FromStr;
12
13use thiserror::Error;
14
15/// A task contract an AI SQL function fulfils. Each `ai_*` function maps to
16/// exactly one task; a model declares the subset of tasks it supports.
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub enum Task {
19    /// Assign one label from a candidate set (`ai_classify`).
20    Classify,
21    /// Classify over a fixed sentiment label set (`ai_sentiment`).
22    Sentiment,
23    /// Produce a dense embedding vector (`ai_embed`).
24    Embed,
25    /// Pull structured fields out of text (`ai_extract`).
26    Extract,
27    /// Free-form completion (`ai_complete`).
28    Complete,
29    /// Summarize text (`ai_summarize`).
30    Summarize,
31    /// Translate text (`ai_translate`).
32    Translate,
33    /// Open-ended generation (`ai_gen`).
34    Gen,
35}
36
37impl Task {
38    /// Canonical lower-case name, matching the `task = "…"` config spelling.
39    #[must_use]
40    pub fn as_str(self) -> &'static str {
41        match self {
42            Task::Classify => "classify",
43            Task::Sentiment => "sentiment",
44            Task::Embed => "embed",
45            Task::Extract => "extract",
46            Task::Complete => "complete",
47            Task::Summarize => "summarize",
48            Task::Translate => "translate",
49            Task::Gen => "gen",
50        }
51    }
52}
53
54impl fmt::Display for Task {
55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56        f.write_str(self.as_str())
57    }
58}
59
60impl FromStr for Task {
61    type Err = RegistryError;
62
63    fn from_str(s: &str) -> Result<Self, Self::Err> {
64        match s {
65            "classify" => Ok(Task::Classify),
66            "sentiment" => Ok(Task::Sentiment),
67            "embed" => Ok(Task::Embed),
68            "extract" => Ok(Task::Extract),
69            "complete" => Ok(Task::Complete),
70            "summarize" => Ok(Task::Summarize),
71            "translate" => Ok(Task::Translate),
72            "gen" => Ok(Task::Gen),
73            other => Err(RegistryError::UnknownTask(other.to_string())),
74        }
75    }
76}
77
78/// Which runtime executes a model.
79#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
80pub enum BackendKind {
81    /// Local ONNX model run in-process via tract.
82    Local,
83    /// Model served by a configured remote provider.
84    Remote,
85}
86
87/// Backend-specific wiring for a registered model.
88#[derive(Debug, Clone, PartialEq, Eq)]
89pub enum ModelBackend {
90    /// A local ONNX model.
91    Local {
92        /// Weight source — e.g. `hf:onnx-community/finbert`, or a file /
93        /// object_store URI. Resolved and mmapped by the local backend.
94        source: String,
95        /// Intrinsic classifier labels (`id2label`), when known at
96        /// registration. `None` until derived from the model's `config.json`.
97        labels: Option<Vec<String>>,
98    },
99    /// A model served by a remote provider.
100    Remote {
101        /// Key into the configured providers map (e.g. `anthropic`).
102        provider: String,
103        /// Provider-specific model id (e.g. `claude-haiku-4-5-20251001`).
104        model: String,
105    },
106}
107
108/// One registered model: its name, the tasks it serves, and its backend.
109#[derive(Debug, Clone, PartialEq, Eq)]
110pub struct ModelEntry {
111    /// Registry key — the name referenced as `model => '<id>'` in SQL.
112    pub id: String,
113    /// Tasks this model can perform.
114    pub tasks: Vec<Task>,
115    /// The backend that runs it.
116    pub backend: ModelBackend,
117}
118
119impl ModelEntry {
120    /// The backend kind (local vs remote).
121    #[must_use]
122    pub fn kind(&self) -> BackendKind {
123        match self.backend {
124            ModelBackend::Local { .. } => BackendKind::Local,
125            ModelBackend::Remote { .. } => BackendKind::Remote,
126        }
127    }
128
129    /// Whether this model can perform `task`.
130    #[must_use]
131    pub fn supports(&self, task: Task) -> bool {
132        self.tasks.contains(&task)
133    }
134
135    /// Deterministic results are cacheable permanently for correctness; only
136    /// local backends are deterministic.
137    #[must_use]
138    pub fn is_deterministic(&self) -> bool {
139        matches!(self.kind(), BackendKind::Local)
140    }
141
142    /// Whether calls incur metered cost (remote) versus free (local). Costed
143    /// calls are logged to `laminar.ai_calls` with tokens and cost.
144    #[must_use]
145    pub fn is_costed(&self) -> bool {
146        matches!(self.kind(), BackendKind::Remote)
147    }
148
149    /// Intrinsic labels of a local classifier, if any.
150    #[must_use]
151    pub fn labels(&self) -> Option<&[String]> {
152        match &self.backend {
153            ModelBackend::Local { labels, .. } => labels.as_deref(),
154            ModelBackend::Remote { .. } => None,
155        }
156    }
157}
158
159/// Curated map of model name → entry, plus a default model per task.
160#[derive(Debug, Default)]
161pub struct ModelRegistry {
162    models: HashMap<String, ModelEntry>,
163    defaults: HashMap<Task, String>,
164}
165
166impl ModelRegistry {
167    /// An empty registry.
168    #[must_use]
169    pub fn new() -> Self {
170        Self::default()
171    }
172
173    /// Register a model.
174    ///
175    /// # Errors
176    ///
177    /// Returns [`RegistryError::DuplicateModel`] if a model with the same id is
178    /// already registered.
179    pub fn register(&mut self, entry: ModelEntry) -> Result<(), RegistryError> {
180        if self.models.contains_key(&entry.id) {
181            return Err(RegistryError::DuplicateModel(entry.id.clone()));
182        }
183        self.models.insert(entry.id.clone(), entry);
184        Ok(())
185    }
186
187    /// Set the default model for a task (the `[ai.defaults]` config block).
188    pub fn set_default(&mut self, task: Task, model: impl Into<String>) {
189        self.defaults.insert(task, model.into());
190    }
191
192    /// Default model name for a task, if one is configured.
193    #[must_use]
194    pub fn default_for(&self, task: Task) -> Option<&str> {
195        self.defaults.get(&task).map(String::as_str)
196    }
197
198    /// Resolve a model by name.
199    ///
200    /// # Errors
201    ///
202    /// Returns [`RegistryError::UnknownModel`] if no model with that name is
203    /// registered.
204    pub fn resolve(&self, name: &str) -> Result<&ModelEntry, RegistryError> {
205        self.models
206            .get(name)
207            .ok_or_else(|| RegistryError::UnknownModel(name.to_string()))
208    }
209
210    /// Resolve a model and confirm it supports `task`. This is the plan-time
211    /// gate behind every AI function call.
212    ///
213    /// # Errors
214    ///
215    /// Returns [`RegistryError::UnknownModel`] if the model is not registered,
216    /// or [`RegistryError::TaskUnsupported`] if it cannot perform `task`.
217    pub fn validate(&self, name: &str, task: Task) -> Result<&ModelEntry, RegistryError> {
218        let entry = self.resolve(name)?;
219        if entry.supports(task) {
220            Ok(entry)
221        } else {
222            Err(RegistryError::TaskUnsupported {
223                model: name.to_string(),
224                task,
225                supported: entry.tasks.clone(),
226            })
227        }
228    }
229
230    /// Number of registered models.
231    #[must_use]
232    pub fn len(&self) -> usize {
233        self.models.len()
234    }
235
236    /// Whether no models are registered.
237    #[must_use]
238    pub fn is_empty(&self) -> bool {
239        self.models.is_empty()
240    }
241
242    /// Iterate registered entries in unspecified order. Backs `laminar.models`.
243    pub fn iter(&self) -> impl Iterator<Item = &ModelEntry> {
244        self.models.values()
245    }
246}
247
248/// Errors from resolving or validating a model. Surfaced at plan time.
249#[derive(Debug, Error, PartialEq, Eq)]
250pub enum RegistryError {
251    /// No model with the given name is registered.
252    #[error("unknown model '{0}'")]
253    UnknownModel(String),
254
255    /// The model exists but does not support the requested task.
256    #[error("model '{model}' does not support task '{task}' (supports: {supported:?})")]
257    TaskUnsupported {
258        /// The referenced model name.
259        model: String,
260        /// The task that was requested.
261        task: Task,
262        /// The tasks the model actually supports.
263        supported: Vec<Task>,
264    },
265
266    /// A model with this id is already registered.
267    #[error("model '{0}' is already registered")]
268    DuplicateModel(String),
269
270    /// A `task = "…"` string did not name a known task.
271    #[error("unknown task '{0}'")]
272    UnknownTask(String),
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278
279    fn local_classifier() -> ModelEntry {
280        ModelEntry {
281            id: "finbert".to_string(),
282            tasks: vec![Task::Classify, Task::Sentiment],
283            backend: ModelBackend::Local {
284                source: "hf:onnx-community/finbert".to_string(),
285                labels: Some(vec![
286                    "positive".to_string(),
287                    "negative".to_string(),
288                    "neutral".to_string(),
289                ]),
290            },
291        }
292    }
293
294    fn remote_llm() -> ModelEntry {
295        ModelEntry {
296            id: "haiku".to_string(),
297            tasks: vec![Task::Classify, Task::Extract, Task::Complete],
298            backend: ModelBackend::Remote {
299                provider: "anthropic".to_string(),
300                model: "claude-haiku-4-5-20251001".to_string(),
301            },
302        }
303    }
304
305    #[test]
306    fn resolve_and_validate() {
307        let mut reg = ModelRegistry::new();
308        assert!(reg.is_empty());
309        reg.register(local_classifier()).unwrap();
310        reg.register(remote_llm()).unwrap();
311        reg.set_default(Task::Classify, "finbert");
312        assert_eq!(reg.len(), 2);
313
314        // Unknown model.
315        assert_eq!(
316            reg.resolve("missing").unwrap_err(),
317            RegistryError::UnknownModel("missing".to_string())
318        );
319
320        // Supported task resolves.
321        assert_eq!(
322            reg.validate("finbert", Task::Sentiment).unwrap().id,
323            "finbert"
324        );
325
326        // Unsupported task is rejected with the supported set.
327        match reg.validate("finbert", Task::Complete).unwrap_err() {
328            RegistryError::TaskUnsupported {
329                model,
330                task,
331                supported,
332            } => {
333                assert_eq!(model, "finbert");
334                assert_eq!(task, Task::Complete);
335                assert_eq!(supported, vec![Task::Classify, Task::Sentiment]);
336            }
337            other => panic!("unexpected error: {other}"),
338        }
339
340        assert_eq!(reg.default_for(Task::Classify), Some("finbert"));
341        assert_eq!(reg.default_for(Task::Embed), None);
342    }
343
344    #[test]
345    fn duplicate_registration_rejected() {
346        let mut reg = ModelRegistry::new();
347        reg.register(local_classifier()).unwrap();
348        assert_eq!(
349            reg.register(local_classifier()).unwrap_err(),
350            RegistryError::DuplicateModel("finbert".to_string())
351        );
352    }
353}