laminar_ai/backends/
anthropic.rs1use std::time::Duration;
8
9use async_trait::async_trait;
10use futures::stream::{StreamExt, TryStreamExt};
11use serde::{Deserialize, Serialize};
12
13use crate::backends::remote::{add_usage, chat_prompt, post_json};
14use crate::provider::{
15 InferenceOutputs, InferenceProvider, InferenceRequest, InferenceResponse, ProviderError, Usage,
16};
17use crate::registry::Task;
18
19const REQUEST_TIMEOUT_MS: u64 = 60_000;
20const MAX_RETRIES: u32 = 2;
21const ANTHROPIC_VERSION: &str = "2023-06-01";
22const DEFAULT_MAX_TOKENS: u32 = 1024;
24
25pub struct AnthropicProvider {
27 client: reqwest::Client,
28 base_url: String,
29 api_key: String,
30 max_concurrency: usize,
31}
32
33impl AnthropicProvider {
34 pub fn new(
42 base_url: impl Into<String>,
43 api_key: impl Into<String>,
44 max_concurrency: usize,
45 ) -> Result<Self, ProviderError> {
46 let client = reqwest::Client::builder()
47 .timeout(Duration::from_millis(REQUEST_TIMEOUT_MS))
48 .build()
49 .map_err(|e| ProviderError::Transport(e.to_string()))?;
50 Ok(Self {
51 client,
52 base_url: base_url.into().trim_end_matches('/').to_string(),
53 api_key: api_key.into(),
54 max_concurrency: max_concurrency.max(1),
55 })
56 }
57}
58
59#[async_trait]
60impl InferenceProvider for AnthropicProvider {
61 async fn infer_batch(
62 &self,
63 request: InferenceRequest,
64 ) -> Result<InferenceResponse, ProviderError> {
65 if request.task == Task::Embed {
66 return Err(ProviderError::UnsupportedTask(Task::Embed));
67 }
68
69 let url = format!("{}/v1/messages", self.base_url);
70 let bodies: Vec<MessageBody> = request
71 .inputs
72 .iter()
73 .map(|input| {
74 let (system, user) =
75 chat_prompt(request.task, input, request.params.labels.as_deref());
76 MessageBody {
77 model: request.model.clone(),
78 max_tokens: DEFAULT_MAX_TOKENS,
79 system,
80 messages: vec![Message::user(user)],
81 }
82 })
83 .collect();
84
85 let url = &url;
86 let calls = bodies.into_iter().map(|body| async move {
87 let builder = self
88 .client
89 .post(url)
90 .header("x-api-key", &self.api_key)
91 .header("anthropic-version", ANTHROPIC_VERSION)
92 .json(&body);
93 let response: MessageResponse =
94 post_json(builder, MAX_RETRIES, REQUEST_TIMEOUT_MS).await?;
95 parse_message(response)
96 });
97
98 let results: Vec<(String, Usage)> = futures::stream::iter(calls)
99 .buffered(self.max_concurrency)
100 .try_collect()
101 .await?;
102
103 let mut texts = Vec::with_capacity(results.len());
104 let mut usage = Usage::ZERO;
105 for (text, call_usage) in results {
106 texts.push(text);
107 usage = add_usage(usage, call_usage);
108 }
109 Ok(InferenceResponse {
110 outputs: InferenceOutputs::Text(texts),
111 usage,
112 })
113 }
114
115 fn name(&self) -> &'static str {
116 "anthropic"
117 }
118}
119
120#[derive(Serialize)]
123struct MessageBody {
124 model: String,
125 max_tokens: u32,
126 system: String,
127 messages: Vec<Message>,
128}
129
130#[derive(Serialize)]
131struct Message {
132 role: &'static str,
133 content: String,
134}
135
136impl Message {
137 fn user(content: String) -> Self {
138 Self {
139 role: "user",
140 content,
141 }
142 }
143}
144
145#[derive(Deserialize)]
146struct MessageResponse {
147 content: Vec<ContentBlock>,
148 usage: Option<AnthropicUsage>,
149}
150
151#[derive(Deserialize)]
152struct ContentBlock {
153 #[serde(rename = "type")]
154 kind: String,
155 #[serde(default)]
156 text: String,
157}
158
159#[derive(Deserialize)]
160struct AnthropicUsage {
161 #[serde(default)]
162 input_tokens: u64,
163 #[serde(default)]
164 output_tokens: u64,
165}
166
167fn parse_message(response: MessageResponse) -> Result<(String, Usage), ProviderError> {
169 let text = response
170 .content
171 .into_iter()
172 .find(|b| b.kind == "text")
173 .map(|b| b.text)
174 .ok_or_else(|| {
175 ProviderError::BadResponse("messages response had no text block".to_string())
176 })?;
177 let usage = response.usage.map_or(Usage::ZERO, |u| Usage {
178 input_tokens: u.input_tokens,
179 output_tokens: u.output_tokens,
180 cost_micros: 0,
181 });
182 Ok((text, usage))
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188
189 #[test]
190 fn message_body_has_system_and_max_tokens() {
191 let body = MessageBody {
192 model: "claude-x".to_string(),
193 max_tokens: 256,
194 system: "be terse".to_string(),
195 messages: vec![Message::user("hello".to_string())],
196 };
197 let value = serde_json::to_value(&body).unwrap();
198 assert_eq!(value["model"], "claude-x");
199 assert_eq!(value["max_tokens"], 256);
200 assert_eq!(value["system"], "be terse");
201 assert_eq!(value["messages"][0]["role"], "user");
202 }
203
204 #[test]
205 fn parse_message_takes_first_text_block() {
206 let json = r#"{
207 "content": [{"type": "text", "text": "positive"}],
208 "usage": {"input_tokens": 9, "output_tokens": 1}
209 }"#;
210 let response: MessageResponse = serde_json::from_str(json).unwrap();
211 let (text, usage) = parse_message(response).unwrap();
212 assert_eq!(text, "positive");
213 assert_eq!(usage.input_tokens, 9);
214 assert_eq!(usage.output_tokens, 1);
215 }
216
217 #[test]
218 fn parse_message_errors_without_text() {
219 let response: MessageResponse = serde_json::from_str(r#"{"content": []}"#).unwrap();
220 assert!(parse_message(response).is_err());
221 }
222}