1use std::collections::HashMap;
10use std::fmt;
11use std::str::FromStr;
12
13use thiserror::Error;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub enum Task {
19 Classify,
21 Sentiment,
23 Embed,
25 Extract,
27 Complete,
29 Summarize,
31 Translate,
33 Gen,
35}
36
37impl Task {
38 #[must_use]
40 pub fn as_str(self) -> &'static str {
41 match self {
42 Task::Classify => "classify",
43 Task::Sentiment => "sentiment",
44 Task::Embed => "embed",
45 Task::Extract => "extract",
46 Task::Complete => "complete",
47 Task::Summarize => "summarize",
48 Task::Translate => "translate",
49 Task::Gen => "gen",
50 }
51 }
52}
53
54impl fmt::Display for Task {
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 f.write_str(self.as_str())
57 }
58}
59
60impl FromStr for Task {
61 type Err = RegistryError;
62
63 fn from_str(s: &str) -> Result<Self, Self::Err> {
64 match s {
65 "classify" => Ok(Task::Classify),
66 "sentiment" => Ok(Task::Sentiment),
67 "embed" => Ok(Task::Embed),
68 "extract" => Ok(Task::Extract),
69 "complete" => Ok(Task::Complete),
70 "summarize" => Ok(Task::Summarize),
71 "translate" => Ok(Task::Translate),
72 "gen" => Ok(Task::Gen),
73 other => Err(RegistryError::UnknownTask(other.to_string())),
74 }
75 }
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
80pub enum BackendKind {
81 Local,
83 Remote,
85}
86
87#[derive(Debug, Clone, PartialEq, Eq)]
89pub enum ModelBackend {
90 Local {
92 source: String,
95 labels: Option<Vec<String>>,
98 },
99 Remote {
101 provider: String,
103 model: String,
105 },
106}
107
108#[derive(Debug, Clone, PartialEq, Eq)]
110pub struct ModelEntry {
111 pub id: String,
113 pub tasks: Vec<Task>,
115 pub backend: ModelBackend,
117}
118
119impl ModelEntry {
120 #[must_use]
122 pub fn kind(&self) -> BackendKind {
123 match self.backend {
124 ModelBackend::Local { .. } => BackendKind::Local,
125 ModelBackend::Remote { .. } => BackendKind::Remote,
126 }
127 }
128
129 #[must_use]
131 pub fn supports(&self, task: Task) -> bool {
132 self.tasks.contains(&task)
133 }
134
135 #[must_use]
138 pub fn is_deterministic(&self) -> bool {
139 matches!(self.kind(), BackendKind::Local)
140 }
141
142 #[must_use]
145 pub fn is_costed(&self) -> bool {
146 matches!(self.kind(), BackendKind::Remote)
147 }
148
149 #[must_use]
151 pub fn labels(&self) -> Option<&[String]> {
152 match &self.backend {
153 ModelBackend::Local { labels, .. } => labels.as_deref(),
154 ModelBackend::Remote { .. } => None,
155 }
156 }
157}
158
159#[derive(Debug, Default)]
161pub struct ModelRegistry {
162 models: HashMap<String, ModelEntry>,
163 defaults: HashMap<Task, String>,
164}
165
166impl ModelRegistry {
167 #[must_use]
169 pub fn new() -> Self {
170 Self::default()
171 }
172
173 pub fn register(&mut self, entry: ModelEntry) -> Result<(), RegistryError> {
180 if self.models.contains_key(&entry.id) {
181 return Err(RegistryError::DuplicateModel(entry.id.clone()));
182 }
183 self.models.insert(entry.id.clone(), entry);
184 Ok(())
185 }
186
187 pub fn set_default(&mut self, task: Task, model: impl Into<String>) {
189 self.defaults.insert(task, model.into());
190 }
191
192 #[must_use]
194 pub fn default_for(&self, task: Task) -> Option<&str> {
195 self.defaults.get(&task).map(String::as_str)
196 }
197
198 pub fn resolve(&self, name: &str) -> Result<&ModelEntry, RegistryError> {
205 self.models
206 .get(name)
207 .ok_or_else(|| RegistryError::UnknownModel(name.to_string()))
208 }
209
210 pub fn validate(&self, name: &str, task: Task) -> Result<&ModelEntry, RegistryError> {
218 let entry = self.resolve(name)?;
219 if entry.supports(task) {
220 Ok(entry)
221 } else {
222 Err(RegistryError::TaskUnsupported {
223 model: name.to_string(),
224 task,
225 supported: entry.tasks.clone(),
226 })
227 }
228 }
229
230 #[must_use]
232 pub fn len(&self) -> usize {
233 self.models.len()
234 }
235
236 #[must_use]
238 pub fn is_empty(&self) -> bool {
239 self.models.is_empty()
240 }
241
242 pub fn iter(&self) -> impl Iterator<Item = &ModelEntry> {
244 self.models.values()
245 }
246}
247
248#[derive(Debug, Error, PartialEq, Eq)]
250pub enum RegistryError {
251 #[error("unknown model '{0}'")]
253 UnknownModel(String),
254
255 #[error("model '{model}' does not support task '{task}' (supports: {supported:?})")]
257 TaskUnsupported {
258 model: String,
260 task: Task,
262 supported: Vec<Task>,
264 },
265
266 #[error("model '{0}' is already registered")]
268 DuplicateModel(String),
269
270 #[error("unknown task '{0}'")]
272 UnknownTask(String),
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 fn local_classifier() -> ModelEntry {
280 ModelEntry {
281 id: "finbert".to_string(),
282 tasks: vec![Task::Classify, Task::Sentiment],
283 backend: ModelBackend::Local {
284 source: "hf:onnx-community/finbert".to_string(),
285 labels: Some(vec![
286 "positive".to_string(),
287 "negative".to_string(),
288 "neutral".to_string(),
289 ]),
290 },
291 }
292 }
293
294 fn remote_llm() -> ModelEntry {
295 ModelEntry {
296 id: "haiku".to_string(),
297 tasks: vec![Task::Classify, Task::Extract, Task::Complete],
298 backend: ModelBackend::Remote {
299 provider: "anthropic".to_string(),
300 model: "claude-haiku-4-5-20251001".to_string(),
301 },
302 }
303 }
304
305 #[test]
306 fn resolve_and_validate() {
307 let mut reg = ModelRegistry::new();
308 assert!(reg.is_empty());
309 reg.register(local_classifier()).unwrap();
310 reg.register(remote_llm()).unwrap();
311 reg.set_default(Task::Classify, "finbert");
312 assert_eq!(reg.len(), 2);
313
314 assert_eq!(
316 reg.resolve("missing").unwrap_err(),
317 RegistryError::UnknownModel("missing".to_string())
318 );
319
320 assert_eq!(
322 reg.validate("finbert", Task::Sentiment).unwrap().id,
323 "finbert"
324 );
325
326 match reg.validate("finbert", Task::Complete).unwrap_err() {
328 RegistryError::TaskUnsupported {
329 model,
330 task,
331 supported,
332 } => {
333 assert_eq!(model, "finbert");
334 assert_eq!(task, Task::Complete);
335 assert_eq!(supported, vec![Task::Classify, Task::Sentiment]);
336 }
337 other => panic!("unexpected error: {other}"),
338 }
339
340 assert_eq!(reg.default_for(Task::Classify), Some("finbert"));
341 assert_eq!(reg.default_for(Task::Embed), None);
342 }
343
344 #[test]
345 fn duplicate_registration_rejected() {
346 let mut reg = ModelRegistry::new();
347 reg.register(local_classifier()).unwrap();
348 assert_eq!(
349 reg.register(local_classifier()).unwrap_err(),
350 RegistryError::DuplicateModel("finbert".to_string())
351 );
352 }
353}