1use thiserror::Error;
9
10use crate::provider::InferenceOutputs;
11use crate::registry::{BackendKind, Task};
12
13#[derive(Debug, Error, PartialEq, Eq)]
15pub enum AdapterError {
16 #[error("classification requires the model's labels but none were provided")]
19 MissingLabels,
20
21 #[error("classifier chose index {index} but only {len} labels are defined")]
23 LabelIndexOutOfRange {
24 index: usize,
26 len: usize,
28 },
29
30 #[error("classifier returned no logits for a row")]
32 EmptyLogits,
33
34 #[error(
37 "sentiment scoring requires 'positive' and 'negative' among the labels, got {labels:?}"
38 )]
39 SentimentLabelsUnusable {
40 labels: Vec<String>,
42 },
43
44 #[error("remote sentiment reply had no parseable score: {reply:?}")]
46 UnparseableScore {
47 reply: String,
49 },
50
51 #[error("task '{task}' on the {kind:?} backend expected {expected} output, got {got}")]
53 UnexpectedOutputShape {
54 task: Task,
56 kind: BackendKind,
58 expected: &'static str,
60 got: &'static str,
62 },
63
64 #[error("task '{task}' is not supported on the {kind:?} backend")]
66 UnsupportedCombination {
67 task: Task,
69 kind: BackendKind,
71 },
72}
73
74pub fn parse_response(
85 task: Task,
86 kind: BackendKind,
87 raw: InferenceOutputs,
88 labels: Option<&[String]>,
89) -> Result<InferenceOutputs, AdapterError> {
90 match task {
91 Task::Classify => match kind {
92 BackendKind::Local => classify_from_logits(kind, raw, labels),
93 BackendKind::Remote => coerce_remote_classification(raw, labels),
94 },
95 Task::Sentiment => match kind {
96 BackendKind::Local => sentiment_from_logits(raw, labels),
97 BackendKind::Remote => sentiment_from_text(raw),
98 },
99 Task::Embed => expect_vectors(task, kind, raw),
100 Task::Complete | Task::Summarize | Task::Translate | Task::Gen | Task::Extract => {
101 if kind != BackendKind::Remote {
102 return Err(AdapterError::UnsupportedCombination { task, kind });
105 }
106 expect_text(task, kind, raw)
107 }
108 }
109}
110
111fn classify_from_logits(
113 kind: BackendKind,
114 raw: InferenceOutputs,
115 labels: Option<&[String]>,
116) -> Result<InferenceOutputs, AdapterError> {
117 let rows = match raw {
118 InferenceOutputs::Vectors(rows) => rows,
119 other => {
120 return Err(AdapterError::UnexpectedOutputShape {
121 task: Task::Classify,
122 kind,
123 expected: "vectors",
124 got: shape_name(&other),
125 });
126 }
127 };
128 let labels = labels.ok_or(AdapterError::MissingLabels)?;
129 let mut out = Vec::with_capacity(rows.len());
130 for logits in rows {
131 let index = argmax(&logits).ok_or(AdapterError::EmptyLogits)?;
132 let label = labels
133 .get(index)
134 .ok_or(AdapterError::LabelIndexOutOfRange {
135 index,
136 len: labels.len(),
137 })?;
138 out.push(label.clone());
139 }
140 Ok(InferenceOutputs::Text(out))
141}
142
143fn coerce_remote_classification(
150 raw: InferenceOutputs,
151 labels: Option<&[String]>,
152) -> Result<InferenceOutputs, AdapterError> {
153 let texts = match raw {
154 InferenceOutputs::Text(texts) => texts,
155 other => {
156 return Err(AdapterError::UnexpectedOutputShape {
157 task: Task::Classify,
158 kind: BackendKind::Remote,
159 expected: "text",
160 got: shape_name(&other),
161 });
162 }
163 };
164 let coerced = texts
165 .into_iter()
166 .map(|text| match labels {
167 Some(labels) => coerce_label(&text, labels),
168 None => text.trim().to_string(),
169 })
170 .collect();
171 Ok(InferenceOutputs::Text(coerced))
172}
173
174fn coerce_label(text: &str, labels: &[String]) -> String {
176 let trimmed = text.trim();
177 if let Some(label) = labels.iter().find(|l| l.eq_ignore_ascii_case(trimmed)) {
178 return label.clone();
179 }
180 let lower = trimmed.to_ascii_lowercase();
181 if let Some(label) = labels
182 .iter()
183 .find(|l| lower.contains(&l.to_ascii_lowercase()))
184 {
185 return label.clone();
186 }
187 trimmed.to_string()
188}
189
190fn expect_text(
192 task: Task,
193 kind: BackendKind,
194 raw: InferenceOutputs,
195) -> Result<InferenceOutputs, AdapterError> {
196 match raw {
197 InferenceOutputs::Text(text) => Ok(InferenceOutputs::Text(text)),
198 other => Err(AdapterError::UnexpectedOutputShape {
199 task,
200 kind,
201 expected: "text",
202 got: shape_name(&other),
203 }),
204 }
205}
206
207fn expect_vectors(
209 task: Task,
210 kind: BackendKind,
211 raw: InferenceOutputs,
212) -> Result<InferenceOutputs, AdapterError> {
213 match raw {
214 InferenceOutputs::Vectors(vectors) => Ok(InferenceOutputs::Vectors(vectors)),
215 other => Err(AdapterError::UnexpectedOutputShape {
216 task,
217 kind,
218 expected: "vectors",
219 got: shape_name(&other),
220 }),
221 }
222}
223
224fn shape_name(raw: &InferenceOutputs) -> &'static str {
226 match raw {
227 InferenceOutputs::Text(_) => "text",
228 InferenceOutputs::Vectors(_) => "vectors",
229 InferenceOutputs::Scores(_) => "scores",
230 }
231}
232
233fn sentiment_from_logits(
237 raw: InferenceOutputs,
238 labels: Option<&[String]>,
239) -> Result<InferenceOutputs, AdapterError> {
240 let rows = match raw {
241 InferenceOutputs::Vectors(rows) => rows,
242 other => {
243 return Err(AdapterError::UnexpectedOutputShape {
244 task: Task::Sentiment,
245 kind: BackendKind::Local,
246 expected: "vectors",
247 got: shape_name(&other),
248 });
249 }
250 };
251 let labels = labels.ok_or(AdapterError::MissingLabels)?;
252 let pos = labels
253 .iter()
254 .position(|l| l.eq_ignore_ascii_case("positive"));
255 let neg = labels
256 .iter()
257 .position(|l| l.eq_ignore_ascii_case("negative"));
258 let (Some(pos), Some(neg)) = (pos, neg) else {
259 return Err(AdapterError::SentimentLabelsUnusable {
260 labels: labels.to_vec(),
261 });
262 };
263 let mut scores = Vec::with_capacity(rows.len());
264 for logits in rows {
265 if logits.is_empty() {
266 return Err(AdapterError::EmptyLogits);
267 }
268 let probs = softmax(&logits);
269 let p_pos = probs.get(pos).copied().unwrap_or(0.0);
270 let p_neg = probs.get(neg).copied().unwrap_or(0.0);
271 scores.push(f64::from(p_pos - p_neg));
272 }
273 Ok(InferenceOutputs::Scores(scores))
274}
275
276fn sentiment_from_text(raw: InferenceOutputs) -> Result<InferenceOutputs, AdapterError> {
278 let texts = match raw {
279 InferenceOutputs::Text(texts) => texts,
280 other => {
281 return Err(AdapterError::UnexpectedOutputShape {
282 task: Task::Sentiment,
283 kind: BackendKind::Remote,
284 expected: "text",
285 got: shape_name(&other),
286 });
287 }
288 };
289 let mut scores = Vec::with_capacity(texts.len());
290 for reply in texts {
291 let value = parse_score(&reply).ok_or(AdapterError::UnparseableScore { reply })?;
292 scores.push(value.clamp(-1.0, 1.0));
293 }
294 Ok(InferenceOutputs::Scores(scores))
295}
296
297fn softmax(logits: &[f32]) -> Vec<f32> {
299 let max = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
300 let mut exps: Vec<f32> = logits.iter().map(|&v| (v - max).exp()).collect();
301 let sum: f32 = exps.iter().sum();
302 if sum > 0.0 {
303 for e in &mut exps {
304 *e /= sum;
305 }
306 }
307 exps
308}
309
310fn parse_score(reply: &str) -> Option<f64> {
313 let trimmed = reply.trim();
314 if let Ok(v) = trimmed.parse::<f64>() {
315 return Some(v);
316 }
317 trimmed.split_whitespace().find_map(|token| {
318 token
321 .trim_start_matches(|c: char| !(c.is_ascii_digit() || c == '-'))
322 .trim_end_matches(|c: char| !c.is_ascii_digit())
323 .parse::<f64>()
324 .ok()
325 })
326}
327
328fn argmax(values: &[f32]) -> Option<usize> {
330 let mut best: Option<(usize, f32)> = None;
331 for (index, &value) in values.iter().enumerate() {
332 match best {
333 Some((_, current)) if value <= current => {}
334 _ => best = Some((index, value)),
335 }
336 }
337 best.map(|(index, _)| index)
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343
344 fn labels() -> Vec<String> {
345 vec!["negative".into(), "positive".into(), "neutral".into()]
346 }
347
348 #[test]
349 fn local_classify_argmaxes_logits_to_labels() {
350 let raw = InferenceOutputs::Vectors(vec![vec![0.1, 0.9, 0.2], vec![5.0, 0.0, 0.0]]);
351 let out = parse_response(Task::Classify, BackendKind::Local, raw, Some(&labels())).unwrap();
352 assert_eq!(
353 out,
354 InferenceOutputs::Text(vec!["positive".into(), "negative".into()])
355 );
356 }
357
358 #[test]
359 fn local_classify_requires_labels() {
360 let raw = InferenceOutputs::Vectors(vec![vec![0.1, 0.9]]);
361 assert_eq!(
362 parse_response(Task::Sentiment, BackendKind::Local, raw, None),
363 Err(AdapterError::MissingLabels)
364 );
365 }
366
367 #[test]
368 fn remote_classify_coerces_to_canonical_label() {
369 let raw = InferenceOutputs::Text(vec![
372 "The sentiment is Positive.".into(),
373 "NEGATIVE".into(),
374 "totally unrelated".into(),
375 ]);
376 let out =
377 parse_response(Task::Classify, BackendKind::Remote, raw, Some(&labels())).unwrap();
378 assert_eq!(
379 out,
380 InferenceOutputs::Text(vec![
381 "positive".into(),
382 "negative".into(),
383 "totally unrelated".into(),
384 ])
385 );
386 }
387
388 #[test]
389 fn local_sentiment_softmaxes_logits_to_a_signed_score() {
390 let raw = InferenceOutputs::Vectors(vec![vec![0.0, 2.0, 0.0], vec![3.0, 0.0, 0.0]]);
393 let out =
394 parse_response(Task::Sentiment, BackendKind::Local, raw, Some(&labels())).unwrap();
395 let InferenceOutputs::Scores(scores) = out else {
396 panic!("sentiment is numeric");
397 };
398 let denom = 2.0 + std::f32::consts::E.powi(2);
400 let expected0 = f64::from((std::f32::consts::E.powi(2) - 1.0) / denom);
401 assert!((scores[0] - expected0).abs() < 1e-6, "got {}", scores[0]);
402 assert!(scores[0] > 0.0 && scores[1] < 0.0);
403 assert!((-1.0..=1.0).contains(&scores[0]) && (-1.0..=1.0).contains(&scores[1]));
404 }
405
406 #[test]
407 fn local_sentiment_needs_positive_and_negative_labels() {
408 let raw = InferenceOutputs::Vectors(vec![vec![0.1, 0.9]]);
409 let stars = vec!["one_star".to_string(), "five_star".to_string()];
410 assert!(matches!(
411 parse_response(Task::Sentiment, BackendKind::Local, raw, Some(&stars)),
412 Err(AdapterError::SentimentLabelsUnusable { .. })
413 ));
414 }
415
416 #[test]
417 fn remote_sentiment_parses_and_clamps_a_number() {
418 let raw = InferenceOutputs::Text(vec![
419 "0.8".into(),
420 "The sentiment is -0.5.".into(),
421 "2.0".into(), ]);
423 let out = parse_response(Task::Sentiment, BackendKind::Remote, raw, None).unwrap();
424 let InferenceOutputs::Scores(scores) = out else {
425 panic!("sentiment is numeric");
426 };
427 assert!((scores[0] - 0.8).abs() < 1e-12);
428 assert!((scores[1] + 0.5).abs() < 1e-12);
429 assert!((scores[2] - 1.0).abs() < 1e-12);
430 }
431
432 #[test]
433 fn generation_is_remote_only() {
434 let raw = InferenceOutputs::Vectors(vec![vec![0.0]]);
435 assert_eq!(
436 parse_response(Task::Complete, BackendKind::Local, raw, None),
437 Err(AdapterError::UnsupportedCombination {
438 task: Task::Complete,
439 kind: BackendKind::Local,
440 })
441 );
442 }
443}