Skip to main content

laminar_ai/
adapter.rs

1//! Resolve a raw provider response into a task's per-row output, given the
2//! `(task, backend)` pair. Local classification takes argmax over logits (equal
3//! to argmax of softmax — classify returns only the label, so no softmax).
4//! Sentiment is numeric: local softmaxes the logits to `P(pos) − P(neg)`, remote
5//! parses a number from the reply — a continuous score in `[-1, 1]`. Stays
6//! Arrow-free: returns [`InferenceOutputs`]; the operator builds the column.
7
8use thiserror::Error;
9
10use crate::provider::InferenceOutputs;
11use crate::registry::{BackendKind, Task};
12
13/// Errors from adapting a response to a task output.
14#[derive(Debug, Error, PartialEq, Eq)]
15pub enum AdapterError {
16    /// A local classifier produced logits but no labels were supplied to map
17    /// them to.
18    #[error("classification requires the model's labels but none were provided")]
19    MissingLabels,
20
21    /// argmax landed on an index with no corresponding label.
22    #[error("classifier chose index {index} but only {len} labels are defined")]
23    LabelIndexOutOfRange {
24        /// The chosen index.
25        index: usize,
26        /// The number of labels available.
27        len: usize,
28    },
29
30    /// A classifier returned an empty logit vector for a row.
31    #[error("classifier returned no logits for a row")]
32    EmptyLogits,
33
34    /// Sentiment scoring needs both a `positive` and a `negative` label among
35    /// the model's labels to map logits to a signed score.
36    #[error(
37        "sentiment scoring requires 'positive' and 'negative' among the labels, got {labels:?}"
38    )]
39    SentimentLabelsUnusable {
40        /// The labels that were available.
41        labels: Vec<String>,
42    },
43
44    /// A remote sentiment reply contained no parseable number.
45    #[error("remote sentiment reply had no parseable score: {reply:?}")]
46    UnparseableScore {
47        /// The reply that could not be parsed.
48        reply: String,
49    },
50
51    /// The raw output shape did not match what the task/backend produces.
52    #[error("task '{task}' on the {kind:?} backend expected {expected} output, got {got}")]
53    UnexpectedOutputShape {
54        /// The task.
55        task: Task,
56        /// The backend kind.
57        kind: BackendKind,
58        /// The shape that was expected (`text` or `vectors`).
59        expected: &'static str,
60        /// The shape actually produced.
61        got: &'static str,
62    },
63
64    /// The task cannot run on this backend kind (e.g. generation on local).
65    #[error("task '{task}' is not supported on the {kind:?} backend")]
66    UnsupportedCombination {
67        /// The task.
68        task: Task,
69        /// The backend kind.
70        kind: BackendKind,
71    },
72}
73
74/// Resolve a raw provider response into the task's per-row output column.
75///
76/// `labels` carries the candidate/intrinsic label set for classification — the
77/// model's `id2label` for a local classifier (required), ignored otherwise.
78///
79/// # Errors
80///
81/// Returns [`AdapterError`] if the output shape is wrong for the task, a local
82/// classifier has no labels or argmaxes out of range, or the task cannot run on
83/// the given backend kind.
84pub fn parse_response(
85    task: Task,
86    kind: BackendKind,
87    raw: InferenceOutputs,
88    labels: Option<&[String]>,
89) -> Result<InferenceOutputs, AdapterError> {
90    match task {
91        Task::Classify => match kind {
92            BackendKind::Local => classify_from_logits(kind, raw, labels),
93            BackendKind::Remote => coerce_remote_classification(raw, labels),
94        },
95        Task::Sentiment => match kind {
96            BackendKind::Local => sentiment_from_logits(raw, labels),
97            BackendKind::Remote => sentiment_from_text(raw),
98        },
99        Task::Embed => expect_vectors(task, kind, raw),
100        Task::Complete | Task::Summarize | Task::Translate | Task::Gen | Task::Extract => {
101            if kind != BackendKind::Remote {
102                // Narrow ONNX token-classification extraction is out of v0.1
103                // scope; generation is remote by definition.
104                return Err(AdapterError::UnsupportedCombination { task, kind });
105            }
106            expect_text(task, kind, raw)
107        }
108    }
109}
110
111/// argmax over each row's logits, mapped to the model's labels.
112fn classify_from_logits(
113    kind: BackendKind,
114    raw: InferenceOutputs,
115    labels: Option<&[String]>,
116) -> Result<InferenceOutputs, AdapterError> {
117    let rows = match raw {
118        InferenceOutputs::Vectors(rows) => rows,
119        other => {
120            return Err(AdapterError::UnexpectedOutputShape {
121                task: Task::Classify,
122                kind,
123                expected: "vectors",
124                got: shape_name(&other),
125            });
126        }
127    };
128    let labels = labels.ok_or(AdapterError::MissingLabels)?;
129    let mut out = Vec::with_capacity(rows.len());
130    for logits in rows {
131        let index = argmax(&logits).ok_or(AdapterError::EmptyLogits)?;
132        let label = labels
133            .get(index)
134            .ok_or(AdapterError::LabelIndexOutOfRange {
135                index,
136                len: labels.len(),
137            })?;
138        out.push(label.clone());
139    }
140    Ok(InferenceOutputs::Text(out))
141}
142
143/// Coerce a remote model's free text toward the candidate label set.
144///
145/// LLMs sometimes wrap the answer ("The sentiment is Positive."). When labels
146/// are known, normalize each reply to a canonical label by exact (case-
147/// insensitive) match, then by containment; otherwise fall back to the trimmed
148/// raw text rather than dropping the row.
149fn coerce_remote_classification(
150    raw: InferenceOutputs,
151    labels: Option<&[String]>,
152) -> Result<InferenceOutputs, AdapterError> {
153    let texts = match raw {
154        InferenceOutputs::Text(texts) => texts,
155        other => {
156            return Err(AdapterError::UnexpectedOutputShape {
157                task: Task::Classify,
158                kind: BackendKind::Remote,
159                expected: "text",
160                got: shape_name(&other),
161            });
162        }
163    };
164    let coerced = texts
165        .into_iter()
166        .map(|text| match labels {
167            Some(labels) => coerce_label(&text, labels),
168            None => text.trim().to_string(),
169        })
170        .collect();
171    Ok(InferenceOutputs::Text(coerced))
172}
173
174/// Map free text to a canonical label, or the trimmed text if none matches.
175fn coerce_label(text: &str, labels: &[String]) -> String {
176    let trimmed = text.trim();
177    if let Some(label) = labels.iter().find(|l| l.eq_ignore_ascii_case(trimmed)) {
178        return label.clone();
179    }
180    let lower = trimmed.to_ascii_lowercase();
181    if let Some(label) = labels
182        .iter()
183        .find(|l| lower.contains(&l.to_ascii_lowercase()))
184    {
185        return label.clone();
186    }
187    trimmed.to_string()
188}
189
190/// Pass text outputs through, rejecting a vector shape.
191fn expect_text(
192    task: Task,
193    kind: BackendKind,
194    raw: InferenceOutputs,
195) -> Result<InferenceOutputs, AdapterError> {
196    match raw {
197        InferenceOutputs::Text(text) => Ok(InferenceOutputs::Text(text)),
198        other => Err(AdapterError::UnexpectedOutputShape {
199            task,
200            kind,
201            expected: "text",
202            got: shape_name(&other),
203        }),
204    }
205}
206
207/// Pass vector outputs through, rejecting a text shape.
208fn expect_vectors(
209    task: Task,
210    kind: BackendKind,
211    raw: InferenceOutputs,
212) -> Result<InferenceOutputs, AdapterError> {
213    match raw {
214        InferenceOutputs::Vectors(vectors) => Ok(InferenceOutputs::Vectors(vectors)),
215        other => Err(AdapterError::UnexpectedOutputShape {
216            task,
217            kind,
218            expected: "vectors",
219            got: shape_name(&other),
220        }),
221    }
222}
223
224/// Static name of an output shape, for error messages.
225fn shape_name(raw: &InferenceOutputs) -> &'static str {
226    match raw {
227        InferenceOutputs::Text(_) => "text",
228        InferenceOutputs::Vectors(_) => "vectors",
229        InferenceOutputs::Scores(_) => "scores",
230    }
231}
232
233/// Local sentiment: softmax each row's logits, then `P(positive) − P(negative)`,
234/// a signed score in `[-1, 1]`. The model's labels locate the positive and
235/// negative classes; a neutral class (if present) pulls the score toward 0.
236fn sentiment_from_logits(
237    raw: InferenceOutputs,
238    labels: Option<&[String]>,
239) -> Result<InferenceOutputs, AdapterError> {
240    let rows = match raw {
241        InferenceOutputs::Vectors(rows) => rows,
242        other => {
243            return Err(AdapterError::UnexpectedOutputShape {
244                task: Task::Sentiment,
245                kind: BackendKind::Local,
246                expected: "vectors",
247                got: shape_name(&other),
248            });
249        }
250    };
251    let labels = labels.ok_or(AdapterError::MissingLabels)?;
252    let pos = labels
253        .iter()
254        .position(|l| l.eq_ignore_ascii_case("positive"));
255    let neg = labels
256        .iter()
257        .position(|l| l.eq_ignore_ascii_case("negative"));
258    let (Some(pos), Some(neg)) = (pos, neg) else {
259        return Err(AdapterError::SentimentLabelsUnusable {
260            labels: labels.to_vec(),
261        });
262    };
263    let mut scores = Vec::with_capacity(rows.len());
264    for logits in rows {
265        if logits.is_empty() {
266            return Err(AdapterError::EmptyLogits);
267        }
268        let probs = softmax(&logits);
269        let p_pos = probs.get(pos).copied().unwrap_or(0.0);
270        let p_neg = probs.get(neg).copied().unwrap_or(0.0);
271        scores.push(f64::from(p_pos - p_neg));
272    }
273    Ok(InferenceOutputs::Scores(scores))
274}
275
276/// Remote sentiment: parse a number from each reply and clamp to `[-1, 1]`.
277fn sentiment_from_text(raw: InferenceOutputs) -> Result<InferenceOutputs, AdapterError> {
278    let texts = match raw {
279        InferenceOutputs::Text(texts) => texts,
280        other => {
281            return Err(AdapterError::UnexpectedOutputShape {
282                task: Task::Sentiment,
283                kind: BackendKind::Remote,
284                expected: "text",
285                got: shape_name(&other),
286            });
287        }
288    };
289    let mut scores = Vec::with_capacity(texts.len());
290    for reply in texts {
291        let value = parse_score(&reply).ok_or(AdapterError::UnparseableScore { reply })?;
292        scores.push(value.clamp(-1.0, 1.0));
293    }
294    Ok(InferenceOutputs::Scores(scores))
295}
296
297/// Numerically stable softmax over a row of logits.
298fn softmax(logits: &[f32]) -> Vec<f32> {
299    let max = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
300    let mut exps: Vec<f32> = logits.iter().map(|&v| (v - max).exp()).collect();
301    let sum: f32 = exps.iter().sum();
302    if sum > 0.0 {
303        for e in &mut exps {
304            *e /= sum;
305        }
306    }
307    exps
308}
309
310/// Parse the first number out of a reply (`"0.8"`, `"The sentiment is -0.5."`),
311/// or `None` if there is none.
312fn parse_score(reply: &str) -> Option<f64> {
313    let trimmed = reply.trim();
314    if let Ok(v) = trimmed.parse::<f64>() {
315        return Some(v);
316    }
317    trimmed.split_whitespace().find_map(|token| {
318        // Strip surrounding punctuation: a leading sign/digit may start it, but
319        // only a digit may end it (so a sentence-final '.' isn't kept).
320        token
321            .trim_start_matches(|c: char| !(c.is_ascii_digit() || c == '-'))
322            .trim_end_matches(|c: char| !c.is_ascii_digit())
323            .parse::<f64>()
324            .ok()
325    })
326}
327
328/// Index of the maximum value, or `None` for an empty slice.
329fn argmax(values: &[f32]) -> Option<usize> {
330    let mut best: Option<(usize, f32)> = None;
331    for (index, &value) in values.iter().enumerate() {
332        match best {
333            Some((_, current)) if value <= current => {}
334            _ => best = Some((index, value)),
335        }
336    }
337    best.map(|(index, _)| index)
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343
344    fn labels() -> Vec<String> {
345        vec!["negative".into(), "positive".into(), "neutral".into()]
346    }
347
348    #[test]
349    fn local_classify_argmaxes_logits_to_labels() {
350        let raw = InferenceOutputs::Vectors(vec![vec![0.1, 0.9, 0.2], vec![5.0, 0.0, 0.0]]);
351        let out = parse_response(Task::Classify, BackendKind::Local, raw, Some(&labels())).unwrap();
352        assert_eq!(
353            out,
354            InferenceOutputs::Text(vec!["positive".into(), "negative".into()])
355        );
356    }
357
358    #[test]
359    fn local_classify_requires_labels() {
360        let raw = InferenceOutputs::Vectors(vec![vec![0.1, 0.9]]);
361        assert_eq!(
362            parse_response(Task::Sentiment, BackendKind::Local, raw, None),
363            Err(AdapterError::MissingLabels)
364        );
365    }
366
367    #[test]
368    fn remote_classify_coerces_to_canonical_label() {
369        // Exact (case-insensitive), containment in a wrapped reply, and the
370        // no-match fallback to trimmed raw text.
371        let raw = InferenceOutputs::Text(vec![
372            "The sentiment is Positive.".into(),
373            "NEGATIVE".into(),
374            "totally unrelated".into(),
375        ]);
376        let out =
377            parse_response(Task::Classify, BackendKind::Remote, raw, Some(&labels())).unwrap();
378        assert_eq!(
379            out,
380            InferenceOutputs::Text(vec![
381                "positive".into(),
382                "negative".into(),
383                "totally unrelated".into(),
384            ])
385        );
386    }
387
388    #[test]
389    fn local_sentiment_softmaxes_logits_to_a_signed_score() {
390        // labels = [negative, positive, neutral]. Row 1 favours positive, row 2
391        // favours negative. Score = P(pos) − P(neg) ∈ [-1, 1].
392        let raw = InferenceOutputs::Vectors(vec![vec![0.0, 2.0, 0.0], vec![3.0, 0.0, 0.0]]);
393        let out =
394            parse_response(Task::Sentiment, BackendKind::Local, raw, Some(&labels())).unwrap();
395        let InferenceOutputs::Scores(scores) = out else {
396            panic!("sentiment is numeric");
397        };
398        // Row 1: softmax([0,2,0]) → P(neg)=P(neu)=e^0/(2+e^2), P(pos)=e^2/(2+e^2).
399        let denom = 2.0 + std::f32::consts::E.powi(2);
400        let expected0 = f64::from((std::f32::consts::E.powi(2) - 1.0) / denom);
401        assert!((scores[0] - expected0).abs() < 1e-6, "got {}", scores[0]);
402        assert!(scores[0] > 0.0 && scores[1] < 0.0);
403        assert!((-1.0..=1.0).contains(&scores[0]) && (-1.0..=1.0).contains(&scores[1]));
404    }
405
406    #[test]
407    fn local_sentiment_needs_positive_and_negative_labels() {
408        let raw = InferenceOutputs::Vectors(vec![vec![0.1, 0.9]]);
409        let stars = vec!["one_star".to_string(), "five_star".to_string()];
410        assert!(matches!(
411            parse_response(Task::Sentiment, BackendKind::Local, raw, Some(&stars)),
412            Err(AdapterError::SentimentLabelsUnusable { .. })
413        ));
414    }
415
416    #[test]
417    fn remote_sentiment_parses_and_clamps_a_number() {
418        let raw = InferenceOutputs::Text(vec![
419            "0.8".into(),
420            "The sentiment is -0.5.".into(),
421            "2.0".into(), // out of range → clamped
422        ]);
423        let out = parse_response(Task::Sentiment, BackendKind::Remote, raw, None).unwrap();
424        let InferenceOutputs::Scores(scores) = out else {
425            panic!("sentiment is numeric");
426        };
427        assert!((scores[0] - 0.8).abs() < 1e-12);
428        assert!((scores[1] + 0.5).abs() < 1e-12);
429        assert!((scores[2] - 1.0).abs() < 1e-12);
430    }
431
432    #[test]
433    fn generation_is_remote_only() {
434        let raw = InferenceOutputs::Vectors(vec![vec![0.0]]);
435        assert_eq!(
436            parse_response(Task::Complete, BackendKind::Local, raw, None),
437            Err(AdapterError::UnsupportedCombination {
438                task: Task::Complete,
439                kind: BackendKind::Local,
440            })
441        );
442    }
443}