Skip to main content

laminar_ai/backends/
anthropic.rs

1//! Anthropic Messages API provider.
2//!
3//! Drives the discriminative and generative tasks (one bounded-concurrent call
4//! per input row). Anthropic exposes no embeddings endpoint, so `ai_embed` is
5//! rejected — use the OpenAI-compatible provider for embeddings.
6
7use std::time::Duration;
8
9use async_trait::async_trait;
10use futures::stream::{StreamExt, TryStreamExt};
11use serde::{Deserialize, Serialize};
12
13use crate::backends::remote::{add_usage, chat_prompt, post_json};
14use crate::provider::{
15    InferenceOutputs, InferenceProvider, InferenceRequest, InferenceResponse, ProviderError, Usage,
16};
17use crate::registry::Task;
18
19const REQUEST_TIMEOUT_MS: u64 = 60_000;
20const MAX_RETRIES: u32 = 2;
21const ANTHROPIC_VERSION: &str = "2023-06-01";
22/// Messages API requires `max_tokens`; this caps a single reply.
23const DEFAULT_MAX_TOKENS: u32 = 1024;
24
25/// Anthropic Messages provider.
26pub struct AnthropicProvider {
27    client: reqwest::Client,
28    base_url: String,
29    api_key: String,
30    max_concurrency: usize,
31}
32
33impl AnthropicProvider {
34    /// Build a provider for `base_url` (e.g. `https://api.anthropic.com`),
35    /// authenticating with `api_key` and issuing at most `max_concurrency`
36    /// concurrent requests per batch.
37    ///
38    /// # Errors
39    ///
40    /// Returns [`ProviderError::Transport`] if the HTTP client cannot be built.
41    pub fn new(
42        base_url: impl Into<String>,
43        api_key: impl Into<String>,
44        max_concurrency: usize,
45    ) -> Result<Self, ProviderError> {
46        let client = reqwest::Client::builder()
47            .timeout(Duration::from_millis(REQUEST_TIMEOUT_MS))
48            .build()
49            .map_err(|e| ProviderError::Transport(e.to_string()))?;
50        Ok(Self {
51            client,
52            base_url: base_url.into().trim_end_matches('/').to_string(),
53            api_key: api_key.into(),
54            max_concurrency: max_concurrency.max(1),
55        })
56    }
57}
58
59#[async_trait]
60impl InferenceProvider for AnthropicProvider {
61    async fn infer_batch(
62        &self,
63        request: InferenceRequest,
64    ) -> Result<InferenceResponse, ProviderError> {
65        if request.task == Task::Embed {
66            return Err(ProviderError::UnsupportedTask(Task::Embed));
67        }
68
69        let url = format!("{}/v1/messages", self.base_url);
70        let bodies: Vec<MessageBody> = request
71            .inputs
72            .iter()
73            .map(|input| {
74                let (system, user) =
75                    chat_prompt(request.task, input, request.params.labels.as_deref());
76                MessageBody {
77                    model: request.model.clone(),
78                    max_tokens: DEFAULT_MAX_TOKENS,
79                    system,
80                    messages: vec![Message::user(user)],
81                }
82            })
83            .collect();
84
85        let url = &url;
86        let calls = bodies.into_iter().map(|body| async move {
87            let builder = self
88                .client
89                .post(url)
90                .header("x-api-key", &self.api_key)
91                .header("anthropic-version", ANTHROPIC_VERSION)
92                .json(&body);
93            let response: MessageResponse =
94                post_json(builder, MAX_RETRIES, REQUEST_TIMEOUT_MS).await?;
95            parse_message(response)
96        });
97
98        let results: Vec<(String, Usage)> = futures::stream::iter(calls)
99            .buffered(self.max_concurrency)
100            .try_collect()
101            .await?;
102
103        let mut texts = Vec::with_capacity(results.len());
104        let mut usage = Usage::ZERO;
105        for (text, call_usage) in results {
106            texts.push(text);
107            usage = add_usage(usage, call_usage);
108        }
109        Ok(InferenceResponse {
110            outputs: InferenceOutputs::Text(texts),
111            usage,
112        })
113    }
114
115    fn name(&self) -> &'static str {
116        "anthropic"
117    }
118}
119
120// --- wire shapes ---
121
122#[derive(Serialize)]
123struct MessageBody {
124    model: String,
125    max_tokens: u32,
126    system: String,
127    messages: Vec<Message>,
128}
129
130#[derive(Serialize)]
131struct Message {
132    role: &'static str,
133    content: String,
134}
135
136impl Message {
137    fn user(content: String) -> Self {
138        Self {
139            role: "user",
140            content,
141        }
142    }
143}
144
145#[derive(Deserialize)]
146struct MessageResponse {
147    content: Vec<ContentBlock>,
148    usage: Option<AnthropicUsage>,
149}
150
151#[derive(Deserialize)]
152struct ContentBlock {
153    #[serde(rename = "type")]
154    kind: String,
155    #[serde(default)]
156    text: String,
157}
158
159#[derive(Deserialize)]
160struct AnthropicUsage {
161    #[serde(default)]
162    input_tokens: u64,
163    #[serde(default)]
164    output_tokens: u64,
165}
166
167/// Take the first text content block and token usage from a messages response.
168fn parse_message(response: MessageResponse) -> Result<(String, Usage), ProviderError> {
169    let text = response
170        .content
171        .into_iter()
172        .find(|b| b.kind == "text")
173        .map(|b| b.text)
174        .ok_or_else(|| {
175            ProviderError::BadResponse("messages response had no text block".to_string())
176        })?;
177    let usage = response.usage.map_or(Usage::ZERO, |u| Usage {
178        input_tokens: u.input_tokens,
179        output_tokens: u.output_tokens,
180        cost_micros: 0,
181    });
182    Ok((text, usage))
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188
189    #[test]
190    fn message_body_has_system_and_max_tokens() {
191        let body = MessageBody {
192            model: "claude-x".to_string(),
193            max_tokens: 256,
194            system: "be terse".to_string(),
195            messages: vec![Message::user("hello".to_string())],
196        };
197        let value = serde_json::to_value(&body).unwrap();
198        assert_eq!(value["model"], "claude-x");
199        assert_eq!(value["max_tokens"], 256);
200        assert_eq!(value["system"], "be terse");
201        assert_eq!(value["messages"][0]["role"], "user");
202    }
203
204    #[test]
205    fn parse_message_takes_first_text_block() {
206        let json = r#"{
207            "content": [{"type": "text", "text": "positive"}],
208            "usage": {"input_tokens": 9, "output_tokens": 1}
209        }"#;
210        let response: MessageResponse = serde_json::from_str(json).unwrap();
211        let (text, usage) = parse_message(response).unwrap();
212        assert_eq!(text, "positive");
213        assert_eq!(usage.input_tokens, 9);
214        assert_eq!(usage.output_tokens, 1);
215    }
216
217    #[test]
218    fn parse_message_errors_without_text() {
219        let response: MessageResponse = serde_json::from_str(r#"{"content": []}"#).unwrap();
220        assert!(parse_message(response).is_err());
221    }
222}