Files
Genarrative/server-rs/crates/platform-llm/src/lib.rs

1142 lines
36 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use std::{error::Error, fmt, str as std_str, time::Duration};
use log::{debug, warn};
use reqwest::{Client, StatusCode};
use serde::{Deserialize, Serialize};
use tokio::time::sleep;
pub const DEFAULT_ARK_BASE_URL: &str = "https://ark.cn-beijing.volces.com/api/v3";
pub const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 30_000;
pub const DEFAULT_MAX_RETRIES: u32 = 1;
pub const DEFAULT_RETRY_BACKOFF_MS: u64 = 500;
pub const CHAT_COMPLETIONS_PATH: &str = "/chat/completions";
// 冻结平台来源,避免上层继续散落 provider 字符串。
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LlmProvider {
Ark,
DashScope,
OpenAiCompatible,
}
// 统一收口文本模型网关配置,避免 api-server 和业务模块各自重复解析环境变量。
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct LlmConfig {
provider: LlmProvider,
base_url: String,
api_key: String,
model: String,
request_timeout_ms: u64,
max_retries: u32,
retry_backoff_ms: u64,
}
// 首版只冻结当前项目已稳定使用的 system/user/assistant 三种消息角色。
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LlmMessageRole {
System,
User,
Assistant,
}
// 单条消息保持 OpenAI 兼容格式,供统一请求体直接序列化。
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct LlmMessage {
pub role: LlmMessageRole,
pub content: String,
}
// 文本补全请求冻结为“消息列表 + 可选模型覆盖 + 可选 max_tokens”最小闭环。
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct LlmTextRequest {
pub model: Option<String>,
pub messages: Vec<LlmMessage>,
pub max_tokens: Option<u32>,
}
// 上层在流式消费时拿到的是“累计文本 + 当前增量”,避免每层重新自己拼接。
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct LlmStreamDelta {
pub accumulated_text: String,
pub delta_text: String,
pub finish_reason: Option<String>,
}
// 用于保留 token 计数,后续模块可以决定是否写入审计或成本统计。
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct LlmTokenUsage {
pub prompt_tokens: u64,
pub completion_tokens: u64,
pub total_tokens: u64,
}
// 统一文本响应,避免业务层再去解析 choices/message/content。
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct LlmTextResponse {
pub provider: LlmProvider,
pub model: String,
pub content: String,
pub finish_reason: Option<String>,
pub response_id: Option<String>,
pub usage: Option<LlmTokenUsage>,
}
// 将上游错误归一到稳定的领域枚举,后续 api-server 可以直接映射成 HTTP error contract。
#[derive(Debug, PartialEq, Eq)]
pub enum LlmError {
InvalidConfig(String),
InvalidRequest(String),
Timeout { attempts: u32 },
Connectivity { attempts: u32, message: String },
Upstream { status_code: u16, message: String },
StreamUnavailable,
EmptyResponse,
Transport(String),
Deserialize(String),
}
// 统一 OpenAI 兼容文本网关 client。
#[derive(Clone, Debug)]
pub struct LlmClient {
config: LlmConfig,
http_client: Client,
}
#[derive(Serialize)]
struct ChatCompletionsRequestBody<'a> {
model: &'a str,
messages: &'a [LlmMessage],
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
}
#[derive(Deserialize)]
struct ChatCompletionsResponseEnvelope {
id: Option<String>,
model: Option<String>,
choices: Vec<ChatCompletionsChoice>,
usage: Option<LlmTokenUsage>,
}
#[derive(Deserialize)]
struct ChatCompletionsChoice {
#[serde(default)]
message: Option<ChatCompletionsMessage>,
#[serde(default)]
delta: Option<ChatCompletionsMessage>,
#[serde(default)]
finish_reason: Option<String>,
}
#[derive(Deserialize)]
struct ChatCompletionsMessage {
#[serde(default)]
content: Option<ChatCompletionsContent>,
}
#[derive(Deserialize)]
#[serde(untagged)]
enum ChatCompletionsContent {
Text(String),
Parts(Vec<ChatCompletionsContentPart>),
}
#[derive(Deserialize)]
struct ChatCompletionsContentPart {
#[serde(rename = "type")]
#[allow(dead_code)]
part_type: Option<String>,
#[serde(default)]
text: Option<String>,
}
#[derive(Default)]
struct OpenAiCompatibleSseParser {
buffer: String,
}
#[derive(Debug)]
struct ParsedStreamEvent {
delta_text: Option<String>,
finish_reason: Option<String>,
}
impl LlmProvider {
pub fn as_str(&self) -> &'static str {
match self {
Self::Ark => "ark",
Self::DashScope => "dash_scope",
Self::OpenAiCompatible => "openai_compatible",
}
}
}
impl LlmConfig {
#[allow(clippy::too_many_arguments)]
pub fn new(
provider: LlmProvider,
base_url: String,
api_key: String,
model: String,
request_timeout_ms: u64,
max_retries: u32,
retry_backoff_ms: u64,
) -> Result<Self, LlmError> {
let base_url = normalize_non_empty(base_url, "LLM base_url 不能为空")?;
let api_key = normalize_non_empty(api_key, "LLM api_key 不能为空")?;
let model = normalize_non_empty(model, "LLM model 不能为空")?;
if request_timeout_ms == 0 {
return Err(LlmError::InvalidConfig(
"LLM request_timeout_ms 必须大于 0".to_string(),
));
}
Ok(Self {
provider,
base_url,
api_key,
model,
request_timeout_ms,
max_retries,
retry_backoff_ms,
})
}
pub fn ark_default(api_key: String, model: String) -> Result<Self, LlmError> {
Self::new(
LlmProvider::Ark,
DEFAULT_ARK_BASE_URL.to_string(),
api_key,
model,
DEFAULT_REQUEST_TIMEOUT_MS,
DEFAULT_MAX_RETRIES,
DEFAULT_RETRY_BACKOFF_MS,
)
}
pub fn provider(&self) -> LlmProvider {
self.provider
}
pub fn base_url(&self) -> &str {
&self.base_url
}
pub fn api_key(&self) -> &str {
&self.api_key
}
pub fn model(&self) -> &str {
&self.model
}
pub fn request_timeout_ms(&self) -> u64 {
self.request_timeout_ms
}
pub fn max_retries(&self) -> u32 {
self.max_retries
}
pub fn retry_backoff_ms(&self) -> u64 {
self.retry_backoff_ms
}
pub fn chat_completions_url(&self) -> String {
format!(
"{}/{}",
self.base_url.trim_end_matches('/'),
CHAT_COMPLETIONS_PATH.trim_start_matches('/')
)
}
}
impl LlmMessage {
pub fn new(role: LlmMessageRole, content: impl Into<String>) -> Self {
Self {
role,
content: content.into(),
}
}
pub fn system(content: impl Into<String>) -> Self {
Self::new(LlmMessageRole::System, content)
}
pub fn user(content: impl Into<String>) -> Self {
Self::new(LlmMessageRole::User, content)
}
pub fn assistant(content: impl Into<String>) -> Self {
Self::new(LlmMessageRole::Assistant, content)
}
}
impl LlmTextRequest {
pub fn new(messages: Vec<LlmMessage>) -> Self {
Self {
model: None,
messages,
max_tokens: None,
}
}
pub fn single_turn(system_prompt: impl Into<String>, user_prompt: impl Into<String>) -> Self {
Self::new(vec![
LlmMessage::system(system_prompt),
LlmMessage::user(user_prompt),
])
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
fn validate(&self) -> Result<(), LlmError> {
if self.messages.is_empty() {
return Err(LlmError::InvalidRequest(
"LLM messages 不能为空".to_string(),
));
}
for message in &self.messages {
if message.content.trim().is_empty() {
return Err(LlmError::InvalidRequest(
"LLM message.content 不能为空".to_string(),
));
}
}
if let Some(model) = &self.model
&& model.trim().is_empty()
{
return Err(LlmError::InvalidRequest(
"LLM request.model 不能为空字符串".to_string(),
));
}
Ok(())
}
fn resolved_model<'a>(&'a self, fallback_model: &'a str) -> &'a str {
self.model
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty())
.unwrap_or(fallback_model)
}
}
impl fmt::Display for LlmError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidConfig(message)
| Self::InvalidRequest(message)
| Self::Transport(message)
| Self::Deserialize(message) => write!(f, "{message}"),
Self::Timeout { attempts } => {
write!(f, "LLM 请求超时,累计尝试 {attempts} 次")
}
Self::Connectivity { attempts, message } => {
write!(f, "LLM 连接失败,累计尝试 {attempts} 次:{message}")
}
Self::Upstream {
status_code,
message,
} => write!(f, "LLM 上游返回 {status_code}{message}"),
Self::StreamUnavailable => write!(f, "LLM 流式响应体不可用"),
Self::EmptyResponse => write!(f, "LLM 返回内容为空"),
}
}
}
impl Error for LlmError {}
impl LlmClient {
pub fn new(config: LlmConfig) -> Result<Self, LlmError> {
let http_client = Client::builder().build().map_err(|error| {
LlmError::InvalidConfig(format!("构建 reqwest client 失败:{error}"))
})?;
Ok(Self {
config,
http_client,
})
}
pub fn config(&self) -> &LlmConfig {
&self.config
}
pub async fn request_text(&self, request: LlmTextRequest) -> Result<LlmTextResponse, LlmError> {
request.validate()?;
let resolved_model = request.resolved_model(self.config.model()).to_string();
let response = self.execute_request(&request, false).await?;
let raw_text = response
.text()
.await
.map_err(|error| map_stream_read_error(error, 1))?;
parse_chat_completions_response(self.config.provider(), &resolved_model, raw_text.as_str())
}
pub async fn request_single_message_text(
&self,
system_prompt: impl Into<String>,
user_prompt: impl Into<String>,
) -> Result<LlmTextResponse, LlmError> {
self.request_text(LlmTextRequest::single_turn(system_prompt, user_prompt))
.await
}
pub async fn stream_text<F>(
&self,
request: LlmTextRequest,
mut on_delta: F,
) -> Result<LlmTextResponse, LlmError>
where
F: FnMut(&LlmStreamDelta),
{
request.validate()?;
let resolved_model = request.resolved_model(self.config.model()).to_string();
let mut response = self.execute_request(&request, true).await?;
let response_id = response
.headers()
.get("x-request-id")
.and_then(|value| value.to_str().ok())
.map(str::to_string);
let mut parser = OpenAiCompatibleSseParser::default();
let mut accumulated_text = String::new();
let mut finish_reason = None;
let mut undecoded_chunk_bytes = Vec::new();
loop {
let next_chunk = response
.chunk()
.await
.map_err(|error| map_stream_read_error(error, 1))?;
let Some(chunk) = next_chunk else {
break;
};
undecoded_chunk_bytes.extend_from_slice(chunk.as_ref());
let (chunk_text, remaining_bytes) =
decode_utf8_stream_chunk(undecoded_chunk_bytes.as_slice())?;
undecoded_chunk_bytes = remaining_bytes;
if chunk_text.is_empty() {
continue;
}
for event in parser.push_chunk(chunk_text.as_ref())? {
if let Some(delta_text) = event.delta_text
&& !delta_text.is_empty()
{
accumulated_text.push_str(delta_text.as_str());
let update = LlmStreamDelta {
accumulated_text: accumulated_text.clone(),
delta_text,
finish_reason: event.finish_reason.clone(),
};
on_delta(&update);
}
if event.finish_reason.is_some() {
finish_reason = event.finish_reason;
}
}
}
if !undecoded_chunk_bytes.is_empty() {
let trailing_text = std_str::from_utf8(undecoded_chunk_bytes.as_slice())
.map_err(|error| {
LlmError::Deserialize(format!(
"解析 LLM 流式 UTF-8 响应失败:{error}"
))
})?;
if !trailing_text.is_empty() {
for event in parser.push_chunk(trailing_text)? {
if let Some(delta_text) = event.delta_text
&& !delta_text.is_empty()
{
accumulated_text.push_str(delta_text.as_str());
let update = LlmStreamDelta {
accumulated_text: accumulated_text.clone(),
delta_text,
finish_reason: event.finish_reason.clone(),
};
on_delta(&update);
}
if event.finish_reason.is_some() {
finish_reason = event.finish_reason;
}
}
}
}
for event in parser.finish()? {
if let Some(delta_text) = event.delta_text
&& !delta_text.is_empty()
{
accumulated_text.push_str(delta_text.as_str());
let update = LlmStreamDelta {
accumulated_text: accumulated_text.clone(),
delta_text,
finish_reason: event.finish_reason.clone(),
};
on_delta(&update);
}
if event.finish_reason.is_some() {
finish_reason = event.finish_reason;
}
}
let content = accumulated_text.trim().to_string();
if content.is_empty() {
return Err(LlmError::EmptyResponse);
}
Ok(LlmTextResponse {
provider: self.config.provider(),
model: resolved_model,
content,
finish_reason,
response_id,
usage: None,
})
}
pub async fn stream_single_message_text<F>(
&self,
system_prompt: impl Into<String>,
user_prompt: impl Into<String>,
on_delta: F,
) -> Result<LlmTextResponse, LlmError>
where
F: FnMut(&LlmStreamDelta),
{
self.stream_text(
LlmTextRequest::single_turn(system_prompt, user_prompt),
on_delta,
)
.await
}
async fn execute_request(
&self,
request: &LlmTextRequest,
stream: bool,
) -> Result<reqwest::Response, LlmError> {
let request_body = ChatCompletionsRequestBody {
model: request.resolved_model(self.config.model()),
messages: request.messages.as_slice(),
stream,
max_tokens: request.max_tokens,
};
let max_attempts = self.config.max_retries().saturating_add(1);
for attempt in 1..=max_attempts {
debug!(
"platform-llm request started: provider={}, stream={}, attempt={}, model={}",
self.config.provider().as_str(),
stream,
attempt,
request_body.model
);
let send_result = self
.http_client
.post(self.config.chat_completions_url())
.bearer_auth(self.config.api_key())
.json(&request_body)
.timeout(Duration::from_millis(self.config.request_timeout_ms()))
.send()
.await;
match send_result {
Ok(response) if response.status().is_success() => {
debug!(
"platform-llm request succeeded: provider={}, stream={}, attempt={}, status={}",
self.config.provider().as_str(),
stream,
attempt,
response.status().as_u16()
);
return Ok(response);
}
Ok(response) => {
let status = response.status();
let raw_text = response.text().await.unwrap_or_default();
let message = extract_api_error_message(&raw_text, "LLM 上游请求失败");
if should_retry_status(status) && attempt < max_attempts {
warn!(
"platform-llm request retrying after upstream status: provider={}, attempt={}, status={}, message={}",
self.config.provider().as_str(),
attempt,
status.as_u16(),
message
);
self.sleep_before_retry(attempt).await;
continue;
}
return Err(LlmError::Upstream {
status_code: status.as_u16(),
message,
});
}
Err(error) if error.is_timeout() => {
if attempt < max_attempts {
warn!(
"platform-llm request retrying after timeout: provider={}, attempt={}",
self.config.provider().as_str(),
attempt
);
self.sleep_before_retry(attempt).await;
continue;
}
return Err(LlmError::Timeout { attempts: attempt });
}
Err(error) if error.is_connect() => {
let message = error.to_string();
if attempt < max_attempts {
warn!(
"platform-llm request retrying after connectivity failure: provider={}, attempt={}, error={}",
self.config.provider().as_str(),
attempt,
message
);
self.sleep_before_retry(attempt).await;
continue;
}
return Err(LlmError::Connectivity {
attempts: attempt,
message,
});
}
Err(error) => {
return Err(LlmError::Transport(error.to_string()));
}
}
}
Err(LlmError::Transport(
"LLM 请求在重试循环后仍未返回结果".to_string(),
))
}
async fn sleep_before_retry(&self, attempt: u32) {
let backoff_ms = self
.config
.retry_backoff_ms()
.saturating_mul(u64::from(attempt));
if backoff_ms > 0 {
sleep(Duration::from_millis(backoff_ms)).await;
}
}
}
impl OpenAiCompatibleSseParser {
fn push_chunk(&mut self, chunk: &str) -> Result<Vec<ParsedStreamEvent>, LlmError> {
self.buffer.push_str(chunk);
self.buffer = self.buffer.replace("\r\n", "\n");
self.drain_complete_events()
}
fn finish(&mut self) -> Result<Vec<ParsedStreamEvent>, LlmError> {
if self.buffer.trim().is_empty() {
return Ok(Vec::new());
}
self.buffer.push_str("\n\n");
self.drain_complete_events()
}
fn drain_complete_events(&mut self) -> Result<Vec<ParsedStreamEvent>, LlmError> {
let mut events = Vec::new();
while let Some(boundary) = self.buffer.find("\n\n") {
let block = self.buffer[..boundary].to_string();
self.buffer = self.buffer[(boundary + 2)..].to_string();
if let Some(event) = parse_sse_event_block(block.as_str())? {
events.push(event);
}
}
Ok(events)
}
}
fn normalize_non_empty(value: String, error_message: &str) -> Result<String, LlmError> {
let trimmed = value.trim().to_string();
if trimmed.is_empty() {
return Err(LlmError::InvalidConfig(error_message.to_string()));
}
Ok(trimmed)
}
fn parse_chat_completions_response(
provider: LlmProvider,
fallback_model: &str,
raw_text: &str,
) -> Result<LlmTextResponse, LlmError> {
let parsed: ChatCompletionsResponseEnvelope = serde_json::from_str(raw_text)
.map_err(|error| LlmError::Deserialize(format!("解析 LLM JSON 响应失败:{error}")))?;
let first_choice = parsed
.choices
.first()
.ok_or_else(|| LlmError::Deserialize("LLM 响应缺少 choices[0]".to_string()))?;
let content = extract_message_text(first_choice)
.ok_or(LlmError::EmptyResponse)?
.trim()
.to_string();
if content.is_empty() {
return Err(LlmError::EmptyResponse);
}
Ok(LlmTextResponse {
provider,
model: parsed.model.unwrap_or_else(|| fallback_model.to_string()),
content,
finish_reason: first_choice.finish_reason.clone(),
response_id: parsed.id,
usage: parsed.usage,
})
}
fn extract_message_text(choice: &ChatCompletionsChoice) -> Option<String> {
choice
.message
.as_ref()
.and_then(|message| message.content.as_ref())
.and_then(extract_content_text)
.or_else(|| {
choice
.delta
.as_ref()
.and_then(|message| message.content.as_ref())
.and_then(extract_content_text)
})
}
fn extract_content_text(content: &ChatCompletionsContent) -> Option<String> {
match content {
ChatCompletionsContent::Text(text) => Some(text.clone()),
ChatCompletionsContent::Parts(parts) => {
let text = parts
.iter()
.filter_map(|part| part.text.as_deref())
.collect::<Vec<_>>()
.join("");
if text.is_empty() { None } else { Some(text) }
}
}
}
fn decode_utf8_stream_chunk(bytes: &[u8]) -> Result<(String, Vec<u8>), LlmError> {
match std_str::from_utf8(bytes) {
Ok(text) => Ok((text.to_string(), Vec::new())),
Err(error) => {
let valid_up_to = error.valid_up_to();
let Some(_) = error.error_len() else {
let decoded = std_str::from_utf8(&bytes[..valid_up_to]).map_err(|inner_error| {
LlmError::Deserialize(format!(
"解析 LLM 流式 UTF-8 响应失败:{inner_error}"
))
})?;
return Ok((decoded.to_string(), bytes[valid_up_to..].to_vec()));
};
Err(LlmError::Deserialize(format!(
"解析 LLM 流式 UTF-8 响应失败:{error}"
)))
}
}
}
fn parse_sse_event_block(block: &str) -> Result<Option<ParsedStreamEvent>, LlmError> {
let data_lines = block
.lines()
.filter_map(|line| line.trim().strip_prefix("data:"))
.map(str::trim_start)
.collect::<Vec<_>>();
if data_lines.is_empty() {
return Ok(None);
}
let data = data_lines.join("\n");
if data.trim().is_empty() || data.trim() == "[DONE]" {
return Ok(None);
}
let parsed: ChatCompletionsResponseEnvelope = serde_json::from_str(data.as_str())
.map_err(|error| LlmError::Deserialize(format!("解析 LLM SSE 事件失败:{error}")))?;
let first_choice = parsed
.choices
.first()
.ok_or_else(|| LlmError::Deserialize("LLM SSE 响应缺少 choices[0]".to_string()))?;
Ok(Some(ParsedStreamEvent {
delta_text: extract_message_text(first_choice),
finish_reason: first_choice.finish_reason.clone(),
}))
}
fn should_retry_status(status: StatusCode) -> bool {
status == StatusCode::REQUEST_TIMEOUT
|| status == StatusCode::TOO_MANY_REQUESTS
|| status.is_server_error()
}
fn extract_api_error_message(raw_text: &str, fallback_message: &str) -> String {
let trimmed = raw_text.trim();
if trimmed.is_empty() {
return fallback_message.to_string();
}
let parsed = serde_json::from_str::<serde_json::Value>(trimmed);
if let Ok(value) = parsed {
if let Some(message) = value
.get("error")
.and_then(|error| error.get("message"))
.and_then(serde_json::Value::as_str)
.map(str::trim)
.filter(|message| !message.is_empty())
{
return message.to_string();
}
if let Some(message) = value
.get("message")
.and_then(serde_json::Value::as_str)
.map(str::trim)
.filter(|message| !message.is_empty())
{
return message.to_string();
}
}
trimmed.to_string()
}
fn map_stream_read_error(error: reqwest::Error, attempts: u32) -> LlmError {
if error.is_timeout() {
return LlmError::Timeout { attempts };
}
if error.is_connect() {
return LlmError::Connectivity {
attempts,
message: error.to_string(),
};
}
LlmError::Transport(error.to_string())
}
#[cfg(test)]
mod tests {
use std::{
io::{Read, Write},
net::TcpListener,
thread,
time::Duration as StdDuration,
};
use super::*;
struct MockResponse {
status_line: &'static str,
content_type: &'static str,
body: String,
extra_headers: Vec<(&'static str, &'static str)>,
}
#[test]
fn llm_config_rejects_blank_api_key() {
let error = LlmConfig::new(
LlmProvider::Ark,
DEFAULT_ARK_BASE_URL.to_string(),
" ".to_string(),
"model-a".to_string(),
DEFAULT_REQUEST_TIMEOUT_MS,
DEFAULT_MAX_RETRIES,
DEFAULT_RETRY_BACKOFF_MS,
)
.expect_err("blank api key should be rejected");
assert_eq!(
error,
LlmError::InvalidConfig("LLM api_key 不能为空".to_string())
);
}
#[test]
fn llm_chat_completion_url_normalizes_trailing_slash() {
let config = LlmConfig::new(
LlmProvider::OpenAiCompatible,
"https://example.com/base///".to_string(),
"secret".to_string(),
"model-a".to_string(),
DEFAULT_REQUEST_TIMEOUT_MS,
DEFAULT_MAX_RETRIES,
DEFAULT_RETRY_BACKOFF_MS,
)
.expect("config should be valid");
assert_eq!(
config.chat_completions_url(),
"https://example.com/base/chat/completions"
);
}
#[test]
fn sse_parser_handles_split_chunks_and_done_marker() {
let mut parser = OpenAiCompatibleSseParser::default();
let events_a = parser
.push_chunk("data: {\"choices\":[{\"delta\":{\"content\":\"\"}}]}\r\n\r\n")
.expect("first chunk should parse");
let events_b = parser
.push_chunk("data: {\"choices\":[{\"delta\":{\"content\":\"\"},\"finish_reason\":\"stop\"}]}\n\ndata: [DONE]\n\n")
.expect("second chunk should parse");
assert_eq!(events_a.len(), 1);
assert_eq!(events_a[0].delta_text.as_deref(), Some(""));
assert_eq!(events_b.len(), 1);
assert_eq!(events_b[0].delta_text.as_deref(), Some(""));
assert_eq!(events_b[0].finish_reason.as_deref(), Some("stop"));
}
#[test]
fn decode_utf8_stream_chunk_preserves_incomplete_multibyte_suffix() {
let full_bytes = "你好".as_bytes();
let first_result = decode_utf8_stream_chunk(&full_bytes[..2])
.expect("incomplete utf-8 chunk should be buffered");
assert_eq!(first_result.0, "");
assert_eq!(first_result.1, full_bytes[..2].to_vec());
let mut combined = first_result.1;
combined.extend_from_slice(&full_bytes[2..]);
let second_result = decode_utf8_stream_chunk(combined.as_slice())
.expect("completed utf-8 bytes should decode");
assert_eq!(second_result.0, "你好");
assert!(second_result.1.is_empty());
}
#[tokio::test]
async fn request_text_parses_non_stream_response() {
let server_url = spawn_mock_server(vec![MockResponse {
status_line: "200 OK",
content_type: "application/json; charset=utf-8",
body: r#"{"id":"resp_01","model":"ark-test-model","choices":[{"message":{"content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":6,"total_tokens":16}}"#.to_string(),
extra_headers: Vec::new(),
}]);
let client = build_test_client(server_url, 0);
let response = client
.request_single_message_text("系统", "用户")
.await
.expect("request_text should succeed");
assert_eq!(response.provider, LlmProvider::Ark);
assert_eq!(response.model, "ark-test-model");
assert_eq!(response.content, "测试成功");
assert_eq!(response.finish_reason.as_deref(), Some("stop"));
assert_eq!(response.response_id.as_deref(), Some("resp_01"));
assert_eq!(
response.usage,
Some(LlmTokenUsage {
prompt_tokens: 10,
completion_tokens: 6,
total_tokens: 16,
})
);
}
#[tokio::test]
async fn request_text_retries_after_upstream_500() {
let server_url = spawn_mock_server(vec![
MockResponse {
status_line: "500 Internal Server Error",
content_type: "application/json; charset=utf-8",
body: r#"{"error":{"message":"temporary upstream failure"}}"#.to_string(),
extra_headers: Vec::new(),
},
MockResponse {
status_line: "200 OK",
content_type: "application/json; charset=utf-8",
body: r#"{"id":"resp_retry","choices":[{"message":{"content":""},"finish_reason":"stop"}]}"#.to_string(),
extra_headers: Vec::new(),
},
]);
let client = build_test_client(server_url, 1);
let response = client
.request_single_message_text("系统", "用户")
.await
.expect("second attempt should succeed");
assert_eq!(response.content, "第二次成功");
assert_eq!(response.response_id.as_deref(), Some("resp_retry"));
}
#[tokio::test]
async fn stream_text_accumulates_sse_response() {
let server_url = spawn_mock_server(vec![MockResponse {
status_line: "200 OK",
content_type: "text/event-stream; charset=utf-8",
body: concat!(
"data: {\"choices\":[{\"delta\":{\"content\":\"\"}}]}\n\n",
"data: {\"choices\":[{\"delta\":{\"content\":\"\"}}]}\n\n",
"data: {\"choices\":[{\"finish_reason\":\"stop\"}]}\n\n",
"data: [DONE]\n\n"
)
.to_string(),
extra_headers: vec![("x-request-id", "req_stream_01")],
}]);
let client = build_test_client(server_url, 0);
let mut updates = Vec::new();
let response = client
.stream_single_message_text("系统", "用户", |delta| {
updates.push(delta.accumulated_text.clone());
})
.await
.expect("stream_text should succeed");
assert_eq!(updates, vec!["".to_string(), "你好".to_string()]);
assert_eq!(response.content, "你好");
assert_eq!(response.finish_reason.as_deref(), Some("stop"));
assert_eq!(response.response_id.as_deref(), Some("req_stream_01"));
}
fn build_test_client(base_url: String, max_retries: u32) -> LlmClient {
let config = LlmConfig::new(
LlmProvider::Ark,
base_url,
"test-key".to_string(),
"test-model".to_string(),
DEFAULT_REQUEST_TIMEOUT_MS,
max_retries,
1,
)
.expect("config should be valid");
LlmClient::new(config).expect("client should be created")
}
fn spawn_mock_server(responses: Vec<MockResponse>) -> String {
let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind");
let address = listener.local_addr().expect("listener should have addr");
thread::spawn(move || {
for response in responses {
let (mut stream, _) = listener.accept().expect("request should connect");
read_request(&mut stream);
write_response(&mut stream, response);
}
});
format!("http://{address}")
}
fn read_request(stream: &mut std::net::TcpStream) {
stream
.set_read_timeout(Some(StdDuration::from_secs(1)))
.expect("read timeout should be set");
let mut buffer = Vec::new();
let mut chunk = [0_u8; 1024];
let mut expected_total = None;
loop {
match stream.read(&mut chunk) {
Ok(0) => break,
Ok(bytes_read) => {
buffer.extend_from_slice(&chunk[..bytes_read]);
if expected_total.is_none()
&& let Some(header_end) = find_header_end(&buffer)
{
let content_length =
read_content_length(&buffer[..header_end]).unwrap_or(0);
expected_total = Some(header_end + content_length);
}
if let Some(total_bytes) = expected_total
&& buffer.len() >= total_bytes
{
break;
}
}
Err(error)
if error.kind() == std::io::ErrorKind::WouldBlock
|| error.kind() == std::io::ErrorKind::TimedOut =>
{
break;
}
Err(error) => panic!("mock server failed to read request: {error}"),
}
}
}
fn write_response(stream: &mut std::net::TcpStream, response: MockResponse) {
let body = response.body;
let mut raw_response = format!(
"HTTP/1.1 {}\r\nContent-Type: {}\r\nContent-Length: {}\r\nConnection: close\r\n",
response.status_line,
response.content_type,
body.len()
);
for (name, value) in response.extra_headers {
raw_response.push_str(format!("{name}: {value}\r\n").as_str());
}
raw_response.push_str("\r\n");
raw_response.push_str(body.as_str());
stream
.write_all(raw_response.as_bytes())
.expect("mock response should be written");
stream.flush().expect("mock response should flush");
}
fn find_header_end(buffer: &[u8]) -> Option<usize> {
buffer
.windows(4)
.position(|window| window == b"\r\n\r\n")
.map(|index| index + 4)
}
fn read_content_length(headers: &[u8]) -> Option<usize> {
let text = String::from_utf8_lossy(headers);
text.lines().find_map(|line| {
let (name, value) = line.split_once(':')?;
if name.eq_ignore_ascii_case("content-length") {
return value.trim().parse::<usize>().ok();
}
None
})
}
}