Files
Genarrative/server-rs/crates/platform-llm/src/lib.rs
2026-05-08 20:48:29 +08:00

2373 lines
79 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use std::{
env,
error::Error,
fmt, fs,
path::PathBuf,
str as std_str,
sync::atomic::{AtomicU64, Ordering},
time::{Duration, SystemTime, UNIX_EPOCH},
};
use log::{debug, warn};
use reqwest::{Client, StatusCode};
use serde::{Deserialize, Serialize};
use tokio::time::sleep;
pub const DEFAULT_ARK_BASE_URL: &str = "https://ark.cn-beijing.volces.com/api/v3";
pub const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 30_000;
pub const DEFAULT_MAX_RETRIES: u32 = 1;
pub const DEFAULT_RETRY_BACKOFF_MS: u64 = 500;
pub const CHAT_COMPLETIONS_PATH: &str = "/chat/completions";
pub const RESPONSES_PATH: &str = "/responses";
const DEFAULT_LLM_RAW_LOG_DIR: &str = "logs/llm-raw";
static LLM_RAW_LOG_SEQUENCE: AtomicU64 = AtomicU64::new(1);
// 冻结平台来源,避免上层继续散落 provider 字符串。
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LlmProvider {
Ark,
DashScope,
OpenAiCompatible,
}
// 统一收口文本模型网关配置,避免 api-server 和业务模块各自重复解析环境变量。
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct LlmConfig {
provider: LlmProvider,
base_url: String,
api_key: String,
model: String,
request_timeout_ms: u64,
max_retries: u32,
retry_backoff_ms: u64,
official_fallback: bool,
}
// 首版只冻结当前项目已稳定使用的 system/user/assistant 三种消息角色。
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LlmMessageRole {
System,
User,
Assistant,
}
// 单条消息保持 OpenAI 兼容格式,供统一请求体直接序列化。
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct LlmMessage {
pub role: LlmMessageRole,
// 中文注释:保留纯文本字段兼容 Chat Completions 和既有调用Responses 多模态请求读取 content_parts。
pub content: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub content_parts: Vec<LlmMessageContentPart>,
}
// Responses 多模态内容块。字段名按上游 OpenAI 兼容协议保持 snake_case。
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum LlmMessageContentPart {
InputText { text: String },
InputImage { image_url: String },
}
// 文本补全请求冻结为“消息列表 + 可选模型覆盖 + 可选 max_tokens”最小闭环。
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct LlmTextRequest {
pub model: Option<String>,
pub messages: Vec<LlmMessage>,
pub max_tokens: Option<u32>,
pub enable_web_search: bool,
pub protocol: LlmTextProtocol,
pub request_timeout_ms: Option<u64>,
}
// 文本协议必须由业务请求显式选择,避免全局默认模型把不同场景混到同一上游形态。
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum LlmTextProtocol {
ChatCompletions,
Responses,
}
// 上层在流式消费时拿到的是“累计文本 + 当前增量”,避免每层重新自己拼接。
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct LlmStreamDelta {
pub accumulated_text: String,
pub delta_text: String,
pub finish_reason: Option<String>,
}
// 用于保留 token 计数,后续模块可以决定是否写入审计或成本统计。
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct LlmTokenUsage {
pub prompt_tokens: u64,
pub completion_tokens: u64,
pub total_tokens: u64,
}
// 统一文本响应,避免业务层再去解析 choices/message/content。
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct LlmTextResponse {
pub provider: LlmProvider,
pub model: String,
pub content: String,
pub finish_reason: Option<String>,
pub response_id: Option<String>,
pub usage: Option<LlmTokenUsage>,
}
// 将上游错误归一到稳定的领域枚举,后续 api-server 可以直接映射成 HTTP error contract。
#[derive(Debug, PartialEq, Eq)]
pub enum LlmError {
InvalidConfig(String),
InvalidRequest(String),
Timeout { attempts: u32 },
Connectivity { attempts: u32, message: String },
Upstream { status_code: u16, message: String },
StreamUnavailable,
EmptyResponse,
Transport(String),
Deserialize(String),
}
// 平台层只暴露稳定错误分类HTTP status 和业务文案由 api-server 再映射。
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum LlmErrorKind {
InvalidConfig,
InvalidRequest,
Timeout,
Connectivity,
Upstream,
StreamUnavailable,
EmptyResponse,
Transport,
Deserialize,
}
// 统一 OpenAI 兼容文本网关 client。
#[derive(Clone, Debug)]
pub struct LlmClient {
config: LlmConfig,
http_client: Client,
}
#[derive(Serialize)]
#[serde(untagged)]
enum LlmRequestBody {
ChatCompletions(ChatCompletionsRequestBody),
Responses(ResponsesRequestBody),
}
#[derive(Serialize)]
struct ChatCompletionsRequestBody {
model: String,
messages: Vec<ChatCompletionsInputMessage>,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
official_fallback: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
web_search_options: Option<ChatCompletionsWebSearchOptions>,
}
#[derive(Serialize)]
struct ChatCompletionsWebSearchOptions {}
#[derive(Serialize)]
struct ChatCompletionsInputMessage {
role: &'static str,
content: ChatCompletionsInputContent,
}
#[derive(Serialize)]
#[serde(untagged)]
enum ChatCompletionsInputContent {
Text(String),
Parts(Vec<ChatCompletionsInputContentPart>),
}
#[derive(Serialize)]
#[serde(tag = "type")]
enum ChatCompletionsInputContentPart {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ChatCompletionsImageUrl },
}
#[derive(Serialize)]
struct ChatCompletionsImageUrl {
url: String,
}
#[derive(Serialize)]
struct ResponsesRequestBody {
model: String,
stream: bool,
input: Vec<ResponsesInputMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
official_fallback: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
max_output_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<ResponsesWebSearchTool>>,
}
#[derive(Serialize)]
struct ResponsesInputMessage {
role: &'static str,
content: Vec<ResponsesInputContentPart>,
}
#[derive(Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ResponsesInputContentPart {
InputText { text: String },
InputImage { image_url: String },
}
#[derive(Serialize)]
struct ResponsesWebSearchTool {
#[serde(rename = "type")]
tool_type: &'static str,
max_keyword: u8,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct LlmRawFailureInputLog<'a> {
provider: &'static str,
protocol: &'static str,
model: &'a str,
stream: bool,
attempt: u32,
max_tokens: Option<u32>,
messages: &'a [LlmMessage],
}
#[derive(Deserialize)]
#[serde(untagged)]
enum ChatCompletionsResponsePayload {
Direct(ChatCompletionsResponseEnvelope),
Wrapped {
data: ChatCompletionsResponseEnvelope,
},
}
#[derive(Deserialize)]
struct ChatCompletionsResponseEnvelope {
id: Option<String>,
model: Option<String>,
choices: Vec<ChatCompletionsChoice>,
usage: Option<LlmTokenUsage>,
}
#[derive(Deserialize)]
struct ChatCompletionsChoice {
#[serde(default)]
message: Option<ChatCompletionsMessage>,
#[serde(default)]
delta: Option<ChatCompletionsMessage>,
#[serde(default)]
finish_reason: Option<String>,
}
#[derive(Deserialize)]
struct ChatCompletionsMessage {
#[serde(default)]
content: Option<ChatCompletionsContent>,
}
#[derive(Deserialize)]
#[serde(untagged)]
enum ChatCompletionsContent {
Text(String),
Parts(Vec<ChatCompletionsContentPart>),
}
#[derive(Deserialize)]
struct ChatCompletionsContentPart {
#[serde(rename = "type")]
#[allow(dead_code)]
part_type: Option<String>,
#[serde(default)]
text: Option<String>,
}
#[derive(Deserialize)]
struct ResponsesResponseEnvelope {
id: Option<String>,
model: Option<String>,
#[serde(default)]
output_text: Option<String>,
#[serde(default)]
output: Vec<ResponsesOutputItem>,
#[serde(default)]
status: Option<String>,
usage: Option<ResponsesUsage>,
}
#[derive(Deserialize)]
struct ResponsesOutputItem {
#[serde(default)]
content: Vec<ResponsesOutputContentPart>,
}
#[derive(Deserialize)]
struct ResponsesOutputContentPart {
#[serde(rename = "type")]
#[allow(dead_code)]
part_type: Option<String>,
#[serde(default)]
text: Option<String>,
}
#[derive(Deserialize)]
struct ResponsesUsage {
#[serde(default)]
input_tokens: u64,
#[serde(default)]
output_tokens: u64,
#[serde(default)]
total_tokens: u64,
}
struct OpenAiCompatibleSseParser {
buffer: String,
raw_text: String,
protocol: LlmTextProtocol,
}
#[derive(Debug)]
struct ParsedStreamEvent {
delta_text: Option<String>,
finish_reason: Option<String>,
}
impl LlmProvider {
pub fn as_str(&self) -> &'static str {
match self {
Self::Ark => "ark",
Self::DashScope => "dash_scope",
Self::OpenAiCompatible => "openai_compatible",
}
}
}
impl LlmConfig {
#[allow(clippy::too_many_arguments)]
pub fn new(
provider: LlmProvider,
base_url: String,
api_key: String,
model: String,
request_timeout_ms: u64,
max_retries: u32,
retry_backoff_ms: u64,
) -> Result<Self, LlmError> {
let base_url = normalize_non_empty(base_url, "LLM base_url 不能为空")?;
let api_key = normalize_non_empty(api_key, "LLM api_key 不能为空")?;
let model = normalize_non_empty(model, "LLM model 不能为空")?;
if request_timeout_ms == 0 {
return Err(LlmError::InvalidConfig(
"LLM request_timeout_ms 必须大于 0".to_string(),
));
}
Ok(Self {
provider,
base_url,
api_key,
model,
request_timeout_ms,
max_retries,
retry_backoff_ms,
official_fallback: false,
})
}
pub fn with_official_fallback(mut self, official_fallback: bool) -> Self {
self.official_fallback = official_fallback;
self
}
pub fn ark_default(api_key: String, model: String) -> Result<Self, LlmError> {
Self::new(
LlmProvider::Ark,
DEFAULT_ARK_BASE_URL.to_string(),
api_key,
model,
DEFAULT_REQUEST_TIMEOUT_MS,
DEFAULT_MAX_RETRIES,
DEFAULT_RETRY_BACKOFF_MS,
)
}
pub fn provider(&self) -> LlmProvider {
self.provider
}
pub fn base_url(&self) -> &str {
&self.base_url
}
pub fn api_key(&self) -> &str {
&self.api_key
}
pub fn model(&self) -> &str {
&self.model
}
pub fn request_timeout_ms(&self) -> u64 {
self.request_timeout_ms
}
pub fn max_retries(&self) -> u32 {
self.max_retries
}
pub fn retry_backoff_ms(&self) -> u64 {
self.retry_backoff_ms
}
pub fn official_fallback(&self) -> bool {
self.official_fallback
}
pub fn chat_completions_url(&self) -> String {
format!(
"{}/{}",
self.base_url.trim_end_matches('/'),
CHAT_COMPLETIONS_PATH.trim_start_matches('/')
)
}
pub fn responses_url(&self) -> String {
format!(
"{}/{}",
self.base_url.trim_end_matches('/'),
RESPONSES_PATH.trim_start_matches('/')
)
}
}
impl LlmMessage {
pub fn new(role: LlmMessageRole, content: impl Into<String>) -> Self {
Self {
role,
content: content.into(),
content_parts: Vec::new(),
}
}
pub fn system(content: impl Into<String>) -> Self {
Self::new(LlmMessageRole::System, content)
}
pub fn user(content: impl Into<String>) -> Self {
Self::new(LlmMessageRole::User, content)
}
pub fn assistant(content: impl Into<String>) -> Self {
Self::new(LlmMessageRole::Assistant, content)
}
pub fn multimodal(role: LlmMessageRole, content_parts: Vec<LlmMessageContentPart>) -> Self {
let content = content_parts
.iter()
.filter_map(|part| match part {
LlmMessageContentPart::InputText { text } => Some(text.as_str()),
LlmMessageContentPart::InputImage { .. } => None,
})
.collect::<Vec<_>>()
.join("\n");
Self {
role,
content,
content_parts,
}
}
pub fn user_multimodal(content_parts: Vec<LlmMessageContentPart>) -> Self {
Self::multimodal(LlmMessageRole::User, content_parts)
}
pub fn with_image_url(mut self, image_url: impl Into<String>) -> Self {
if self.content_parts.is_empty() && !self.content.trim().is_empty() {
self.content_parts.push(LlmMessageContentPart::InputText {
text: self.content.clone(),
});
}
self.content_parts.push(LlmMessageContentPart::InputImage {
image_url: image_url.into(),
});
self
}
}
impl LlmTextRequest {
pub fn new(messages: Vec<LlmMessage>) -> Self {
Self {
model: None,
messages,
max_tokens: None,
enable_web_search: false,
protocol: LlmTextProtocol::ChatCompletions,
request_timeout_ms: None,
}
}
pub fn single_turn(system_prompt: impl Into<String>, user_prompt: impl Into<String>) -> Self {
Self::new(vec![
LlmMessage::system(system_prompt),
LlmMessage::user(user_prompt),
])
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_web_search(mut self, enabled: bool) -> Self {
self.enable_web_search = enabled;
self
}
pub fn with_responses_api(mut self) -> Self {
self.protocol = LlmTextProtocol::Responses;
self
}
pub fn with_request_timeout_ms(mut self, request_timeout_ms: u64) -> Self {
self.request_timeout_ms = Some(request_timeout_ms);
self
}
fn validate(&self) -> Result<(), LlmError> {
if self.messages.is_empty() {
return Err(LlmError::InvalidRequest(
"LLM messages 不能为空".to_string(),
));
}
for message in &self.messages {
let has_text = !message.content.trim().is_empty()
|| message.content_parts.iter().any(|part| match part {
LlmMessageContentPart::InputText { text } => !text.trim().is_empty(),
LlmMessageContentPart::InputImage { .. } => false,
});
let has_image = message.content_parts.iter().any(|part| match part {
LlmMessageContentPart::InputImage { image_url } => !image_url.trim().is_empty(),
LlmMessageContentPart::InputText { .. } => false,
});
if !has_text && !has_image {
return Err(LlmError::InvalidRequest(
"LLM message content 不能为空".to_string(),
));
}
if message.content_parts.iter().any(|part| match part {
LlmMessageContentPart::InputText { text } => text.trim().is_empty(),
LlmMessageContentPart::InputImage { image_url } => image_url.trim().is_empty(),
}) {
return Err(LlmError::InvalidRequest(
"LLM message content part 不能为空".to_string(),
));
}
}
if let Some(model) = &self.model
&& model.trim().is_empty()
{
return Err(LlmError::InvalidRequest(
"LLM request.model 不能为空字符串".to_string(),
));
}
if let Some(request_timeout_ms) = self.request_timeout_ms
&& request_timeout_ms == 0
{
return Err(LlmError::InvalidRequest(
"LLM request_timeout_ms 必须大于 0".to_string(),
));
}
Ok(())
}
fn resolved_model<'a>(&'a self, fallback_model: &'a str) -> &'a str {
self.model
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty())
.unwrap_or(fallback_model)
}
fn resolved_request_timeout_ms(&self, fallback_timeout_ms: u64) -> u64 {
self.request_timeout_ms
.filter(|value| *value > 0)
.unwrap_or(fallback_timeout_ms)
}
}
impl LlmTextProtocol {
fn as_str(self) -> &'static str {
match self {
Self::ChatCompletions => "chat_completions",
Self::Responses => "responses",
}
}
}
impl fmt::Display for LlmError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidConfig(message)
| Self::InvalidRequest(message)
| Self::Transport(message)
| Self::Deserialize(message) => write!(f, "{message}"),
Self::Timeout { attempts } => {
write!(f, "LLM 请求超时,累计尝试 {attempts} 次")
}
Self::Connectivity { attempts, message } => {
write!(f, "LLM 连接失败,累计尝试 {attempts} 次:{message}")
}
Self::Upstream {
status_code,
message,
} => write!(f, "LLM 上游返回 {status_code}{message}"),
Self::StreamUnavailable => write!(f, "LLM 流式响应体不可用"),
Self::EmptyResponse => write!(f, "LLM 返回内容为空"),
}
}
}
impl Error for LlmError {}
impl LlmError {
pub fn kind(&self) -> LlmErrorKind {
match self {
Self::InvalidConfig(_) => LlmErrorKind::InvalidConfig,
Self::InvalidRequest(_) => LlmErrorKind::InvalidRequest,
Self::Timeout { .. } => LlmErrorKind::Timeout,
Self::Connectivity { .. } => LlmErrorKind::Connectivity,
Self::Upstream { .. } => LlmErrorKind::Upstream,
Self::StreamUnavailable => LlmErrorKind::StreamUnavailable,
Self::EmptyResponse => LlmErrorKind::EmptyResponse,
Self::Transport(_) => LlmErrorKind::Transport,
Self::Deserialize(_) => LlmErrorKind::Deserialize,
}
}
}
impl LlmClient {
pub fn new(config: LlmConfig) -> Result<Self, LlmError> {
let http_client = Client::builder().build().map_err(|error| {
LlmError::InvalidConfig(format!("构建 reqwest client 失败:{error}"))
})?;
Ok(Self {
config,
http_client,
})
}
pub fn config(&self) -> &LlmConfig {
&self.config
}
pub async fn request_text(&self, request: LlmTextRequest) -> Result<LlmTextResponse, LlmError> {
request.validate()?;
let resolved_model = request.resolved_model(self.config.model()).to_string();
let response = self.execute_request(&request, false).await?;
let raw_text = response.text().await.map_err(|error| {
let llm_error = map_stream_read_error(error, 1);
log_llm_raw_failure(
&self.config,
&request,
false,
1,
"read_response_failed",
llm_error.to_string().as_str(),
);
llm_error
})?;
parse_text_response(
request.protocol,
self.config.provider(),
&resolved_model,
raw_text.as_str(),
)
.map_err(|error| {
log_llm_raw_failure(
&self.config,
&request,
false,
1,
"parse_response_failed",
raw_text.as_str(),
);
error
})
}
pub async fn request_single_message_text(
&self,
system_prompt: impl Into<String>,
user_prompt: impl Into<String>,
) -> Result<LlmTextResponse, LlmError> {
self.request_text(LlmTextRequest::single_turn(system_prompt, user_prompt))
.await
}
pub async fn stream_text<F>(
&self,
request: LlmTextRequest,
mut on_delta: F,
) -> Result<LlmTextResponse, LlmError>
where
F: FnMut(&LlmStreamDelta),
{
request.validate()?;
let resolved_model = request.resolved_model(self.config.model()).to_string();
let mut response = self.execute_request(&request, true).await?;
let response_id = response
.headers()
.get("x-request-id")
.and_then(|value| value.to_str().ok())
.map(str::to_string);
let mut parser = OpenAiCompatibleSseParser::new(request.protocol);
let mut accumulated_text = String::new();
let mut finish_reason = None;
let mut undecoded_chunk_bytes = Vec::new();
loop {
let next_chunk = response.chunk().await.map_err(|error| {
let llm_error = map_stream_read_error(error, 1);
log_llm_raw_failure(
&self.config,
&request,
true,
1,
"read_stream_failed",
parser.raw_text().as_str(),
);
llm_error
})?;
let Some(chunk) = next_chunk else {
break;
};
undecoded_chunk_bytes.extend_from_slice(chunk.as_ref());
let (chunk_text, remaining_bytes) =
decode_utf8_stream_chunk(undecoded_chunk_bytes.as_slice()).map_err(|error| {
log_llm_raw_failure(
&self.config,
&request,
true,
1,
"decode_stream_failed",
parser.raw_text().as_str(),
);
error
})?;
undecoded_chunk_bytes = remaining_bytes;
if chunk_text.is_empty() {
continue;
}
let stream_events = parser.push_chunk(chunk_text.as_ref()).map_err(|error| {
log_llm_raw_failure(
&self.config,
&request,
true,
1,
"parse_stream_failed",
parser.raw_text().as_str(),
);
error
})?;
for event in stream_events {
if let Some(delta_text) = event.delta_text
&& !delta_text.is_empty()
{
accumulated_text.push_str(delta_text.as_str());
let update = LlmStreamDelta {
accumulated_text: accumulated_text.clone(),
delta_text,
finish_reason: event.finish_reason.clone(),
};
on_delta(&update);
}
if event.finish_reason.is_some() {
finish_reason = event.finish_reason;
}
}
}
if !undecoded_chunk_bytes.is_empty() {
let trailing_text =
std_str::from_utf8(undecoded_chunk_bytes.as_slice()).map_err(|error| {
log_llm_raw_failure(
&self.config,
&request,
true,
1,
"decode_stream_failed",
parser.raw_text().as_str(),
);
LlmError::Deserialize(format!("解析 LLM 流式 UTF-8 响应失败:{error}"))
})?;
if !trailing_text.is_empty() {
let trailing_events = parser.push_chunk(trailing_text).map_err(|error| {
log_llm_raw_failure(
&self.config,
&request,
true,
1,
"parse_stream_failed",
parser.raw_text().as_str(),
);
error
})?;
for event in trailing_events {
if let Some(delta_text) = event.delta_text
&& !delta_text.is_empty()
{
accumulated_text.push_str(delta_text.as_str());
let update = LlmStreamDelta {
accumulated_text: accumulated_text.clone(),
delta_text,
finish_reason: event.finish_reason.clone(),
};
on_delta(&update);
}
if event.finish_reason.is_some() {
finish_reason = event.finish_reason;
}
}
}
}
let remaining_events = parser.finish().map_err(|error| {
log_llm_raw_failure(
&self.config,
&request,
true,
1,
"parse_stream_failed",
parser.raw_text().as_str(),
);
error
})?;
for event in remaining_events {
if let Some(delta_text) = event.delta_text
&& !delta_text.is_empty()
{
accumulated_text.push_str(delta_text.as_str());
let update = LlmStreamDelta {
accumulated_text: accumulated_text.clone(),
delta_text,
finish_reason: event.finish_reason.clone(),
};
on_delta(&update);
}
if event.finish_reason.is_some() {
finish_reason = event.finish_reason;
}
}
let content = accumulated_text.trim().to_string();
if content.is_empty() {
log_llm_raw_failure(
&self.config,
&request,
true,
1,
"empty_stream_response",
parser.raw_text().as_str(),
);
return Err(LlmError::EmptyResponse);
}
Ok(LlmTextResponse {
provider: self.config.provider(),
model: resolved_model,
content,
finish_reason,
response_id,
usage: None,
})
}
pub async fn stream_single_message_text<F>(
&self,
system_prompt: impl Into<String>,
user_prompt: impl Into<String>,
on_delta: F,
) -> Result<LlmTextResponse, LlmError>
where
F: FnMut(&LlmStreamDelta),
{
self.stream_text(
LlmTextRequest::single_turn(system_prompt, user_prompt),
on_delta,
)
.await
}
async fn execute_request(
&self,
request: &LlmTextRequest,
stream: bool,
) -> Result<reqwest::Response, LlmError> {
let request_body = build_request_body(request, &self.config, stream);
let model = request.resolved_model(self.config.model());
let url = match request.protocol {
LlmTextProtocol::ChatCompletions => self.config.chat_completions_url(),
LlmTextProtocol::Responses => self.config.responses_url(),
};
let max_attempts = self.config.max_retries().saturating_add(1);
for attempt in 1..=max_attempts {
debug!(
"platform-llm request started: provider={}, protocol={}, stream={}, attempt={}, model={}",
self.config.provider().as_str(),
request.protocol.as_str(),
stream,
attempt,
model
);
let send_result = self
.http_client
.post(url.as_str())
.bearer_auth(self.config.api_key())
.json(&request_body)
.timeout(Duration::from_millis(
request.resolved_request_timeout_ms(self.config.request_timeout_ms()),
))
.send()
.await;
match send_result {
Ok(response) if response.status().is_success() => {
debug!(
"platform-llm request succeeded: provider={}, protocol={}, stream={}, attempt={}, status={}",
self.config.provider().as_str(),
request.protocol.as_str(),
stream,
attempt,
response.status().as_u16()
);
return Ok(response);
}
Ok(response) => {
let status = response.status();
let raw_text = response.text().await.unwrap_or_default();
let message = extract_api_error_message(&raw_text, "LLM 上游请求失败");
if should_retry_status(status) && attempt < max_attempts {
warn!(
"platform-llm request retrying after upstream status: provider={}, protocol={}, attempt={}, status={}, message={}",
self.config.provider().as_str(),
request.protocol.as_str(),
attempt,
status.as_u16(),
message
);
self.sleep_before_retry(attempt).await;
continue;
}
log_llm_raw_failure(
&self.config,
request,
stream,
attempt,
"upstream_status_failed",
raw_text.as_str(),
);
return Err(LlmError::Upstream {
status_code: status.as_u16(),
message,
});
}
Err(error) if error.is_timeout() => {
if attempt < max_attempts {
warn!(
"platform-llm request retrying after timeout: provider={}, protocol={}, attempt={}",
self.config.provider().as_str(),
request.protocol.as_str(),
attempt
);
self.sleep_before_retry(attempt).await;
continue;
}
let error = LlmError::Timeout { attempts: attempt };
log_llm_raw_failure(
&self.config,
request,
stream,
attempt,
"request_timeout",
error.to_string().as_str(),
);
return Err(error);
}
Err(error) if error.is_connect() => {
let message = error.to_string();
if attempt < max_attempts {
warn!(
"platform-llm request retrying after connectivity failure: provider={}, protocol={}, attempt={}, error={}",
self.config.provider().as_str(),
request.protocol.as_str(),
attempt,
message
);
self.sleep_before_retry(attempt).await;
continue;
}
let error = LlmError::Connectivity {
attempts: attempt,
message,
};
log_llm_raw_failure(
&self.config,
request,
stream,
attempt,
"request_connectivity_failed",
error.to_string().as_str(),
);
return Err(error);
}
Err(error) => {
let error = LlmError::Transport(error.to_string());
log_llm_raw_failure(
&self.config,
request,
stream,
attempt,
"request_transport_failed",
error.to_string().as_str(),
);
return Err(error);
}
}
}
Err(LlmError::Transport(
"LLM 请求在重试循环后仍未返回结果".to_string(),
))
}
async fn sleep_before_retry(&self, attempt: u32) {
let backoff_ms = self
.config
.retry_backoff_ms()
.saturating_mul(u64::from(attempt));
if backoff_ms > 0 {
sleep(Duration::from_millis(backoff_ms)).await;
}
}
}
impl OpenAiCompatibleSseParser {
fn new(protocol: LlmTextProtocol) -> Self {
Self {
buffer: String::new(),
raw_text: String::new(),
protocol,
}
}
fn push_chunk(&mut self, chunk: &str) -> Result<Vec<ParsedStreamEvent>, LlmError> {
self.raw_text.push_str(chunk);
self.buffer.push_str(chunk);
self.buffer = self.buffer.replace("\r\n", "\n");
self.drain_complete_events()
}
fn raw_text(&self) -> String {
self.raw_text.clone()
}
fn finish(&mut self) -> Result<Vec<ParsedStreamEvent>, LlmError> {
if self.buffer.trim().is_empty() {
return Ok(Vec::new());
}
self.buffer.push_str("\n\n");
self.drain_complete_events()
}
fn drain_complete_events(&mut self) -> Result<Vec<ParsedStreamEvent>, LlmError> {
let mut events = Vec::new();
while let Some(boundary) = self.buffer.find("\n\n") {
let block = self.buffer[..boundary].to_string();
self.buffer = self.buffer[(boundary + 2)..].to_string();
if let Some(event) = parse_sse_event_block(self.protocol, block.as_str())? {
events.push(event);
}
}
Ok(events)
}
}
fn normalize_non_empty(value: String, error_message: &str) -> Result<String, LlmError> {
let trimmed = value.trim().to_string();
if trimmed.is_empty() {
return Err(LlmError::InvalidConfig(error_message.to_string()));
}
Ok(trimmed)
}
fn build_request_body(
request: &LlmTextRequest,
config: &LlmConfig,
stream: bool,
) -> LlmRequestBody {
let fallback_model = config.model();
let official_fallback = config.official_fallback().then_some(true);
match request.protocol {
LlmTextProtocol::ChatCompletions => {
LlmRequestBody::ChatCompletions(ChatCompletionsRequestBody {
model: request.resolved_model(fallback_model).to_string(),
messages: map_chat_completions_input_messages(request.messages.as_slice()),
stream,
official_fallback,
max_tokens: request.max_tokens,
web_search_options: request
.enable_web_search
.then_some(ChatCompletionsWebSearchOptions {}),
})
}
LlmTextProtocol::Responses => LlmRequestBody::Responses(ResponsesRequestBody {
model: request.resolved_model(fallback_model).to_string(),
stream,
input: map_responses_input_messages(request.messages.as_slice()),
official_fallback,
max_output_tokens: request.max_tokens,
tools: request.enable_web_search.then(|| {
vec![ResponsesWebSearchTool {
tool_type: "web_search",
max_keyword: 3,
}]
}),
}),
}
}
fn map_chat_completions_input_messages(
messages: &[LlmMessage],
) -> Vec<ChatCompletionsInputMessage> {
messages
.iter()
.map(|message| ChatCompletionsInputMessage {
role: map_llm_message_role(message.role),
content: map_chat_completions_content(message),
})
.collect()
}
fn map_chat_completions_content(message: &LlmMessage) -> ChatCompletionsInputContent {
if message.content_parts.is_empty() {
return ChatCompletionsInputContent::Text(message.content.clone());
}
ChatCompletionsInputContent::Parts(
message
.content_parts
.iter()
.map(|part| match part {
LlmMessageContentPart::InputText { text } => {
ChatCompletionsInputContentPart::Text { text: text.clone() }
}
LlmMessageContentPart::InputImage { image_url } => {
ChatCompletionsInputContentPart::ImageUrl {
image_url: ChatCompletionsImageUrl {
url: image_url.clone(),
},
}
}
})
.collect(),
)
}
fn map_responses_input_messages(messages: &[LlmMessage]) -> Vec<ResponsesInputMessage> {
messages
.iter()
.map(|message| ResponsesInputMessage {
role: map_llm_message_role(message.role),
content: map_responses_content_parts(message),
})
.collect()
}
fn map_llm_message_role(role: LlmMessageRole) -> &'static str {
match role {
LlmMessageRole::System => "system",
LlmMessageRole::User => "user",
LlmMessageRole::Assistant => "assistant",
}
}
fn map_responses_content_parts(message: &LlmMessage) -> Vec<ResponsesInputContentPart> {
if message.content_parts.is_empty() {
return vec![ResponsesInputContentPart::InputText {
text: message.content.clone(),
}];
}
message
.content_parts
.iter()
.map(|part| match part {
LlmMessageContentPart::InputText { text } => {
ResponsesInputContentPart::InputText { text: text.clone() }
}
LlmMessageContentPart::InputImage { image_url } => {
ResponsesInputContentPart::InputImage {
image_url: image_url.clone(),
}
}
})
.collect()
}
fn log_llm_raw_failure(
config: &LlmConfig,
request: &LlmTextRequest,
stream: bool,
attempt: u32,
failure_stage: &str,
raw_output: &str,
) {
if let Err(error) =
write_llm_raw_failure(config, request, stream, attempt, failure_stage, raw_output)
{
warn!(
"LLM 失败原文日志落盘失败,主错误流程继续执行: failure_stage={}, error={}",
failure_stage, error
);
}
}
fn write_llm_raw_failure(
config: &LlmConfig,
request: &LlmTextRequest,
stream: bool,
attempt: u32,
failure_stage: &str,
raw_output: &str,
) -> Result<(), String> {
let log_dir = env::var("LLM_RAW_LOG_DIR")
.map(PathBuf::from)
.unwrap_or_else(|_| PathBuf::from(DEFAULT_LLM_RAW_LOG_DIR));
fs::create_dir_all(&log_dir).map_err(|error| format!("创建日志目录失败:{error}"))?;
let prefix = build_llm_raw_log_prefix(failure_stage);
let model = request.resolved_model(config.model());
let input_log = LlmRawFailureInputLog {
provider: config.provider().as_str(),
protocol: request.protocol.as_str(),
model,
stream,
attempt,
max_tokens: request.max_tokens,
messages: request.messages.as_slice(),
};
let input_text = serde_json::to_string_pretty(&input_log)
.map_err(|error| format!("序列化模型输入日志失败:{error}"))?;
fs::write(log_dir.join(format!("{prefix}.input.json")), input_text)
.map_err(|error| format!("写入模型输入日志失败:{error}"))?;
fs::write(log_dir.join(format!("{prefix}.output.txt")), raw_output)
.map_err(|error| format!("写入模型输出日志失败:{error}"))?;
Ok(())
}
fn build_llm_raw_log_prefix(failure_stage: &str) -> String {
let millis = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|duration| duration.as_millis())
.unwrap_or_default();
let sequence = LLM_RAW_LOG_SEQUENCE.fetch_add(1, Ordering::Relaxed);
let safe_stage = sanitize_log_file_segment(failure_stage);
format!("{millis}-{}-{sequence:06}-{safe_stage}", std::process::id())
}
fn sanitize_log_file_segment(value: &str) -> String {
let sanitized = value
.chars()
.map(|character| {
if character.is_ascii_alphanumeric() || character == '-' || character == '_' {
character
} else {
'_'
}
})
.collect::<String>();
if sanitized.is_empty() {
"unknown".to_string()
} else {
sanitized
}
}
fn parse_text_response(
protocol: LlmTextProtocol,
provider: LlmProvider,
fallback_model: &str,
raw_text: &str,
) -> Result<LlmTextResponse, LlmError> {
match protocol {
LlmTextProtocol::ChatCompletions => {
parse_chat_completions_response(provider, fallback_model, raw_text)
}
LlmTextProtocol::Responses => parse_responses_response(provider, fallback_model, raw_text),
}
}
fn parse_chat_completions_response(
provider: LlmProvider,
fallback_model: &str,
raw_text: &str,
) -> Result<LlmTextResponse, LlmError> {
let parsed: ChatCompletionsResponsePayload = serde_json::from_str(raw_text)
.map_err(|error| LlmError::Deserialize(format!("解析 LLM JSON 响应失败:{error}")))?;
let parsed = match parsed {
ChatCompletionsResponsePayload::Direct(envelope) => envelope,
ChatCompletionsResponsePayload::Wrapped { data } => data,
};
let first_choice = parsed
.choices
.first()
.ok_or_else(|| LlmError::Deserialize("LLM 响应缺少 choices[0]".to_string()))?;
let content = extract_message_text(first_choice)
.ok_or(LlmError::EmptyResponse)?
.trim()
.to_string();
if content.is_empty() {
return Err(LlmError::EmptyResponse);
}
Ok(LlmTextResponse {
provider,
model: parsed.model.unwrap_or_else(|| fallback_model.to_string()),
content,
finish_reason: first_choice.finish_reason.clone(),
response_id: parsed.id,
usage: parsed.usage,
})
}
fn parse_responses_response(
provider: LlmProvider,
fallback_model: &str,
raw_text: &str,
) -> Result<LlmTextResponse, LlmError> {
let parsed: ResponsesResponseEnvelope = serde_json::from_str(raw_text).map_err(|error| {
LlmError::Deserialize(format!("解析 LLM Responses JSON 响应失败:{error}"))
})?;
let content = extract_responses_text(&parsed)
.ok_or(LlmError::EmptyResponse)?
.trim()
.to_string();
if content.is_empty() {
return Err(LlmError::EmptyResponse);
}
Ok(LlmTextResponse {
provider,
model: parsed.model.unwrap_or_else(|| fallback_model.to_string()),
content,
finish_reason: parsed.status,
response_id: parsed.id,
usage: parsed.usage.map(|usage| LlmTokenUsage {
prompt_tokens: usage.input_tokens,
completion_tokens: usage.output_tokens,
total_tokens: usage.total_tokens,
}),
})
}
fn extract_responses_text(parsed: &ResponsesResponseEnvelope) -> Option<String> {
parsed
.output_text
.as_deref()
.map(str::to_string)
.filter(|text| !text.is_empty())
.or_else(|| {
let text = parsed
.output
.iter()
.flat_map(|item| item.content.iter())
.filter_map(|part| part.text.as_deref())
.collect::<Vec<_>>()
.join("");
if text.is_empty() { None } else { Some(text) }
})
}
fn extract_message_text(choice: &ChatCompletionsChoice) -> Option<String> {
choice
.message
.as_ref()
.and_then(|message| message.content.as_ref())
.and_then(extract_content_text)
.or_else(|| {
choice
.delta
.as_ref()
.and_then(|message| message.content.as_ref())
.and_then(extract_content_text)
})
}
fn extract_content_text(content: &ChatCompletionsContent) -> Option<String> {
match content {
ChatCompletionsContent::Text(text) => Some(text.clone()),
ChatCompletionsContent::Parts(parts) => {
let text = parts
.iter()
.filter_map(|part| part.text.as_deref())
.collect::<Vec<_>>()
.join("");
if text.is_empty() { None } else { Some(text) }
}
}
}
fn decode_utf8_stream_chunk(bytes: &[u8]) -> Result<(String, Vec<u8>), LlmError> {
match std_str::from_utf8(bytes) {
Ok(text) => Ok((text.to_string(), Vec::new())),
Err(error) => {
let valid_up_to = error.valid_up_to();
let Some(_) = error.error_len() else {
let decoded = std_str::from_utf8(&bytes[..valid_up_to]).map_err(|inner_error| {
LlmError::Deserialize(format!("解析 LLM 流式 UTF-8 响应失败:{inner_error}"))
})?;
return Ok((decoded.to_string(), bytes[valid_up_to..].to_vec()));
};
Err(LlmError::Deserialize(format!(
"解析 LLM 流式 UTF-8 响应失败:{error}"
)))
}
}
}
fn parse_sse_event_block(
protocol: LlmTextProtocol,
block: &str,
) -> Result<Option<ParsedStreamEvent>, LlmError> {
let data_lines = block
.lines()
.filter_map(|line| line.trim().strip_prefix("data:"))
.map(str::trim_start)
.collect::<Vec<_>>();
if data_lines.is_empty() {
return Ok(None);
}
let data = data_lines.join("\n");
if data.trim().is_empty() || data.trim() == "[DONE]" {
return Ok(None);
}
if protocol == LlmTextProtocol::Responses {
return parse_responses_sse_event(data.as_str());
}
let parsed: ChatCompletionsResponseEnvelope = serde_json::from_str(data.as_str())
.map_err(|error| LlmError::Deserialize(format!("解析 LLM SSE 事件失败:{error}")))?;
let first_choice = parsed
.choices
.first()
.ok_or_else(|| LlmError::Deserialize("LLM SSE 响应缺少 choices[0]".to_string()))?;
Ok(Some(ParsedStreamEvent {
delta_text: extract_message_text(first_choice),
finish_reason: first_choice.finish_reason.clone(),
}))
}
fn parse_responses_sse_event(data: &str) -> Result<Option<ParsedStreamEvent>, LlmError> {
let parsed: serde_json::Value = serde_json::from_str(data).map_err(|error| {
LlmError::Deserialize(format!("解析 LLM Responses SSE 事件失败:{error}"))
})?;
let event_type = parsed
.get("type")
.and_then(serde_json::Value::as_str)
.unwrap_or_default();
match event_type {
"response.output_text.delta" => Ok(Some(ParsedStreamEvent {
delta_text: parsed
.get("delta")
.and_then(serde_json::Value::as_str)
.map(str::to_string),
finish_reason: None,
})),
"response.completed" => Ok(Some(ParsedStreamEvent {
delta_text: None,
finish_reason: Some("completed".to_string()),
})),
"response.failed" | "error" => {
let message = parsed
.get("error")
.and_then(|error| error.get("message"))
.and_then(serde_json::Value::as_str)
.or_else(|| parsed.get("message").and_then(serde_json::Value::as_str))
.unwrap_or("LLM Responses SSE 返回失败事件")
.to_string();
Err(LlmError::Upstream {
status_code: 502,
message,
})
}
_ => Ok(None),
}
}
fn should_retry_status(status: StatusCode) -> bool {
status == StatusCode::REQUEST_TIMEOUT
|| status == StatusCode::TOO_MANY_REQUESTS
|| status.is_server_error()
}
fn extract_api_error_message(raw_text: &str, fallback_message: &str) -> String {
let trimmed = raw_text.trim();
if trimmed.is_empty() {
return fallback_message.to_string();
}
let parsed = serde_json::from_str::<serde_json::Value>(trimmed);
if let Ok(value) = parsed {
if let Some(message) = value
.get("error")
.and_then(|error| error.get("message"))
.and_then(serde_json::Value::as_str)
.map(str::trim)
.filter(|message| !message.is_empty())
{
return message.to_string();
}
if let Some(message) = value
.get("message")
.and_then(serde_json::Value::as_str)
.map(str::trim)
.filter(|message| !message.is_empty())
{
return message.to_string();
}
}
trimmed.to_string()
}
fn map_stream_read_error(error: reqwest::Error, attempts: u32) -> LlmError {
if error.is_timeout() {
return LlmError::Timeout { attempts };
}
if error.is_connect() {
return LlmError::Connectivity {
attempts,
message: error.to_string(),
};
}
LlmError::Transport(error.to_string())
}
#[cfg(test)]
mod tests {
use std::{
io::{Read, Write},
net::TcpListener,
thread,
time::Duration as StdDuration,
};
use super::*;
#[test]
fn llm_error_kind_is_stable_for_adapter_mapping() {
assert_eq!(
LlmError::InvalidConfig("bad config".to_string()).kind(),
LlmErrorKind::InvalidConfig
);
assert_eq!(
LlmError::Upstream {
status_code: 429,
message: "too many requests".to_string(),
}
.kind(),
LlmErrorKind::Upstream
);
assert_eq!(LlmError::EmptyResponse.kind(), LlmErrorKind::EmptyResponse);
}
struct MockResponse {
status_line: &'static str,
content_type: &'static str,
body: String,
extra_headers: Vec<(&'static str, &'static str)>,
}
#[test]
fn llm_config_rejects_blank_api_key() {
let error = LlmConfig::new(
LlmProvider::Ark,
DEFAULT_ARK_BASE_URL.to_string(),
" ".to_string(),
"model-a".to_string(),
DEFAULT_REQUEST_TIMEOUT_MS,
DEFAULT_MAX_RETRIES,
DEFAULT_RETRY_BACKOFF_MS,
)
.expect_err("blank api key should be rejected");
assert_eq!(
error,
LlmError::InvalidConfig("LLM api_key 不能为空".to_string())
);
}
#[test]
fn llm_chat_completion_url_normalizes_trailing_slash() {
let config = LlmConfig::new(
LlmProvider::OpenAiCompatible,
"https://example.com/base///".to_string(),
"secret".to_string(),
"model-a".to_string(),
DEFAULT_REQUEST_TIMEOUT_MS,
DEFAULT_MAX_RETRIES,
DEFAULT_RETRY_BACKOFF_MS,
)
.expect("config should be valid");
assert_eq!(
config.chat_completions_url(),
"https://example.com/base/chat/completions"
);
assert_eq!(config.responses_url(), "https://example.com/base/responses");
}
#[test]
fn llm_config_official_fallback_is_opt_in() {
let config = LlmConfig::new(
LlmProvider::OpenAiCompatible,
"https://example.com/base".to_string(),
"secret".to_string(),
"model-a".to_string(),
DEFAULT_REQUEST_TIMEOUT_MS,
DEFAULT_MAX_RETRIES,
DEFAULT_RETRY_BACKOFF_MS,
)
.expect("config should be valid");
assert!(!config.official_fallback());
assert!(config.with_official_fallback(true).official_fallback());
}
#[tokio::test]
async fn request_text_sends_official_fallback_for_openai_compatible_clients() {
let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind");
let address = listener.local_addr().expect("listener should have addr");
let server_handle = thread::spawn(move || {
let (mut stream, _) = listener.accept().expect("request should connect");
let request_text = read_request(&mut stream);
write_response(
&mut stream,
MockResponse {
status_line: "200 OK",
content_type: "application/json; charset=utf-8",
body: r#"{"id":"resp_openai_compatible","model":"gpt-5","output_text":"","status":"completed"}"#.to_string(),
extra_headers: Vec::new(),
},
);
request_text
});
let config = LlmConfig::new(
LlmProvider::OpenAiCompatible,
format!("http://{address}"),
"test-key".to_string(),
"gpt-5".to_string(),
DEFAULT_REQUEST_TIMEOUT_MS,
0,
1,
)
.expect("config should be valid")
.with_official_fallback(true);
let client = LlmClient::new(config).expect("client should be created");
let response = client
.request_text(LlmTextRequest::single_turn("系统", "用户").with_responses_api())
.await
.expect("request_text should succeed");
let request_text = server_handle.join().expect("server thread should join");
let request_body = request_text
.split("\r\n\r\n")
.nth(1)
.expect("request body should exist");
let request_json: serde_json::Value =
serde_json::from_str(request_body).expect("request body should be json");
assert_eq!(response.content, "兼容成功");
assert_eq!(request_json["official_fallback"], serde_json::json!(true));
}
#[test]
fn sse_parser_handles_split_chunks_and_done_marker() {
let mut parser = OpenAiCompatibleSseParser::new(LlmTextProtocol::ChatCompletions);
let events_a = parser
.push_chunk("data: {\"choices\":[{\"delta\":{\"content\":\"\"}}]}\r\n\r\n")
.expect("first chunk should parse");
let events_b = parser
.push_chunk("data: {\"choices\":[{\"delta\":{\"content\":\"\"},\"finish_reason\":\"stop\"}]}\n\ndata: [DONE]\n\n")
.expect("second chunk should parse");
assert_eq!(events_a.len(), 1);
assert_eq!(events_a[0].delta_text.as_deref(), Some(""));
assert_eq!(events_b.len(), 1);
assert_eq!(events_b[0].delta_text.as_deref(), Some(""));
assert_eq!(events_b[0].finish_reason.as_deref(), Some("stop"));
}
#[test]
fn responses_sse_parser_only_emits_output_text_delta() {
let mut parser = OpenAiCompatibleSseParser::new(LlmTextProtocol::Responses);
let events = parser
.push_chunk(concat!(
"data: {\"type\":\"response.created\"}\n\n",
"data: {\"type\":\"response.output_text.delta\",\"delta\":\"\"}\n\n",
"data: {\"type\":\"response.output_text.delta\",\"delta\":\"\"}\n\n",
"data: {\"type\":\"response.completed\"}\n\n",
))
.expect("responses stream should parse");
assert_eq!(events.len(), 3);
assert_eq!(events[0].delta_text.as_deref(), Some(""));
assert_eq!(events[1].delta_text.as_deref(), Some(""));
assert_eq!(events[2].finish_reason.as_deref(), Some("completed"));
}
#[test]
fn decode_utf8_stream_chunk_preserves_incomplete_multibyte_suffix() {
let full_bytes = "你好".as_bytes();
let first_result = decode_utf8_stream_chunk(&full_bytes[..2])
.expect("incomplete utf-8 chunk should be buffered");
assert_eq!(first_result.0, "");
assert_eq!(first_result.1, full_bytes[..2].to_vec());
let mut combined = first_result.1;
combined.extend_from_slice(&full_bytes[2..]);
let second_result = decode_utf8_stream_chunk(combined.as_slice())
.expect("completed utf-8 bytes should decode");
assert_eq!(second_result.0, "你好");
assert!(second_result.1.is_empty());
}
#[tokio::test]
async fn request_text_parses_non_stream_response() {
let server_url = spawn_mock_server(vec![MockResponse {
status_line: "200 OK",
content_type: "application/json; charset=utf-8",
body: r#"{"id":"resp_01","model":"ark-test-model","choices":[{"message":{"content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":6,"total_tokens":16}}"#.to_string(),
extra_headers: Vec::new(),
}]);
let client = build_test_client(server_url, 0);
let response = client
.request_single_message_text("系统", "用户")
.await
.expect("request_text should succeed");
assert_eq!(response.provider, LlmProvider::Ark);
assert_eq!(response.model, "ark-test-model");
assert_eq!(response.content, "测试成功");
assert_eq!(response.finish_reason.as_deref(), Some("stop"));
assert_eq!(response.response_id.as_deref(), Some("resp_01"));
assert_eq!(
response.usage,
Some(LlmTokenUsage {
prompt_tokens: 10,
completion_tokens: 6,
total_tokens: 16,
})
);
}
#[tokio::test]
async fn request_text_retries_after_upstream_500() {
let server_url = spawn_mock_server(vec![
MockResponse {
status_line: "500 Internal Server Error",
content_type: "application/json; charset=utf-8",
body: r#"{"error":{"message":"temporary upstream failure"}}"#.to_string(),
extra_headers: Vec::new(),
},
MockResponse {
status_line: "200 OK",
content_type: "application/json; charset=utf-8",
body: r#"{"id":"resp_retry","choices":[{"message":{"content":""},"finish_reason":"stop"}]}"#.to_string(),
extra_headers: Vec::new(),
},
]);
let client = build_test_client(server_url, 1);
let response = client
.request_single_message_text("系统", "用户")
.await
.expect("second attempt should succeed");
assert_eq!(response.content, "第二次成功");
assert_eq!(response.response_id.as_deref(), Some("resp_retry"));
}
#[tokio::test]
async fn request_text_uses_request_level_timeout_override() {
let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind");
let address = listener.local_addr().expect("listener should have addr");
thread::spawn(move || {
let (mut stream, _) = listener.accept().expect("request should connect");
let _ = read_request(&mut stream);
thread::sleep(StdDuration::from_millis(200));
write_response(
&mut stream,
MockResponse {
status_line: "200 OK",
content_type: "application/json; charset=utf-8",
body:
r#"{"choices":[{"message":{"content":"too late"},"finish_reason":"stop"}]}"#
.to_string(),
extra_headers: Vec::new(),
},
);
});
let config = LlmConfig::new(
LlmProvider::Ark,
format!("http://{address}"),
"test-key".to_string(),
"test-model".to_string(),
10_000,
0,
1,
)
.expect("config should be valid");
let client = LlmClient::new(config).expect("client should be created");
let error = client
.request_text(LlmTextRequest::single_turn("系统", "用户").with_request_timeout_ms(20))
.await
.expect_err("request override should timeout before the global timeout");
assert_eq!(error, LlmError::Timeout { attempts: 1 });
}
#[tokio::test]
async fn request_text_sends_web_search_options_when_enabled() {
let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind");
let address = listener.local_addr().expect("listener should have addr");
let server_handle = thread::spawn(move || {
let (mut stream, _) = listener.accept().expect("request should connect");
let request_text = read_request(&mut stream);
write_response(
&mut stream,
MockResponse {
status_line: "200 OK",
content_type: "application/json; charset=utf-8",
body: r#"{"id":"resp_search","model":"test-model","choices":[{"message":{"content":""},"finish_reason":"stop"}]}"#.to_string(),
extra_headers: Vec::new(),
},
);
request_text
});
let client = build_test_client(format!("http://{address}"), 0);
let response = client
.request_text(
LlmTextRequest::single_turn("系统", "用户")
.with_web_search(true)
.with_max_tokens(128),
)
.await
.expect("request_text should succeed");
let request_text = server_handle.join().expect("server thread should join");
let request_body = request_text
.split("\r\n\r\n")
.nth(1)
.expect("request body should exist");
let request_json: serde_json::Value =
serde_json::from_str(request_body).expect("request body should be json");
assert_eq!(response.content, "搜索成功");
assert_eq!(request_json["web_search_options"], serde_json::json!({}));
assert!(request_json.get("official_fallback").is_none());
}
#[tokio::test]
async fn chat_completions_multimodal_request_sends_text_and_image_url_parts() {
let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind");
let address = listener.local_addr().expect("listener should have addr");
let server_handle = thread::spawn(move || {
let (mut stream, _) = listener.accept().expect("request should connect");
let request_text = read_request(&mut stream);
write_response(
&mut stream,
MockResponse {
status_line: "200 OK",
content_type: "application/json; charset=utf-8",
body: r#"{"id":"chat_multimodal","model":"gpt-4o-mini","choices":[{"message":{"content":"{\"levelName\":\"雨夜猫街\"}"},"finish_reason":"stop"}]}"#.to_string(),
extra_headers: Vec::new(),
},
);
request_text
});
let config = LlmConfig::new(
LlmProvider::OpenAiCompatible,
format!("http://{address}"),
"test-key".to_string(),
"gpt-4o-mini".to_string(),
DEFAULT_REQUEST_TIMEOUT_MS,
0,
1,
)
.expect("config should be valid")
.with_official_fallback(true);
let client = LlmClient::new(config).expect("client should be created");
let response = client
.request_text(LlmTextRequest::new(vec![
LlmMessage::system("你是拼图关卡命名编辑"),
LlmMessage::user_multimodal(vec![
LlmMessageContentPart::InputText {
text: "画面描述:一只猫在雨夜灯牌下回头。".to_string(),
},
LlmMessageContentPart::InputImage {
image_url: "data:image/png;base64,abcd".to_string(),
},
]),
]))
.await
.expect("request_text should succeed");
let request_text = server_handle.join().expect("server thread should join");
let request_line = request_text.lines().next().unwrap_or_default();
let request_body = request_text
.split("\r\n\r\n")
.nth(1)
.expect("request body should exist");
let request_json: serde_json::Value =
serde_json::from_str(request_body).expect("request body should be json");
assert!(request_line.contains("POST /chat/completions HTTP/1.1"));
assert_eq!(response.model, "gpt-4o-mini");
assert_eq!(response.content, r#"{"levelName":"雨夜猫街"}"#);
assert_eq!(request_json["official_fallback"], serde_json::json!(true));
assert_eq!(
request_json["messages"][1]["content"],
serde_json::json!([
{ "type": "text", "text": "画面描述:一只猫在雨夜灯牌下回头。" },
{ "type": "image_url", "image_url": { "url": "data:image/png;base64,abcd" } }
])
);
}
#[tokio::test]
async fn request_text_sends_responses_body_with_web_search_tool() {
let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind");
let address = listener.local_addr().expect("listener should have addr");
let server_handle = thread::spawn(move || {
let (mut stream, _) = listener.accept().expect("request should connect");
let request_text = read_request(&mut stream);
write_response(
&mut stream,
MockResponse {
status_line: "200 OK",
content_type: "application/json; charset=utf-8",
body: r#"{"id":"resp_responses","model":"deepseek-v3-2-251201","output_text":"Responses ","status":"completed","usage":{"input_tokens":9,"output_tokens":4,"total_tokens":13}}"#.to_string(),
extra_headers: Vec::new(),
},
);
request_text
});
let client = build_test_client(format!("http://{address}"), 0);
let response = client
.request_text(
LlmTextRequest::single_turn("系统", "用户")
.with_model("deepseek-v3-2-251201")
.with_responses_api()
.with_web_search(true)
.with_max_tokens(128),
)
.await
.expect("responses request_text should succeed");
let request_text = server_handle.join().expect("server thread should join");
let request_line = request_text.lines().next().unwrap_or_default();
let request_body = request_text
.split("\r\n\r\n")
.nth(1)
.expect("request body should exist");
let request_json: serde_json::Value =
serde_json::from_str(request_body).expect("request body should be json");
assert!(request_line.contains("POST /responses HTTP/1.1"));
assert_eq!(response.content, "Responses 成功");
assert_eq!(response.model, "deepseek-v3-2-251201");
assert_eq!(
response.usage,
Some(LlmTokenUsage {
prompt_tokens: 9,
completion_tokens: 4,
total_tokens: 13,
})
);
assert_eq!(
request_json["model"],
serde_json::json!("deepseek-v3-2-251201")
);
assert_eq!(request_json["stream"], serde_json::json!(false));
assert_eq!(
request_json["tools"],
serde_json::json!([{ "type": "web_search", "max_keyword": 3 }])
);
assert!(request_json.get("official_fallback").is_none());
assert_eq!(
request_json["input"][0]["content"][0],
serde_json::json!({ "type": "input_text", "text": "系统" })
);
}
#[tokio::test]
async fn responses_multimodal_request_sends_input_text_and_input_image() {
let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind");
let address = listener.local_addr().expect("listener should have addr");
let server_handle = thread::spawn(move || {
let (mut stream, _) = listener.accept().expect("request should connect");
let request_text = read_request(&mut stream);
write_response(
&mut stream,
MockResponse {
status_line: "200 OK",
content_type: "application/json; charset=utf-8",
body: r#"{"id":"resp_multimodal","model":"gpt-5","output_text":"","status":"completed"}"#.to_string(),
extra_headers: Vec::new(),
},
);
request_text
});
let client = build_test_client(format!("http://{address}"), 0);
let response = client
.request_text(
LlmTextRequest::new(vec![
LlmMessage::system("你是创意互动内容生成 Agent"),
LlmMessage::user_multimodal(vec![
LlmMessageContentPart::InputText {
text: "把这张图做成拼图".to_string(),
},
LlmMessageContentPart::InputImage {
image_url: "https://example.com/ref.png".to_string(),
},
]),
])
.with_model("gpt-5")
.with_responses_api(),
)
.await
.expect("responses multimodal request_text should succeed");
let request_text = server_handle.join().expect("server thread should join");
let request_body = request_text
.split("\r\n\r\n")
.nth(1)
.expect("request body should exist");
let request_json: serde_json::Value =
serde_json::from_str(request_body).expect("request body should be json");
assert_eq!(response.model, "gpt-5");
assert_eq!(request_json["model"], serde_json::json!("gpt-5"));
assert!(request_json.get("official_fallback").is_none());
assert_eq!(
request_json["input"][1]["content"],
serde_json::json!([
{ "type": "input_text", "text": "把这张图做成拼图" },
{ "type": "input_image", "image_url": "https://example.com/ref.png" }
])
);
}
#[tokio::test]
async fn stream_text_accumulates_sse_response() {
let server_url = spawn_mock_server(vec![MockResponse {
status_line: "200 OK",
content_type: "text/event-stream; charset=utf-8",
body: concat!(
"data: {\"choices\":[{\"delta\":{\"content\":\"\"}}]}\n\n",
"data: {\"choices\":[{\"delta\":{\"content\":\"\"}}]}\n\n",
"data: {\"choices\":[{\"finish_reason\":\"stop\"}]}\n\n",
"data: [DONE]\n\n"
)
.to_string(),
extra_headers: vec![("x-request-id", "req_stream_01")],
}]);
let client = build_test_client(server_url, 0);
let mut updates = Vec::new();
let response = client
.stream_single_message_text("系统", "用户", |delta| {
updates.push(delta.accumulated_text.clone());
})
.await
.expect("stream_text should succeed");
assert_eq!(updates, vec!["".to_string(), "你好".to_string()]);
assert_eq!(response.content, "你好");
assert_eq!(response.finish_reason.as_deref(), Some("stop"));
assert_eq!(response.response_id.as_deref(), Some("req_stream_01"));
}
#[tokio::test]
async fn stream_text_accumulates_responses_sse_response() {
let server_url = spawn_mock_server(vec![MockResponse {
status_line: "200 OK",
content_type: "text/event-stream; charset=utf-8",
body: concat!(
"data: {\"type\":\"response.output_text.delta\",\"delta\":\"\"}\n\n",
"data: {\"type\":\"response.output_text.delta\",\"delta\":\"\"}\n\n",
"data: {\"type\":\"response.completed\"}\n\n"
)
.to_string(),
extra_headers: vec![("x-request-id", "req_responses_stream_01")],
}]);
let client = build_test_client(server_url, 0);
let mut updates = Vec::new();
let response = client
.stream_text(
LlmTextRequest::single_turn("系统", "用户").with_responses_api(),
|delta| {
updates.push(delta.accumulated_text.clone());
},
)
.await
.expect("responses stream_text should succeed");
assert_eq!(updates, vec!["".to_string(), "你好".to_string()]);
assert_eq!(response.content, "你好");
assert_eq!(response.finish_reason.as_deref(), Some("completed"));
assert_eq!(
response.response_id.as_deref(),
Some("req_responses_stream_01")
);
}
#[tokio::test]
async fn request_text_writes_raw_failure_logs_after_parse_error() {
let log_dir = std::env::temp_dir().join(format!(
"platform-llm-raw-log-test-{}",
build_llm_raw_log_prefix("parse_error")
));
unsafe {
std::env::set_var("LLM_RAW_LOG_DIR", &log_dir);
}
let server_url = spawn_mock_server(vec![MockResponse {
status_line: "200 OK",
content_type: "application/json; charset=utf-8",
body: "不是合法 JSON".to_string(),
extra_headers: Vec::new(),
}]);
let client = build_test_client(server_url, 0);
let error = client
.request_single_message_text("系统原文", "用户原文")
.await
.expect_err("invalid json should fail");
assert!(matches!(error, LlmError::Deserialize(_)));
let mut input_logs = Vec::new();
let mut output_logs = Vec::new();
for entry in fs::read_dir(&log_dir).expect("log dir should exist") {
let path = entry.expect("log entry should be readable").path();
let file_name = path
.file_name()
.and_then(|name| name.to_str())
.unwrap_or_default()
.to_string();
if file_name.ends_with(".input.json") {
input_logs.push(path);
} else if file_name.ends_with(".output.txt") {
output_logs.push(path);
}
}
assert_eq!(input_logs.len(), 1);
assert_eq!(output_logs.len(), 1);
let input_text = fs::read_to_string(&input_logs[0]).expect("input log should be readable");
let output_text =
fs::read_to_string(&output_logs[0]).expect("output log should be readable");
assert!(input_text.contains("系统原文"));
assert!(input_text.contains("用户原文"));
assert!(!input_text.contains("test-key"));
assert_eq!(output_text, "不是合法 JSON");
unsafe {
std::env::remove_var("LLM_RAW_LOG_DIR");
}
fs::remove_dir_all(log_dir).expect("log dir should be removed");
}
fn build_test_client(base_url: String, max_retries: u32) -> LlmClient {
let config = LlmConfig::new(
LlmProvider::Ark,
base_url,
"test-key".to_string(),
"test-model".to_string(),
DEFAULT_REQUEST_TIMEOUT_MS,
max_retries,
1,
)
.expect("config should be valid");
LlmClient::new(config).expect("client should be created")
}
fn spawn_mock_server(responses: Vec<MockResponse>) -> String {
let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind");
let address = listener.local_addr().expect("listener should have addr");
thread::spawn(move || {
for response in responses {
let (mut stream, _) = listener.accept().expect("request should connect");
read_request(&mut stream);
write_response(&mut stream, response);
}
});
format!("http://{address}")
}
fn read_request(stream: &mut std::net::TcpStream) -> String {
stream
.set_read_timeout(Some(StdDuration::from_secs(1)))
.expect("read timeout should be set");
let mut buffer = Vec::new();
let mut chunk = [0_u8; 1024];
let mut expected_total = None;
loop {
match stream.read(&mut chunk) {
Ok(0) => break,
Ok(bytes_read) => {
buffer.extend_from_slice(&chunk[..bytes_read]);
if expected_total.is_none()
&& let Some(header_end) = find_header_end(&buffer)
{
let content_length =
read_content_length(&buffer[..header_end]).unwrap_or(0);
expected_total = Some(header_end + content_length);
}
if let Some(total_bytes) = expected_total
&& buffer.len() >= total_bytes
{
break;
}
}
Err(error)
if error.kind() == std::io::ErrorKind::WouldBlock
|| error.kind() == std::io::ErrorKind::TimedOut =>
{
break;
}
Err(error) => panic!("mock server failed to read request: {error}"),
}
}
String::from_utf8_lossy(buffer.as_slice()).to_string()
}
fn write_response(stream: &mut std::net::TcpStream, response: MockResponse) {
let body = response.body;
let mut raw_response = format!(
"HTTP/1.1 {}\r\nContent-Type: {}\r\nContent-Length: {}\r\nConnection: close\r\n",
response.status_line,
response.content_type,
body.len()
);
for (name, value) in response.extra_headers {
raw_response.push_str(format!("{name}: {value}\r\n").as_str());
}
raw_response.push_str("\r\n");
raw_response.push_str(body.as_str());
stream
.write_all(raw_response.as_bytes())
.expect("mock response should be written");
stream.flush().expect("mock response should flush");
}
fn find_header_end(buffer: &[u8]) -> Option<usize> {
buffer
.windows(4)
.position(|window| window == b"\r\n\r\n")
.map(|index| index + 4)
}
fn read_content_length(headers: &[u8]) -> Option<usize> {
let text = String::from_utf8_lossy(headers);
text.lines().find_map(|line| {
let (name, value) = line.split_once(':')?;
if name.eq_ignore_ascii_case("content-length") {
return value.trim().parse::<usize>().ok();
}
None
})
}
}