2373 lines
79 KiB
Rust
2373 lines
79 KiB
Rust
use std::{
|
||
env,
|
||
error::Error,
|
||
fmt, fs,
|
||
path::PathBuf,
|
||
str as std_str,
|
||
sync::atomic::{AtomicU64, Ordering},
|
||
time::{Duration, SystemTime, UNIX_EPOCH},
|
||
};
|
||
|
||
use log::{debug, warn};
|
||
use reqwest::{Client, StatusCode};
|
||
use serde::{Deserialize, Serialize};
|
||
use tokio::time::sleep;
|
||
|
||
pub const DEFAULT_ARK_BASE_URL: &str = "https://ark.cn-beijing.volces.com/api/v3";
|
||
pub const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 30_000;
|
||
pub const DEFAULT_MAX_RETRIES: u32 = 1;
|
||
pub const DEFAULT_RETRY_BACKOFF_MS: u64 = 500;
|
||
pub const CHAT_COMPLETIONS_PATH: &str = "/chat/completions";
|
||
pub const RESPONSES_PATH: &str = "/responses";
|
||
const DEFAULT_LLM_RAW_LOG_DIR: &str = "logs/llm-raw";
|
||
|
||
static LLM_RAW_LOG_SEQUENCE: AtomicU64 = AtomicU64::new(1);
|
||
|
||
// 冻结平台来源,避免上层继续散落 provider 字符串。
|
||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||
#[serde(rename_all = "snake_case")]
|
||
pub enum LlmProvider {
|
||
Ark,
|
||
DashScope,
|
||
OpenAiCompatible,
|
||
}
|
||
|
||
// 统一收口文本模型网关配置,避免 api-server 和业务模块各自重复解析环境变量。
|
||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||
pub struct LlmConfig {
|
||
provider: LlmProvider,
|
||
base_url: String,
|
||
api_key: String,
|
||
model: String,
|
||
request_timeout_ms: u64,
|
||
max_retries: u32,
|
||
retry_backoff_ms: u64,
|
||
official_fallback: bool,
|
||
}
|
||
|
||
// 首版只冻结当前项目已稳定使用的 system/user/assistant 三种消息角色。
|
||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||
#[serde(rename_all = "snake_case")]
|
||
pub enum LlmMessageRole {
|
||
System,
|
||
User,
|
||
Assistant,
|
||
}
|
||
|
||
// 单条消息保持 OpenAI 兼容格式,供统一请求体直接序列化。
|
||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||
pub struct LlmMessage {
|
||
pub role: LlmMessageRole,
|
||
// 中文注释:保留纯文本字段兼容 Chat Completions 和既有调用;Responses 多模态请求读取 content_parts。
|
||
pub content: String,
|
||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||
pub content_parts: Vec<LlmMessageContentPart>,
|
||
}
|
||
|
||
// Responses 多模态内容块。字段名按上游 OpenAI 兼容协议保持 snake_case。
|
||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||
#[serde(tag = "type", rename_all = "snake_case")]
|
||
pub enum LlmMessageContentPart {
|
||
InputText { text: String },
|
||
InputImage { image_url: String },
|
||
}
|
||
|
||
// 文本补全请求冻结为“消息列表 + 可选模型覆盖 + 可选 max_tokens”最小闭环。
|
||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||
pub struct LlmTextRequest {
|
||
pub model: Option<String>,
|
||
pub messages: Vec<LlmMessage>,
|
||
pub max_tokens: Option<u32>,
|
||
pub enable_web_search: bool,
|
||
pub protocol: LlmTextProtocol,
|
||
pub request_timeout_ms: Option<u64>,
|
||
}
|
||
|
||
// 文本协议必须由业务请求显式选择,避免全局默认模型把不同场景混到同一上游形态。
|
||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||
pub enum LlmTextProtocol {
|
||
ChatCompletions,
|
||
Responses,
|
||
}
|
||
|
||
// 上层在流式消费时拿到的是“累计文本 + 当前增量”,避免每层重新自己拼接。
|
||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||
pub struct LlmStreamDelta {
|
||
pub accumulated_text: String,
|
||
pub delta_text: String,
|
||
pub finish_reason: Option<String>,
|
||
}
|
||
|
||
// 用于保留 token 计数,后续模块可以决定是否写入审计或成本统计。
|
||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||
pub struct LlmTokenUsage {
|
||
pub prompt_tokens: u64,
|
||
pub completion_tokens: u64,
|
||
pub total_tokens: u64,
|
||
}
|
||
|
||
// 统一文本响应,避免业务层再去解析 choices/message/content。
|
||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||
pub struct LlmTextResponse {
|
||
pub provider: LlmProvider,
|
||
pub model: String,
|
||
pub content: String,
|
||
pub finish_reason: Option<String>,
|
||
pub response_id: Option<String>,
|
||
pub usage: Option<LlmTokenUsage>,
|
||
}
|
||
|
||
// 将上游错误归一到稳定的领域枚举,后续 api-server 可以直接映射成 HTTP error contract。
|
||
#[derive(Debug, PartialEq, Eq)]
|
||
pub enum LlmError {
|
||
InvalidConfig(String),
|
||
InvalidRequest(String),
|
||
Timeout { attempts: u32 },
|
||
Connectivity { attempts: u32, message: String },
|
||
Upstream { status_code: u16, message: String },
|
||
StreamUnavailable,
|
||
EmptyResponse,
|
||
Transport(String),
|
||
Deserialize(String),
|
||
}
|
||
|
||
// 平台层只暴露稳定错误分类,HTTP status 和业务文案由 api-server 再映射。
|
||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||
pub enum LlmErrorKind {
|
||
InvalidConfig,
|
||
InvalidRequest,
|
||
Timeout,
|
||
Connectivity,
|
||
Upstream,
|
||
StreamUnavailable,
|
||
EmptyResponse,
|
||
Transport,
|
||
Deserialize,
|
||
}
|
||
|
||
// 统一 OpenAI 兼容文本网关 client。
|
||
#[derive(Clone, Debug)]
|
||
pub struct LlmClient {
|
||
config: LlmConfig,
|
||
http_client: Client,
|
||
}
|
||
|
||
#[derive(Serialize)]
|
||
#[serde(untagged)]
|
||
enum LlmRequestBody {
|
||
ChatCompletions(ChatCompletionsRequestBody),
|
||
Responses(ResponsesRequestBody),
|
||
}
|
||
|
||
#[derive(Serialize)]
|
||
struct ChatCompletionsRequestBody {
|
||
model: String,
|
||
messages: Vec<ChatCompletionsInputMessage>,
|
||
stream: bool,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
official_fallback: Option<bool>,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
max_tokens: Option<u32>,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
web_search_options: Option<ChatCompletionsWebSearchOptions>,
|
||
}
|
||
|
||
#[derive(Serialize)]
|
||
struct ChatCompletionsWebSearchOptions {}
|
||
|
||
#[derive(Serialize)]
|
||
struct ChatCompletionsInputMessage {
|
||
role: &'static str,
|
||
content: ChatCompletionsInputContent,
|
||
}
|
||
|
||
#[derive(Serialize)]
|
||
#[serde(untagged)]
|
||
enum ChatCompletionsInputContent {
|
||
Text(String),
|
||
Parts(Vec<ChatCompletionsInputContentPart>),
|
||
}
|
||
|
||
#[derive(Serialize)]
|
||
#[serde(tag = "type")]
|
||
enum ChatCompletionsInputContentPart {
|
||
#[serde(rename = "text")]
|
||
Text { text: String },
|
||
#[serde(rename = "image_url")]
|
||
ImageUrl { image_url: ChatCompletionsImageUrl },
|
||
}
|
||
|
||
#[derive(Serialize)]
|
||
struct ChatCompletionsImageUrl {
|
||
url: String,
|
||
}
|
||
|
||
#[derive(Serialize)]
|
||
struct ResponsesRequestBody {
|
||
model: String,
|
||
stream: bool,
|
||
input: Vec<ResponsesInputMessage>,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
official_fallback: Option<bool>,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
max_output_tokens: Option<u32>,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
tools: Option<Vec<ResponsesWebSearchTool>>,
|
||
}
|
||
|
||
#[derive(Serialize)]
|
||
struct ResponsesInputMessage {
|
||
role: &'static str,
|
||
content: Vec<ResponsesInputContentPart>,
|
||
}
|
||
|
||
#[derive(Serialize)]
|
||
#[serde(tag = "type", rename_all = "snake_case")]
|
||
enum ResponsesInputContentPart {
|
||
InputText { text: String },
|
||
InputImage { image_url: String },
|
||
}
|
||
|
||
#[derive(Serialize)]
|
||
struct ResponsesWebSearchTool {
|
||
#[serde(rename = "type")]
|
||
tool_type: &'static str,
|
||
max_keyword: u8,
|
||
}
|
||
|
||
#[derive(Serialize)]
|
||
#[serde(rename_all = "camelCase")]
|
||
struct LlmRawFailureInputLog<'a> {
|
||
provider: &'static str,
|
||
protocol: &'static str,
|
||
model: &'a str,
|
||
stream: bool,
|
||
attempt: u32,
|
||
max_tokens: Option<u32>,
|
||
messages: &'a [LlmMessage],
|
||
}
|
||
|
||
#[derive(Deserialize)]
|
||
#[serde(untagged)]
|
||
enum ChatCompletionsResponsePayload {
|
||
Direct(ChatCompletionsResponseEnvelope),
|
||
Wrapped {
|
||
data: ChatCompletionsResponseEnvelope,
|
||
},
|
||
}
|
||
|
||
#[derive(Deserialize)]
|
||
struct ChatCompletionsResponseEnvelope {
|
||
id: Option<String>,
|
||
model: Option<String>,
|
||
choices: Vec<ChatCompletionsChoice>,
|
||
usage: Option<LlmTokenUsage>,
|
||
}
|
||
|
||
#[derive(Deserialize)]
|
||
struct ChatCompletionsChoice {
|
||
#[serde(default)]
|
||
message: Option<ChatCompletionsMessage>,
|
||
#[serde(default)]
|
||
delta: Option<ChatCompletionsMessage>,
|
||
#[serde(default)]
|
||
finish_reason: Option<String>,
|
||
}
|
||
|
||
#[derive(Deserialize)]
|
||
struct ChatCompletionsMessage {
|
||
#[serde(default)]
|
||
content: Option<ChatCompletionsContent>,
|
||
}
|
||
|
||
#[derive(Deserialize)]
|
||
#[serde(untagged)]
|
||
enum ChatCompletionsContent {
|
||
Text(String),
|
||
Parts(Vec<ChatCompletionsContentPart>),
|
||
}
|
||
|
||
#[derive(Deserialize)]
|
||
struct ChatCompletionsContentPart {
|
||
#[serde(rename = "type")]
|
||
#[allow(dead_code)]
|
||
part_type: Option<String>,
|
||
#[serde(default)]
|
||
text: Option<String>,
|
||
}
|
||
|
||
#[derive(Deserialize)]
|
||
struct ResponsesResponseEnvelope {
|
||
id: Option<String>,
|
||
model: Option<String>,
|
||
#[serde(default)]
|
||
output_text: Option<String>,
|
||
#[serde(default)]
|
||
output: Vec<ResponsesOutputItem>,
|
||
#[serde(default)]
|
||
status: Option<String>,
|
||
usage: Option<ResponsesUsage>,
|
||
}
|
||
|
||
#[derive(Deserialize)]
|
||
struct ResponsesOutputItem {
|
||
#[serde(default)]
|
||
content: Vec<ResponsesOutputContentPart>,
|
||
}
|
||
|
||
#[derive(Deserialize)]
|
||
struct ResponsesOutputContentPart {
|
||
#[serde(rename = "type")]
|
||
#[allow(dead_code)]
|
||
part_type: Option<String>,
|
||
#[serde(default)]
|
||
text: Option<String>,
|
||
}
|
||
|
||
#[derive(Deserialize)]
|
||
struct ResponsesUsage {
|
||
#[serde(default)]
|
||
input_tokens: u64,
|
||
#[serde(default)]
|
||
output_tokens: u64,
|
||
#[serde(default)]
|
||
total_tokens: u64,
|
||
}
|
||
|
||
struct OpenAiCompatibleSseParser {
|
||
buffer: String,
|
||
raw_text: String,
|
||
protocol: LlmTextProtocol,
|
||
}
|
||
|
||
#[derive(Debug)]
|
||
struct ParsedStreamEvent {
|
||
delta_text: Option<String>,
|
||
finish_reason: Option<String>,
|
||
}
|
||
|
||
impl LlmProvider {
|
||
pub fn as_str(&self) -> &'static str {
|
||
match self {
|
||
Self::Ark => "ark",
|
||
Self::DashScope => "dash_scope",
|
||
Self::OpenAiCompatible => "openai_compatible",
|
||
}
|
||
}
|
||
}
|
||
|
||
impl LlmConfig {
|
||
#[allow(clippy::too_many_arguments)]
|
||
pub fn new(
|
||
provider: LlmProvider,
|
||
base_url: String,
|
||
api_key: String,
|
||
model: String,
|
||
request_timeout_ms: u64,
|
||
max_retries: u32,
|
||
retry_backoff_ms: u64,
|
||
) -> Result<Self, LlmError> {
|
||
let base_url = normalize_non_empty(base_url, "LLM base_url 不能为空")?;
|
||
let api_key = normalize_non_empty(api_key, "LLM api_key 不能为空")?;
|
||
let model = normalize_non_empty(model, "LLM model 不能为空")?;
|
||
|
||
if request_timeout_ms == 0 {
|
||
return Err(LlmError::InvalidConfig(
|
||
"LLM request_timeout_ms 必须大于 0".to_string(),
|
||
));
|
||
}
|
||
|
||
Ok(Self {
|
||
provider,
|
||
base_url,
|
||
api_key,
|
||
model,
|
||
request_timeout_ms,
|
||
max_retries,
|
||
retry_backoff_ms,
|
||
official_fallback: false,
|
||
})
|
||
}
|
||
|
||
pub fn with_official_fallback(mut self, official_fallback: bool) -> Self {
|
||
self.official_fallback = official_fallback;
|
||
self
|
||
}
|
||
|
||
pub fn ark_default(api_key: String, model: String) -> Result<Self, LlmError> {
|
||
Self::new(
|
||
LlmProvider::Ark,
|
||
DEFAULT_ARK_BASE_URL.to_string(),
|
||
api_key,
|
||
model,
|
||
DEFAULT_REQUEST_TIMEOUT_MS,
|
||
DEFAULT_MAX_RETRIES,
|
||
DEFAULT_RETRY_BACKOFF_MS,
|
||
)
|
||
}
|
||
|
||
pub fn provider(&self) -> LlmProvider {
|
||
self.provider
|
||
}
|
||
|
||
pub fn base_url(&self) -> &str {
|
||
&self.base_url
|
||
}
|
||
|
||
pub fn api_key(&self) -> &str {
|
||
&self.api_key
|
||
}
|
||
|
||
pub fn model(&self) -> &str {
|
||
&self.model
|
||
}
|
||
|
||
pub fn request_timeout_ms(&self) -> u64 {
|
||
self.request_timeout_ms
|
||
}
|
||
|
||
pub fn max_retries(&self) -> u32 {
|
||
self.max_retries
|
||
}
|
||
|
||
pub fn retry_backoff_ms(&self) -> u64 {
|
||
self.retry_backoff_ms
|
||
}
|
||
|
||
pub fn official_fallback(&self) -> bool {
|
||
self.official_fallback
|
||
}
|
||
|
||
pub fn chat_completions_url(&self) -> String {
|
||
format!(
|
||
"{}/{}",
|
||
self.base_url.trim_end_matches('/'),
|
||
CHAT_COMPLETIONS_PATH.trim_start_matches('/')
|
||
)
|
||
}
|
||
|
||
pub fn responses_url(&self) -> String {
|
||
format!(
|
||
"{}/{}",
|
||
self.base_url.trim_end_matches('/'),
|
||
RESPONSES_PATH.trim_start_matches('/')
|
||
)
|
||
}
|
||
}
|
||
|
||
impl LlmMessage {
|
||
pub fn new(role: LlmMessageRole, content: impl Into<String>) -> Self {
|
||
Self {
|
||
role,
|
||
content: content.into(),
|
||
content_parts: Vec::new(),
|
||
}
|
||
}
|
||
|
||
pub fn system(content: impl Into<String>) -> Self {
|
||
Self::new(LlmMessageRole::System, content)
|
||
}
|
||
|
||
pub fn user(content: impl Into<String>) -> Self {
|
||
Self::new(LlmMessageRole::User, content)
|
||
}
|
||
|
||
pub fn assistant(content: impl Into<String>) -> Self {
|
||
Self::new(LlmMessageRole::Assistant, content)
|
||
}
|
||
|
||
pub fn multimodal(role: LlmMessageRole, content_parts: Vec<LlmMessageContentPart>) -> Self {
|
||
let content = content_parts
|
||
.iter()
|
||
.filter_map(|part| match part {
|
||
LlmMessageContentPart::InputText { text } => Some(text.as_str()),
|
||
LlmMessageContentPart::InputImage { .. } => None,
|
||
})
|
||
.collect::<Vec<_>>()
|
||
.join("\n");
|
||
|
||
Self {
|
||
role,
|
||
content,
|
||
content_parts,
|
||
}
|
||
}
|
||
|
||
pub fn user_multimodal(content_parts: Vec<LlmMessageContentPart>) -> Self {
|
||
Self::multimodal(LlmMessageRole::User, content_parts)
|
||
}
|
||
|
||
pub fn with_image_url(mut self, image_url: impl Into<String>) -> Self {
|
||
if self.content_parts.is_empty() && !self.content.trim().is_empty() {
|
||
self.content_parts.push(LlmMessageContentPart::InputText {
|
||
text: self.content.clone(),
|
||
});
|
||
}
|
||
self.content_parts.push(LlmMessageContentPart::InputImage {
|
||
image_url: image_url.into(),
|
||
});
|
||
self
|
||
}
|
||
}
|
||
|
||
impl LlmTextRequest {
|
||
pub fn new(messages: Vec<LlmMessage>) -> Self {
|
||
Self {
|
||
model: None,
|
||
messages,
|
||
max_tokens: None,
|
||
enable_web_search: false,
|
||
protocol: LlmTextProtocol::ChatCompletions,
|
||
request_timeout_ms: None,
|
||
}
|
||
}
|
||
|
||
pub fn single_turn(system_prompt: impl Into<String>, user_prompt: impl Into<String>) -> Self {
|
||
Self::new(vec![
|
||
LlmMessage::system(system_prompt),
|
||
LlmMessage::user(user_prompt),
|
||
])
|
||
}
|
||
|
||
pub fn with_model(mut self, model: impl Into<String>) -> Self {
|
||
self.model = Some(model.into());
|
||
self
|
||
}
|
||
|
||
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
|
||
self.max_tokens = Some(max_tokens);
|
||
self
|
||
}
|
||
|
||
pub fn with_web_search(mut self, enabled: bool) -> Self {
|
||
self.enable_web_search = enabled;
|
||
self
|
||
}
|
||
|
||
pub fn with_responses_api(mut self) -> Self {
|
||
self.protocol = LlmTextProtocol::Responses;
|
||
self
|
||
}
|
||
|
||
pub fn with_request_timeout_ms(mut self, request_timeout_ms: u64) -> Self {
|
||
self.request_timeout_ms = Some(request_timeout_ms);
|
||
self
|
||
}
|
||
|
||
fn validate(&self) -> Result<(), LlmError> {
|
||
if self.messages.is_empty() {
|
||
return Err(LlmError::InvalidRequest(
|
||
"LLM messages 不能为空".to_string(),
|
||
));
|
||
}
|
||
|
||
for message in &self.messages {
|
||
let has_text = !message.content.trim().is_empty()
|
||
|| message.content_parts.iter().any(|part| match part {
|
||
LlmMessageContentPart::InputText { text } => !text.trim().is_empty(),
|
||
LlmMessageContentPart::InputImage { .. } => false,
|
||
});
|
||
let has_image = message.content_parts.iter().any(|part| match part {
|
||
LlmMessageContentPart::InputImage { image_url } => !image_url.trim().is_empty(),
|
||
LlmMessageContentPart::InputText { .. } => false,
|
||
});
|
||
if !has_text && !has_image {
|
||
return Err(LlmError::InvalidRequest(
|
||
"LLM message content 不能为空".to_string(),
|
||
));
|
||
}
|
||
|
||
if message.content_parts.iter().any(|part| match part {
|
||
LlmMessageContentPart::InputText { text } => text.trim().is_empty(),
|
||
LlmMessageContentPart::InputImage { image_url } => image_url.trim().is_empty(),
|
||
}) {
|
||
return Err(LlmError::InvalidRequest(
|
||
"LLM message content part 不能为空".to_string(),
|
||
));
|
||
}
|
||
}
|
||
|
||
if let Some(model) = &self.model
|
||
&& model.trim().is_empty()
|
||
{
|
||
return Err(LlmError::InvalidRequest(
|
||
"LLM request.model 不能为空字符串".to_string(),
|
||
));
|
||
}
|
||
|
||
if let Some(request_timeout_ms) = self.request_timeout_ms
|
||
&& request_timeout_ms == 0
|
||
{
|
||
return Err(LlmError::InvalidRequest(
|
||
"LLM request_timeout_ms 必须大于 0".to_string(),
|
||
));
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
fn resolved_model<'a>(&'a self, fallback_model: &'a str) -> &'a str {
|
||
self.model
|
||
.as_deref()
|
||
.map(str::trim)
|
||
.filter(|value| !value.is_empty())
|
||
.unwrap_or(fallback_model)
|
||
}
|
||
|
||
fn resolved_request_timeout_ms(&self, fallback_timeout_ms: u64) -> u64 {
|
||
self.request_timeout_ms
|
||
.filter(|value| *value > 0)
|
||
.unwrap_or(fallback_timeout_ms)
|
||
}
|
||
}
|
||
|
||
impl LlmTextProtocol {
|
||
fn as_str(self) -> &'static str {
|
||
match self {
|
||
Self::ChatCompletions => "chat_completions",
|
||
Self::Responses => "responses",
|
||
}
|
||
}
|
||
}
|
||
|
||
impl fmt::Display for LlmError {
|
||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||
match self {
|
||
Self::InvalidConfig(message)
|
||
| Self::InvalidRequest(message)
|
||
| Self::Transport(message)
|
||
| Self::Deserialize(message) => write!(f, "{message}"),
|
||
Self::Timeout { attempts } => {
|
||
write!(f, "LLM 请求超时,累计尝试 {attempts} 次")
|
||
}
|
||
Self::Connectivity { attempts, message } => {
|
||
write!(f, "LLM 连接失败,累计尝试 {attempts} 次:{message}")
|
||
}
|
||
Self::Upstream {
|
||
status_code,
|
||
message,
|
||
} => write!(f, "LLM 上游返回 {status_code}:{message}"),
|
||
Self::StreamUnavailable => write!(f, "LLM 流式响应体不可用"),
|
||
Self::EmptyResponse => write!(f, "LLM 返回内容为空"),
|
||
}
|
||
}
|
||
}
|
||
|
||
impl Error for LlmError {}
|
||
|
||
impl LlmError {
|
||
pub fn kind(&self) -> LlmErrorKind {
|
||
match self {
|
||
Self::InvalidConfig(_) => LlmErrorKind::InvalidConfig,
|
||
Self::InvalidRequest(_) => LlmErrorKind::InvalidRequest,
|
||
Self::Timeout { .. } => LlmErrorKind::Timeout,
|
||
Self::Connectivity { .. } => LlmErrorKind::Connectivity,
|
||
Self::Upstream { .. } => LlmErrorKind::Upstream,
|
||
Self::StreamUnavailable => LlmErrorKind::StreamUnavailable,
|
||
Self::EmptyResponse => LlmErrorKind::EmptyResponse,
|
||
Self::Transport(_) => LlmErrorKind::Transport,
|
||
Self::Deserialize(_) => LlmErrorKind::Deserialize,
|
||
}
|
||
}
|
||
}
|
||
|
||
impl LlmClient {
|
||
pub fn new(config: LlmConfig) -> Result<Self, LlmError> {
|
||
let http_client = Client::builder().build().map_err(|error| {
|
||
LlmError::InvalidConfig(format!("构建 reqwest client 失败:{error}"))
|
||
})?;
|
||
|
||
Ok(Self {
|
||
config,
|
||
http_client,
|
||
})
|
||
}
|
||
|
||
pub fn config(&self) -> &LlmConfig {
|
||
&self.config
|
||
}
|
||
|
||
pub async fn request_text(&self, request: LlmTextRequest) -> Result<LlmTextResponse, LlmError> {
|
||
request.validate()?;
|
||
let resolved_model = request.resolved_model(self.config.model()).to_string();
|
||
let response = self.execute_request(&request, false).await?;
|
||
let raw_text = response.text().await.map_err(|error| {
|
||
let llm_error = map_stream_read_error(error, 1);
|
||
log_llm_raw_failure(
|
||
&self.config,
|
||
&request,
|
||
false,
|
||
1,
|
||
"read_response_failed",
|
||
llm_error.to_string().as_str(),
|
||
);
|
||
llm_error
|
||
})?;
|
||
|
||
parse_text_response(
|
||
request.protocol,
|
||
self.config.provider(),
|
||
&resolved_model,
|
||
raw_text.as_str(),
|
||
)
|
||
.map_err(|error| {
|
||
log_llm_raw_failure(
|
||
&self.config,
|
||
&request,
|
||
false,
|
||
1,
|
||
"parse_response_failed",
|
||
raw_text.as_str(),
|
||
);
|
||
error
|
||
})
|
||
}
|
||
|
||
pub async fn request_single_message_text(
|
||
&self,
|
||
system_prompt: impl Into<String>,
|
||
user_prompt: impl Into<String>,
|
||
) -> Result<LlmTextResponse, LlmError> {
|
||
self.request_text(LlmTextRequest::single_turn(system_prompt, user_prompt))
|
||
.await
|
||
}
|
||
|
||
pub async fn stream_text<F>(
|
||
&self,
|
||
request: LlmTextRequest,
|
||
mut on_delta: F,
|
||
) -> Result<LlmTextResponse, LlmError>
|
||
where
|
||
F: FnMut(&LlmStreamDelta),
|
||
{
|
||
request.validate()?;
|
||
let resolved_model = request.resolved_model(self.config.model()).to_string();
|
||
let mut response = self.execute_request(&request, true).await?;
|
||
let response_id = response
|
||
.headers()
|
||
.get("x-request-id")
|
||
.and_then(|value| value.to_str().ok())
|
||
.map(str::to_string);
|
||
|
||
let mut parser = OpenAiCompatibleSseParser::new(request.protocol);
|
||
let mut accumulated_text = String::new();
|
||
let mut finish_reason = None;
|
||
let mut undecoded_chunk_bytes = Vec::new();
|
||
|
||
loop {
|
||
let next_chunk = response.chunk().await.map_err(|error| {
|
||
let llm_error = map_stream_read_error(error, 1);
|
||
log_llm_raw_failure(
|
||
&self.config,
|
||
&request,
|
||
true,
|
||
1,
|
||
"read_stream_failed",
|
||
parser.raw_text().as_str(),
|
||
);
|
||
llm_error
|
||
})?;
|
||
|
||
let Some(chunk) = next_chunk else {
|
||
break;
|
||
};
|
||
|
||
undecoded_chunk_bytes.extend_from_slice(chunk.as_ref());
|
||
let (chunk_text, remaining_bytes) =
|
||
decode_utf8_stream_chunk(undecoded_chunk_bytes.as_slice()).map_err(|error| {
|
||
log_llm_raw_failure(
|
||
&self.config,
|
||
&request,
|
||
true,
|
||
1,
|
||
"decode_stream_failed",
|
||
parser.raw_text().as_str(),
|
||
);
|
||
error
|
||
})?;
|
||
undecoded_chunk_bytes = remaining_bytes;
|
||
if chunk_text.is_empty() {
|
||
continue;
|
||
}
|
||
let stream_events = parser.push_chunk(chunk_text.as_ref()).map_err(|error| {
|
||
log_llm_raw_failure(
|
||
&self.config,
|
||
&request,
|
||
true,
|
||
1,
|
||
"parse_stream_failed",
|
||
parser.raw_text().as_str(),
|
||
);
|
||
error
|
||
})?;
|
||
for event in stream_events {
|
||
if let Some(delta_text) = event.delta_text
|
||
&& !delta_text.is_empty()
|
||
{
|
||
accumulated_text.push_str(delta_text.as_str());
|
||
let update = LlmStreamDelta {
|
||
accumulated_text: accumulated_text.clone(),
|
||
delta_text,
|
||
finish_reason: event.finish_reason.clone(),
|
||
};
|
||
on_delta(&update);
|
||
}
|
||
|
||
if event.finish_reason.is_some() {
|
||
finish_reason = event.finish_reason;
|
||
}
|
||
}
|
||
}
|
||
|
||
if !undecoded_chunk_bytes.is_empty() {
|
||
let trailing_text =
|
||
std_str::from_utf8(undecoded_chunk_bytes.as_slice()).map_err(|error| {
|
||
log_llm_raw_failure(
|
||
&self.config,
|
||
&request,
|
||
true,
|
||
1,
|
||
"decode_stream_failed",
|
||
parser.raw_text().as_str(),
|
||
);
|
||
LlmError::Deserialize(format!("解析 LLM 流式 UTF-8 响应失败:{error}"))
|
||
})?;
|
||
if !trailing_text.is_empty() {
|
||
let trailing_events = parser.push_chunk(trailing_text).map_err(|error| {
|
||
log_llm_raw_failure(
|
||
&self.config,
|
||
&request,
|
||
true,
|
||
1,
|
||
"parse_stream_failed",
|
||
parser.raw_text().as_str(),
|
||
);
|
||
error
|
||
})?;
|
||
for event in trailing_events {
|
||
if let Some(delta_text) = event.delta_text
|
||
&& !delta_text.is_empty()
|
||
{
|
||
accumulated_text.push_str(delta_text.as_str());
|
||
let update = LlmStreamDelta {
|
||
accumulated_text: accumulated_text.clone(),
|
||
delta_text,
|
||
finish_reason: event.finish_reason.clone(),
|
||
};
|
||
on_delta(&update);
|
||
}
|
||
|
||
if event.finish_reason.is_some() {
|
||
finish_reason = event.finish_reason;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
let remaining_events = parser.finish().map_err(|error| {
|
||
log_llm_raw_failure(
|
||
&self.config,
|
||
&request,
|
||
true,
|
||
1,
|
||
"parse_stream_failed",
|
||
parser.raw_text().as_str(),
|
||
);
|
||
error
|
||
})?;
|
||
for event in remaining_events {
|
||
if let Some(delta_text) = event.delta_text
|
||
&& !delta_text.is_empty()
|
||
{
|
||
accumulated_text.push_str(delta_text.as_str());
|
||
let update = LlmStreamDelta {
|
||
accumulated_text: accumulated_text.clone(),
|
||
delta_text,
|
||
finish_reason: event.finish_reason.clone(),
|
||
};
|
||
on_delta(&update);
|
||
}
|
||
|
||
if event.finish_reason.is_some() {
|
||
finish_reason = event.finish_reason;
|
||
}
|
||
}
|
||
|
||
let content = accumulated_text.trim().to_string();
|
||
if content.is_empty() {
|
||
log_llm_raw_failure(
|
||
&self.config,
|
||
&request,
|
||
true,
|
||
1,
|
||
"empty_stream_response",
|
||
parser.raw_text().as_str(),
|
||
);
|
||
return Err(LlmError::EmptyResponse);
|
||
}
|
||
|
||
Ok(LlmTextResponse {
|
||
provider: self.config.provider(),
|
||
model: resolved_model,
|
||
content,
|
||
finish_reason,
|
||
response_id,
|
||
usage: None,
|
||
})
|
||
}
|
||
|
||
pub async fn stream_single_message_text<F>(
|
||
&self,
|
||
system_prompt: impl Into<String>,
|
||
user_prompt: impl Into<String>,
|
||
on_delta: F,
|
||
) -> Result<LlmTextResponse, LlmError>
|
||
where
|
||
F: FnMut(&LlmStreamDelta),
|
||
{
|
||
self.stream_text(
|
||
LlmTextRequest::single_turn(system_prompt, user_prompt),
|
||
on_delta,
|
||
)
|
||
.await
|
||
}
|
||
|
||
async fn execute_request(
|
||
&self,
|
||
request: &LlmTextRequest,
|
||
stream: bool,
|
||
) -> Result<reqwest::Response, LlmError> {
|
||
let request_body = build_request_body(request, &self.config, stream);
|
||
let model = request.resolved_model(self.config.model());
|
||
let url = match request.protocol {
|
||
LlmTextProtocol::ChatCompletions => self.config.chat_completions_url(),
|
||
LlmTextProtocol::Responses => self.config.responses_url(),
|
||
};
|
||
let max_attempts = self.config.max_retries().saturating_add(1);
|
||
|
||
for attempt in 1..=max_attempts {
|
||
debug!(
|
||
"platform-llm request started: provider={}, protocol={}, stream={}, attempt={}, model={}",
|
||
self.config.provider().as_str(),
|
||
request.protocol.as_str(),
|
||
stream,
|
||
attempt,
|
||
model
|
||
);
|
||
|
||
let send_result = self
|
||
.http_client
|
||
.post(url.as_str())
|
||
.bearer_auth(self.config.api_key())
|
||
.json(&request_body)
|
||
.timeout(Duration::from_millis(
|
||
request.resolved_request_timeout_ms(self.config.request_timeout_ms()),
|
||
))
|
||
.send()
|
||
.await;
|
||
|
||
match send_result {
|
||
Ok(response) if response.status().is_success() => {
|
||
debug!(
|
||
"platform-llm request succeeded: provider={}, protocol={}, stream={}, attempt={}, status={}",
|
||
self.config.provider().as_str(),
|
||
request.protocol.as_str(),
|
||
stream,
|
||
attempt,
|
||
response.status().as_u16()
|
||
);
|
||
return Ok(response);
|
||
}
|
||
Ok(response) => {
|
||
let status = response.status();
|
||
let raw_text = response.text().await.unwrap_or_default();
|
||
let message = extract_api_error_message(&raw_text, "LLM 上游请求失败");
|
||
|
||
if should_retry_status(status) && attempt < max_attempts {
|
||
warn!(
|
||
"platform-llm request retrying after upstream status: provider={}, protocol={}, attempt={}, status={}, message={}",
|
||
self.config.provider().as_str(),
|
||
request.protocol.as_str(),
|
||
attempt,
|
||
status.as_u16(),
|
||
message
|
||
);
|
||
self.sleep_before_retry(attempt).await;
|
||
continue;
|
||
}
|
||
|
||
log_llm_raw_failure(
|
||
&self.config,
|
||
request,
|
||
stream,
|
||
attempt,
|
||
"upstream_status_failed",
|
||
raw_text.as_str(),
|
||
);
|
||
return Err(LlmError::Upstream {
|
||
status_code: status.as_u16(),
|
||
message,
|
||
});
|
||
}
|
||
Err(error) if error.is_timeout() => {
|
||
if attempt < max_attempts {
|
||
warn!(
|
||
"platform-llm request retrying after timeout: provider={}, protocol={}, attempt={}",
|
||
self.config.provider().as_str(),
|
||
request.protocol.as_str(),
|
||
attempt
|
||
);
|
||
self.sleep_before_retry(attempt).await;
|
||
continue;
|
||
}
|
||
|
||
let error = LlmError::Timeout { attempts: attempt };
|
||
log_llm_raw_failure(
|
||
&self.config,
|
||
request,
|
||
stream,
|
||
attempt,
|
||
"request_timeout",
|
||
error.to_string().as_str(),
|
||
);
|
||
return Err(error);
|
||
}
|
||
Err(error) if error.is_connect() => {
|
||
let message = error.to_string();
|
||
if attempt < max_attempts {
|
||
warn!(
|
||
"platform-llm request retrying after connectivity failure: provider={}, protocol={}, attempt={}, error={}",
|
||
self.config.provider().as_str(),
|
||
request.protocol.as_str(),
|
||
attempt,
|
||
message
|
||
);
|
||
self.sleep_before_retry(attempt).await;
|
||
continue;
|
||
}
|
||
|
||
let error = LlmError::Connectivity {
|
||
attempts: attempt,
|
||
message,
|
||
};
|
||
log_llm_raw_failure(
|
||
&self.config,
|
||
request,
|
||
stream,
|
||
attempt,
|
||
"request_connectivity_failed",
|
||
error.to_string().as_str(),
|
||
);
|
||
return Err(error);
|
||
}
|
||
Err(error) => {
|
||
let error = LlmError::Transport(error.to_string());
|
||
log_llm_raw_failure(
|
||
&self.config,
|
||
request,
|
||
stream,
|
||
attempt,
|
||
"request_transport_failed",
|
||
error.to_string().as_str(),
|
||
);
|
||
return Err(error);
|
||
}
|
||
}
|
||
}
|
||
|
||
Err(LlmError::Transport(
|
||
"LLM 请求在重试循环后仍未返回结果".to_string(),
|
||
))
|
||
}
|
||
|
||
async fn sleep_before_retry(&self, attempt: u32) {
|
||
let backoff_ms = self
|
||
.config
|
||
.retry_backoff_ms()
|
||
.saturating_mul(u64::from(attempt));
|
||
|
||
if backoff_ms > 0 {
|
||
sleep(Duration::from_millis(backoff_ms)).await;
|
||
}
|
||
}
|
||
}
|
||
|
||
impl OpenAiCompatibleSseParser {
|
||
fn new(protocol: LlmTextProtocol) -> Self {
|
||
Self {
|
||
buffer: String::new(),
|
||
raw_text: String::new(),
|
||
protocol,
|
||
}
|
||
}
|
||
|
||
fn push_chunk(&mut self, chunk: &str) -> Result<Vec<ParsedStreamEvent>, LlmError> {
|
||
self.raw_text.push_str(chunk);
|
||
self.buffer.push_str(chunk);
|
||
self.buffer = self.buffer.replace("\r\n", "\n");
|
||
self.drain_complete_events()
|
||
}
|
||
|
||
fn raw_text(&self) -> String {
|
||
self.raw_text.clone()
|
||
}
|
||
|
||
fn finish(&mut self) -> Result<Vec<ParsedStreamEvent>, LlmError> {
|
||
if self.buffer.trim().is_empty() {
|
||
return Ok(Vec::new());
|
||
}
|
||
|
||
self.buffer.push_str("\n\n");
|
||
self.drain_complete_events()
|
||
}
|
||
|
||
fn drain_complete_events(&mut self) -> Result<Vec<ParsedStreamEvent>, LlmError> {
|
||
let mut events = Vec::new();
|
||
|
||
while let Some(boundary) = self.buffer.find("\n\n") {
|
||
let block = self.buffer[..boundary].to_string();
|
||
self.buffer = self.buffer[(boundary + 2)..].to_string();
|
||
|
||
if let Some(event) = parse_sse_event_block(self.protocol, block.as_str())? {
|
||
events.push(event);
|
||
}
|
||
}
|
||
|
||
Ok(events)
|
||
}
|
||
}
|
||
|
||
fn normalize_non_empty(value: String, error_message: &str) -> Result<String, LlmError> {
|
||
let trimmed = value.trim().to_string();
|
||
if trimmed.is_empty() {
|
||
return Err(LlmError::InvalidConfig(error_message.to_string()));
|
||
}
|
||
|
||
Ok(trimmed)
|
||
}
|
||
|
||
fn build_request_body(
|
||
request: &LlmTextRequest,
|
||
config: &LlmConfig,
|
||
stream: bool,
|
||
) -> LlmRequestBody {
|
||
let fallback_model = config.model();
|
||
let official_fallback = config.official_fallback().then_some(true);
|
||
match request.protocol {
|
||
LlmTextProtocol::ChatCompletions => {
|
||
LlmRequestBody::ChatCompletions(ChatCompletionsRequestBody {
|
||
model: request.resolved_model(fallback_model).to_string(),
|
||
messages: map_chat_completions_input_messages(request.messages.as_slice()),
|
||
stream,
|
||
official_fallback,
|
||
max_tokens: request.max_tokens,
|
||
web_search_options: request
|
||
.enable_web_search
|
||
.then_some(ChatCompletionsWebSearchOptions {}),
|
||
})
|
||
}
|
||
LlmTextProtocol::Responses => LlmRequestBody::Responses(ResponsesRequestBody {
|
||
model: request.resolved_model(fallback_model).to_string(),
|
||
stream,
|
||
input: map_responses_input_messages(request.messages.as_slice()),
|
||
official_fallback,
|
||
max_output_tokens: request.max_tokens,
|
||
tools: request.enable_web_search.then(|| {
|
||
vec![ResponsesWebSearchTool {
|
||
tool_type: "web_search",
|
||
max_keyword: 3,
|
||
}]
|
||
}),
|
||
}),
|
||
}
|
||
}
|
||
|
||
fn map_chat_completions_input_messages(
|
||
messages: &[LlmMessage],
|
||
) -> Vec<ChatCompletionsInputMessage> {
|
||
messages
|
||
.iter()
|
||
.map(|message| ChatCompletionsInputMessage {
|
||
role: map_llm_message_role(message.role),
|
||
content: map_chat_completions_content(message),
|
||
})
|
||
.collect()
|
||
}
|
||
|
||
fn map_chat_completions_content(message: &LlmMessage) -> ChatCompletionsInputContent {
|
||
if message.content_parts.is_empty() {
|
||
return ChatCompletionsInputContent::Text(message.content.clone());
|
||
}
|
||
|
||
ChatCompletionsInputContent::Parts(
|
||
message
|
||
.content_parts
|
||
.iter()
|
||
.map(|part| match part {
|
||
LlmMessageContentPart::InputText { text } => {
|
||
ChatCompletionsInputContentPart::Text { text: text.clone() }
|
||
}
|
||
LlmMessageContentPart::InputImage { image_url } => {
|
||
ChatCompletionsInputContentPart::ImageUrl {
|
||
image_url: ChatCompletionsImageUrl {
|
||
url: image_url.clone(),
|
||
},
|
||
}
|
||
}
|
||
})
|
||
.collect(),
|
||
)
|
||
}
|
||
|
||
fn map_responses_input_messages(messages: &[LlmMessage]) -> Vec<ResponsesInputMessage> {
|
||
messages
|
||
.iter()
|
||
.map(|message| ResponsesInputMessage {
|
||
role: map_llm_message_role(message.role),
|
||
content: map_responses_content_parts(message),
|
||
})
|
||
.collect()
|
||
}
|
||
|
||
fn map_llm_message_role(role: LlmMessageRole) -> &'static str {
|
||
match role {
|
||
LlmMessageRole::System => "system",
|
||
LlmMessageRole::User => "user",
|
||
LlmMessageRole::Assistant => "assistant",
|
||
}
|
||
}
|
||
|
||
fn map_responses_content_parts(message: &LlmMessage) -> Vec<ResponsesInputContentPart> {
|
||
if message.content_parts.is_empty() {
|
||
return vec![ResponsesInputContentPart::InputText {
|
||
text: message.content.clone(),
|
||
}];
|
||
}
|
||
|
||
message
|
||
.content_parts
|
||
.iter()
|
||
.map(|part| match part {
|
||
LlmMessageContentPart::InputText { text } => {
|
||
ResponsesInputContentPart::InputText { text: text.clone() }
|
||
}
|
||
LlmMessageContentPart::InputImage { image_url } => {
|
||
ResponsesInputContentPart::InputImage {
|
||
image_url: image_url.clone(),
|
||
}
|
||
}
|
||
})
|
||
.collect()
|
||
}
|
||
|
||
fn log_llm_raw_failure(
|
||
config: &LlmConfig,
|
||
request: &LlmTextRequest,
|
||
stream: bool,
|
||
attempt: u32,
|
||
failure_stage: &str,
|
||
raw_output: &str,
|
||
) {
|
||
if let Err(error) =
|
||
write_llm_raw_failure(config, request, stream, attempt, failure_stage, raw_output)
|
||
{
|
||
warn!(
|
||
"LLM 失败原文日志落盘失败,主错误流程继续执行: failure_stage={}, error={}",
|
||
failure_stage, error
|
||
);
|
||
}
|
||
}
|
||
|
||
fn write_llm_raw_failure(
|
||
config: &LlmConfig,
|
||
request: &LlmTextRequest,
|
||
stream: bool,
|
||
attempt: u32,
|
||
failure_stage: &str,
|
||
raw_output: &str,
|
||
) -> Result<(), String> {
|
||
let log_dir = env::var("LLM_RAW_LOG_DIR")
|
||
.map(PathBuf::from)
|
||
.unwrap_or_else(|_| PathBuf::from(DEFAULT_LLM_RAW_LOG_DIR));
|
||
fs::create_dir_all(&log_dir).map_err(|error| format!("创建日志目录失败:{error}"))?;
|
||
|
||
let prefix = build_llm_raw_log_prefix(failure_stage);
|
||
let model = request.resolved_model(config.model());
|
||
let input_log = LlmRawFailureInputLog {
|
||
provider: config.provider().as_str(),
|
||
protocol: request.protocol.as_str(),
|
||
model,
|
||
stream,
|
||
attempt,
|
||
max_tokens: request.max_tokens,
|
||
messages: request.messages.as_slice(),
|
||
};
|
||
let input_text = serde_json::to_string_pretty(&input_log)
|
||
.map_err(|error| format!("序列化模型输入日志失败:{error}"))?;
|
||
fs::write(log_dir.join(format!("{prefix}.input.json")), input_text)
|
||
.map_err(|error| format!("写入模型输入日志失败:{error}"))?;
|
||
fs::write(log_dir.join(format!("{prefix}.output.txt")), raw_output)
|
||
.map_err(|error| format!("写入模型输出日志失败:{error}"))?;
|
||
|
||
Ok(())
|
||
}
|
||
|
||
fn build_llm_raw_log_prefix(failure_stage: &str) -> String {
|
||
let millis = SystemTime::now()
|
||
.duration_since(UNIX_EPOCH)
|
||
.map(|duration| duration.as_millis())
|
||
.unwrap_or_default();
|
||
let sequence = LLM_RAW_LOG_SEQUENCE.fetch_add(1, Ordering::Relaxed);
|
||
let safe_stage = sanitize_log_file_segment(failure_stage);
|
||
|
||
format!("{millis}-{}-{sequence:06}-{safe_stage}", std::process::id())
|
||
}
|
||
|
||
fn sanitize_log_file_segment(value: &str) -> String {
|
||
let sanitized = value
|
||
.chars()
|
||
.map(|character| {
|
||
if character.is_ascii_alphanumeric() || character == '-' || character == '_' {
|
||
character
|
||
} else {
|
||
'_'
|
||
}
|
||
})
|
||
.collect::<String>();
|
||
|
||
if sanitized.is_empty() {
|
||
"unknown".to_string()
|
||
} else {
|
||
sanitized
|
||
}
|
||
}
|
||
|
||
fn parse_text_response(
|
||
protocol: LlmTextProtocol,
|
||
provider: LlmProvider,
|
||
fallback_model: &str,
|
||
raw_text: &str,
|
||
) -> Result<LlmTextResponse, LlmError> {
|
||
match protocol {
|
||
LlmTextProtocol::ChatCompletions => {
|
||
parse_chat_completions_response(provider, fallback_model, raw_text)
|
||
}
|
||
LlmTextProtocol::Responses => parse_responses_response(provider, fallback_model, raw_text),
|
||
}
|
||
}
|
||
|
||
fn parse_chat_completions_response(
|
||
provider: LlmProvider,
|
||
fallback_model: &str,
|
||
raw_text: &str,
|
||
) -> Result<LlmTextResponse, LlmError> {
|
||
let parsed: ChatCompletionsResponsePayload = serde_json::from_str(raw_text)
|
||
.map_err(|error| LlmError::Deserialize(format!("解析 LLM JSON 响应失败:{error}")))?;
|
||
let parsed = match parsed {
|
||
ChatCompletionsResponsePayload::Direct(envelope) => envelope,
|
||
ChatCompletionsResponsePayload::Wrapped { data } => data,
|
||
};
|
||
|
||
let first_choice = parsed
|
||
.choices
|
||
.first()
|
||
.ok_or_else(|| LlmError::Deserialize("LLM 响应缺少 choices[0]".to_string()))?;
|
||
let content = extract_message_text(first_choice)
|
||
.ok_or(LlmError::EmptyResponse)?
|
||
.trim()
|
||
.to_string();
|
||
|
||
if content.is_empty() {
|
||
return Err(LlmError::EmptyResponse);
|
||
}
|
||
|
||
Ok(LlmTextResponse {
|
||
provider,
|
||
model: parsed.model.unwrap_or_else(|| fallback_model.to_string()),
|
||
content,
|
||
finish_reason: first_choice.finish_reason.clone(),
|
||
response_id: parsed.id,
|
||
usage: parsed.usage,
|
||
})
|
||
}
|
||
|
||
fn parse_responses_response(
|
||
provider: LlmProvider,
|
||
fallback_model: &str,
|
||
raw_text: &str,
|
||
) -> Result<LlmTextResponse, LlmError> {
|
||
let parsed: ResponsesResponseEnvelope = serde_json::from_str(raw_text).map_err(|error| {
|
||
LlmError::Deserialize(format!("解析 LLM Responses JSON 响应失败:{error}"))
|
||
})?;
|
||
let content = extract_responses_text(&parsed)
|
||
.ok_or(LlmError::EmptyResponse)?
|
||
.trim()
|
||
.to_string();
|
||
|
||
if content.is_empty() {
|
||
return Err(LlmError::EmptyResponse);
|
||
}
|
||
|
||
Ok(LlmTextResponse {
|
||
provider,
|
||
model: parsed.model.unwrap_or_else(|| fallback_model.to_string()),
|
||
content,
|
||
finish_reason: parsed.status,
|
||
response_id: parsed.id,
|
||
usage: parsed.usage.map(|usage| LlmTokenUsage {
|
||
prompt_tokens: usage.input_tokens,
|
||
completion_tokens: usage.output_tokens,
|
||
total_tokens: usage.total_tokens,
|
||
}),
|
||
})
|
||
}
|
||
|
||
fn extract_responses_text(parsed: &ResponsesResponseEnvelope) -> Option<String> {
|
||
parsed
|
||
.output_text
|
||
.as_deref()
|
||
.map(str::to_string)
|
||
.filter(|text| !text.is_empty())
|
||
.or_else(|| {
|
||
let text = parsed
|
||
.output
|
||
.iter()
|
||
.flat_map(|item| item.content.iter())
|
||
.filter_map(|part| part.text.as_deref())
|
||
.collect::<Vec<_>>()
|
||
.join("");
|
||
|
||
if text.is_empty() { None } else { Some(text) }
|
||
})
|
||
}
|
||
|
||
fn extract_message_text(choice: &ChatCompletionsChoice) -> Option<String> {
|
||
choice
|
||
.message
|
||
.as_ref()
|
||
.and_then(|message| message.content.as_ref())
|
||
.and_then(extract_content_text)
|
||
.or_else(|| {
|
||
choice
|
||
.delta
|
||
.as_ref()
|
||
.and_then(|message| message.content.as_ref())
|
||
.and_then(extract_content_text)
|
||
})
|
||
}
|
||
|
||
fn extract_content_text(content: &ChatCompletionsContent) -> Option<String> {
|
||
match content {
|
||
ChatCompletionsContent::Text(text) => Some(text.clone()),
|
||
ChatCompletionsContent::Parts(parts) => {
|
||
let text = parts
|
||
.iter()
|
||
.filter_map(|part| part.text.as_deref())
|
||
.collect::<Vec<_>>()
|
||
.join("");
|
||
|
||
if text.is_empty() { None } else { Some(text) }
|
||
}
|
||
}
|
||
}
|
||
|
||
fn decode_utf8_stream_chunk(bytes: &[u8]) -> Result<(String, Vec<u8>), LlmError> {
|
||
match std_str::from_utf8(bytes) {
|
||
Ok(text) => Ok((text.to_string(), Vec::new())),
|
||
Err(error) => {
|
||
let valid_up_to = error.valid_up_to();
|
||
let Some(_) = error.error_len() else {
|
||
let decoded = std_str::from_utf8(&bytes[..valid_up_to]).map_err(|inner_error| {
|
||
LlmError::Deserialize(format!("解析 LLM 流式 UTF-8 响应失败:{inner_error}"))
|
||
})?;
|
||
return Ok((decoded.to_string(), bytes[valid_up_to..].to_vec()));
|
||
};
|
||
|
||
Err(LlmError::Deserialize(format!(
|
||
"解析 LLM 流式 UTF-8 响应失败:{error}"
|
||
)))
|
||
}
|
||
}
|
||
}
|
||
|
||
fn parse_sse_event_block(
|
||
protocol: LlmTextProtocol,
|
||
block: &str,
|
||
) -> Result<Option<ParsedStreamEvent>, LlmError> {
|
||
let data_lines = block
|
||
.lines()
|
||
.filter_map(|line| line.trim().strip_prefix("data:"))
|
||
.map(str::trim_start)
|
||
.collect::<Vec<_>>();
|
||
|
||
if data_lines.is_empty() {
|
||
return Ok(None);
|
||
}
|
||
|
||
let data = data_lines.join("\n");
|
||
if data.trim().is_empty() || data.trim() == "[DONE]" {
|
||
return Ok(None);
|
||
}
|
||
|
||
if protocol == LlmTextProtocol::Responses {
|
||
return parse_responses_sse_event(data.as_str());
|
||
}
|
||
|
||
let parsed: ChatCompletionsResponseEnvelope = serde_json::from_str(data.as_str())
|
||
.map_err(|error| LlmError::Deserialize(format!("解析 LLM SSE 事件失败:{error}")))?;
|
||
let first_choice = parsed
|
||
.choices
|
||
.first()
|
||
.ok_or_else(|| LlmError::Deserialize("LLM SSE 响应缺少 choices[0]".to_string()))?;
|
||
|
||
Ok(Some(ParsedStreamEvent {
|
||
delta_text: extract_message_text(first_choice),
|
||
finish_reason: first_choice.finish_reason.clone(),
|
||
}))
|
||
}
|
||
|
||
fn parse_responses_sse_event(data: &str) -> Result<Option<ParsedStreamEvent>, LlmError> {
|
||
let parsed: serde_json::Value = serde_json::from_str(data).map_err(|error| {
|
||
LlmError::Deserialize(format!("解析 LLM Responses SSE 事件失败:{error}"))
|
||
})?;
|
||
let event_type = parsed
|
||
.get("type")
|
||
.and_then(serde_json::Value::as_str)
|
||
.unwrap_or_default();
|
||
|
||
match event_type {
|
||
"response.output_text.delta" => Ok(Some(ParsedStreamEvent {
|
||
delta_text: parsed
|
||
.get("delta")
|
||
.and_then(serde_json::Value::as_str)
|
||
.map(str::to_string),
|
||
finish_reason: None,
|
||
})),
|
||
"response.completed" => Ok(Some(ParsedStreamEvent {
|
||
delta_text: None,
|
||
finish_reason: Some("completed".to_string()),
|
||
})),
|
||
"response.failed" | "error" => {
|
||
let message = parsed
|
||
.get("error")
|
||
.and_then(|error| error.get("message"))
|
||
.and_then(serde_json::Value::as_str)
|
||
.or_else(|| parsed.get("message").and_then(serde_json::Value::as_str))
|
||
.unwrap_or("LLM Responses SSE 返回失败事件")
|
||
.to_string();
|
||
Err(LlmError::Upstream {
|
||
status_code: 502,
|
||
message,
|
||
})
|
||
}
|
||
_ => Ok(None),
|
||
}
|
||
}
|
||
|
||
fn should_retry_status(status: StatusCode) -> bool {
|
||
status == StatusCode::REQUEST_TIMEOUT
|
||
|| status == StatusCode::TOO_MANY_REQUESTS
|
||
|| status.is_server_error()
|
||
}
|
||
|
||
fn extract_api_error_message(raw_text: &str, fallback_message: &str) -> String {
|
||
let trimmed = raw_text.trim();
|
||
if trimmed.is_empty() {
|
||
return fallback_message.to_string();
|
||
}
|
||
|
||
let parsed = serde_json::from_str::<serde_json::Value>(trimmed);
|
||
if let Ok(value) = parsed {
|
||
if let Some(message) = value
|
||
.get("error")
|
||
.and_then(|error| error.get("message"))
|
||
.and_then(serde_json::Value::as_str)
|
||
.map(str::trim)
|
||
.filter(|message| !message.is_empty())
|
||
{
|
||
return message.to_string();
|
||
}
|
||
|
||
if let Some(message) = value
|
||
.get("message")
|
||
.and_then(serde_json::Value::as_str)
|
||
.map(str::trim)
|
||
.filter(|message| !message.is_empty())
|
||
{
|
||
return message.to_string();
|
||
}
|
||
}
|
||
|
||
trimmed.to_string()
|
||
}
|
||
|
||
fn map_stream_read_error(error: reqwest::Error, attempts: u32) -> LlmError {
|
||
if error.is_timeout() {
|
||
return LlmError::Timeout { attempts };
|
||
}
|
||
|
||
if error.is_connect() {
|
||
return LlmError::Connectivity {
|
||
attempts,
|
||
message: error.to_string(),
|
||
};
|
||
}
|
||
|
||
LlmError::Transport(error.to_string())
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use std::{
|
||
io::{Read, Write},
|
||
net::TcpListener,
|
||
thread,
|
||
time::Duration as StdDuration,
|
||
};
|
||
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn llm_error_kind_is_stable_for_adapter_mapping() {
|
||
assert_eq!(
|
||
LlmError::InvalidConfig("bad config".to_string()).kind(),
|
||
LlmErrorKind::InvalidConfig
|
||
);
|
||
assert_eq!(
|
||
LlmError::Upstream {
|
||
status_code: 429,
|
||
message: "too many requests".to_string(),
|
||
}
|
||
.kind(),
|
||
LlmErrorKind::Upstream
|
||
);
|
||
assert_eq!(LlmError::EmptyResponse.kind(), LlmErrorKind::EmptyResponse);
|
||
}
|
||
|
||
struct MockResponse {
|
||
status_line: &'static str,
|
||
content_type: &'static str,
|
||
body: String,
|
||
extra_headers: Vec<(&'static str, &'static str)>,
|
||
}
|
||
|
||
#[test]
|
||
fn llm_config_rejects_blank_api_key() {
|
||
let error = LlmConfig::new(
|
||
LlmProvider::Ark,
|
||
DEFAULT_ARK_BASE_URL.to_string(),
|
||
" ".to_string(),
|
||
"model-a".to_string(),
|
||
DEFAULT_REQUEST_TIMEOUT_MS,
|
||
DEFAULT_MAX_RETRIES,
|
||
DEFAULT_RETRY_BACKOFF_MS,
|
||
)
|
||
.expect_err("blank api key should be rejected");
|
||
|
||
assert_eq!(
|
||
error,
|
||
LlmError::InvalidConfig("LLM api_key 不能为空".to_string())
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn llm_chat_completion_url_normalizes_trailing_slash() {
|
||
let config = LlmConfig::new(
|
||
LlmProvider::OpenAiCompatible,
|
||
"https://example.com/base///".to_string(),
|
||
"secret".to_string(),
|
||
"model-a".to_string(),
|
||
DEFAULT_REQUEST_TIMEOUT_MS,
|
||
DEFAULT_MAX_RETRIES,
|
||
DEFAULT_RETRY_BACKOFF_MS,
|
||
)
|
||
.expect("config should be valid");
|
||
|
||
assert_eq!(
|
||
config.chat_completions_url(),
|
||
"https://example.com/base/chat/completions"
|
||
);
|
||
assert_eq!(config.responses_url(), "https://example.com/base/responses");
|
||
}
|
||
|
||
#[test]
|
||
fn llm_config_official_fallback_is_opt_in() {
|
||
let config = LlmConfig::new(
|
||
LlmProvider::OpenAiCompatible,
|
||
"https://example.com/base".to_string(),
|
||
"secret".to_string(),
|
||
"model-a".to_string(),
|
||
DEFAULT_REQUEST_TIMEOUT_MS,
|
||
DEFAULT_MAX_RETRIES,
|
||
DEFAULT_RETRY_BACKOFF_MS,
|
||
)
|
||
.expect("config should be valid");
|
||
|
||
assert!(!config.official_fallback());
|
||
assert!(config.with_official_fallback(true).official_fallback());
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn request_text_sends_official_fallback_for_openai_compatible_clients() {
|
||
let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind");
|
||
let address = listener.local_addr().expect("listener should have addr");
|
||
let server_handle = thread::spawn(move || {
|
||
let (mut stream, _) = listener.accept().expect("request should connect");
|
||
let request_text = read_request(&mut stream);
|
||
write_response(
|
||
&mut stream,
|
||
MockResponse {
|
||
status_line: "200 OK",
|
||
content_type: "application/json; charset=utf-8",
|
||
body: r#"{"id":"resp_openai_compatible","model":"gpt-5","output_text":"兼容成功","status":"completed"}"#.to_string(),
|
||
extra_headers: Vec::new(),
|
||
},
|
||
);
|
||
request_text
|
||
});
|
||
|
||
let config = LlmConfig::new(
|
||
LlmProvider::OpenAiCompatible,
|
||
format!("http://{address}"),
|
||
"test-key".to_string(),
|
||
"gpt-5".to_string(),
|
||
DEFAULT_REQUEST_TIMEOUT_MS,
|
||
0,
|
||
1,
|
||
)
|
||
.expect("config should be valid")
|
||
.with_official_fallback(true);
|
||
let client = LlmClient::new(config).expect("client should be created");
|
||
let response = client
|
||
.request_text(LlmTextRequest::single_turn("系统", "用户").with_responses_api())
|
||
.await
|
||
.expect("request_text should succeed");
|
||
|
||
let request_text = server_handle.join().expect("server thread should join");
|
||
let request_body = request_text
|
||
.split("\r\n\r\n")
|
||
.nth(1)
|
||
.expect("request body should exist");
|
||
let request_json: serde_json::Value =
|
||
serde_json::from_str(request_body).expect("request body should be json");
|
||
|
||
assert_eq!(response.content, "兼容成功");
|
||
assert_eq!(request_json["official_fallback"], serde_json::json!(true));
|
||
}
|
||
|
||
#[test]
|
||
fn sse_parser_handles_split_chunks_and_done_marker() {
|
||
let mut parser = OpenAiCompatibleSseParser::new(LlmTextProtocol::ChatCompletions);
|
||
let events_a = parser
|
||
.push_chunk("data: {\"choices\":[{\"delta\":{\"content\":\"你\"}}]}\r\n\r\n")
|
||
.expect("first chunk should parse");
|
||
let events_b = parser
|
||
.push_chunk("data: {\"choices\":[{\"delta\":{\"content\":\"好\"},\"finish_reason\":\"stop\"}]}\n\ndata: [DONE]\n\n")
|
||
.expect("second chunk should parse");
|
||
|
||
assert_eq!(events_a.len(), 1);
|
||
assert_eq!(events_a[0].delta_text.as_deref(), Some("你"));
|
||
assert_eq!(events_b.len(), 1);
|
||
assert_eq!(events_b[0].delta_text.as_deref(), Some("好"));
|
||
assert_eq!(events_b[0].finish_reason.as_deref(), Some("stop"));
|
||
}
|
||
|
||
#[test]
|
||
fn responses_sse_parser_only_emits_output_text_delta() {
|
||
let mut parser = OpenAiCompatibleSseParser::new(LlmTextProtocol::Responses);
|
||
let events = parser
|
||
.push_chunk(concat!(
|
||
"data: {\"type\":\"response.created\"}\n\n",
|
||
"data: {\"type\":\"response.output_text.delta\",\"delta\":\"你\"}\n\n",
|
||
"data: {\"type\":\"response.output_text.delta\",\"delta\":\"好\"}\n\n",
|
||
"data: {\"type\":\"response.completed\"}\n\n",
|
||
))
|
||
.expect("responses stream should parse");
|
||
|
||
assert_eq!(events.len(), 3);
|
||
assert_eq!(events[0].delta_text.as_deref(), Some("你"));
|
||
assert_eq!(events[1].delta_text.as_deref(), Some("好"));
|
||
assert_eq!(events[2].finish_reason.as_deref(), Some("completed"));
|
||
}
|
||
|
||
#[test]
|
||
fn decode_utf8_stream_chunk_preserves_incomplete_multibyte_suffix() {
|
||
let full_bytes = "你好".as_bytes();
|
||
let first_result = decode_utf8_stream_chunk(&full_bytes[..2])
|
||
.expect("incomplete utf-8 chunk should be buffered");
|
||
assert_eq!(first_result.0, "");
|
||
assert_eq!(first_result.1, full_bytes[..2].to_vec());
|
||
|
||
let mut combined = first_result.1;
|
||
combined.extend_from_slice(&full_bytes[2..]);
|
||
let second_result = decode_utf8_stream_chunk(combined.as_slice())
|
||
.expect("completed utf-8 bytes should decode");
|
||
assert_eq!(second_result.0, "你好");
|
||
assert!(second_result.1.is_empty());
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn request_text_parses_non_stream_response() {
|
||
let server_url = spawn_mock_server(vec![MockResponse {
|
||
status_line: "200 OK",
|
||
content_type: "application/json; charset=utf-8",
|
||
body: r#"{"id":"resp_01","model":"ark-test-model","choices":[{"message":{"content":"测试成功"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":6,"total_tokens":16}}"#.to_string(),
|
||
extra_headers: Vec::new(),
|
||
}]);
|
||
|
||
let client = build_test_client(server_url, 0);
|
||
let response = client
|
||
.request_single_message_text("系统", "用户")
|
||
.await
|
||
.expect("request_text should succeed");
|
||
|
||
assert_eq!(response.provider, LlmProvider::Ark);
|
||
assert_eq!(response.model, "ark-test-model");
|
||
assert_eq!(response.content, "测试成功");
|
||
assert_eq!(response.finish_reason.as_deref(), Some("stop"));
|
||
assert_eq!(response.response_id.as_deref(), Some("resp_01"));
|
||
assert_eq!(
|
||
response.usage,
|
||
Some(LlmTokenUsage {
|
||
prompt_tokens: 10,
|
||
completion_tokens: 6,
|
||
total_tokens: 16,
|
||
})
|
||
);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn request_text_retries_after_upstream_500() {
|
||
let server_url = spawn_mock_server(vec![
|
||
MockResponse {
|
||
status_line: "500 Internal Server Error",
|
||
content_type: "application/json; charset=utf-8",
|
||
body: r#"{"error":{"message":"temporary upstream failure"}}"#.to_string(),
|
||
extra_headers: Vec::new(),
|
||
},
|
||
MockResponse {
|
||
status_line: "200 OK",
|
||
content_type: "application/json; charset=utf-8",
|
||
body: r#"{"id":"resp_retry","choices":[{"message":{"content":"第二次成功"},"finish_reason":"stop"}]}"#.to_string(),
|
||
extra_headers: Vec::new(),
|
||
},
|
||
]);
|
||
|
||
let client = build_test_client(server_url, 1);
|
||
let response = client
|
||
.request_single_message_text("系统", "用户")
|
||
.await
|
||
.expect("second attempt should succeed");
|
||
|
||
assert_eq!(response.content, "第二次成功");
|
||
assert_eq!(response.response_id.as_deref(), Some("resp_retry"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn request_text_uses_request_level_timeout_override() {
|
||
let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind");
|
||
let address = listener.local_addr().expect("listener should have addr");
|
||
thread::spawn(move || {
|
||
let (mut stream, _) = listener.accept().expect("request should connect");
|
||
let _ = read_request(&mut stream);
|
||
thread::sleep(StdDuration::from_millis(200));
|
||
write_response(
|
||
&mut stream,
|
||
MockResponse {
|
||
status_line: "200 OK",
|
||
content_type: "application/json; charset=utf-8",
|
||
body:
|
||
r#"{"choices":[{"message":{"content":"too late"},"finish_reason":"stop"}]}"#
|
||
.to_string(),
|
||
extra_headers: Vec::new(),
|
||
},
|
||
);
|
||
});
|
||
|
||
let config = LlmConfig::new(
|
||
LlmProvider::Ark,
|
||
format!("http://{address}"),
|
||
"test-key".to_string(),
|
||
"test-model".to_string(),
|
||
10_000,
|
||
0,
|
||
1,
|
||
)
|
||
.expect("config should be valid");
|
||
let client = LlmClient::new(config).expect("client should be created");
|
||
|
||
let error = client
|
||
.request_text(LlmTextRequest::single_turn("系统", "用户").with_request_timeout_ms(20))
|
||
.await
|
||
.expect_err("request override should timeout before the global timeout");
|
||
|
||
assert_eq!(error, LlmError::Timeout { attempts: 1 });
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn request_text_sends_web_search_options_when_enabled() {
|
||
let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind");
|
||
let address = listener.local_addr().expect("listener should have addr");
|
||
let server_handle = thread::spawn(move || {
|
||
let (mut stream, _) = listener.accept().expect("request should connect");
|
||
let request_text = read_request(&mut stream);
|
||
write_response(
|
||
&mut stream,
|
||
MockResponse {
|
||
status_line: "200 OK",
|
||
content_type: "application/json; charset=utf-8",
|
||
body: r#"{"id":"resp_search","model":"test-model","choices":[{"message":{"content":"搜索成功"},"finish_reason":"stop"}]}"#.to_string(),
|
||
extra_headers: Vec::new(),
|
||
},
|
||
);
|
||
request_text
|
||
});
|
||
|
||
let client = build_test_client(format!("http://{address}"), 0);
|
||
let response = client
|
||
.request_text(
|
||
LlmTextRequest::single_turn("系统", "用户")
|
||
.with_web_search(true)
|
||
.with_max_tokens(128),
|
||
)
|
||
.await
|
||
.expect("request_text should succeed");
|
||
|
||
let request_text = server_handle.join().expect("server thread should join");
|
||
let request_body = request_text
|
||
.split("\r\n\r\n")
|
||
.nth(1)
|
||
.expect("request body should exist");
|
||
let request_json: serde_json::Value =
|
||
serde_json::from_str(request_body).expect("request body should be json");
|
||
|
||
assert_eq!(response.content, "搜索成功");
|
||
assert_eq!(request_json["web_search_options"], serde_json::json!({}));
|
||
assert!(request_json.get("official_fallback").is_none());
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn chat_completions_multimodal_request_sends_text_and_image_url_parts() {
|
||
let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind");
|
||
let address = listener.local_addr().expect("listener should have addr");
|
||
let server_handle = thread::spawn(move || {
|
||
let (mut stream, _) = listener.accept().expect("request should connect");
|
||
let request_text = read_request(&mut stream);
|
||
write_response(
|
||
&mut stream,
|
||
MockResponse {
|
||
status_line: "200 OK",
|
||
content_type: "application/json; charset=utf-8",
|
||
body: r#"{"id":"chat_multimodal","model":"gpt-4o-mini","choices":[{"message":{"content":"{\"levelName\":\"雨夜猫街\"}"},"finish_reason":"stop"}]}"#.to_string(),
|
||
extra_headers: Vec::new(),
|
||
},
|
||
);
|
||
request_text
|
||
});
|
||
|
||
let config = LlmConfig::new(
|
||
LlmProvider::OpenAiCompatible,
|
||
format!("http://{address}"),
|
||
"test-key".to_string(),
|
||
"gpt-4o-mini".to_string(),
|
||
DEFAULT_REQUEST_TIMEOUT_MS,
|
||
0,
|
||
1,
|
||
)
|
||
.expect("config should be valid")
|
||
.with_official_fallback(true);
|
||
let client = LlmClient::new(config).expect("client should be created");
|
||
let response = client
|
||
.request_text(LlmTextRequest::new(vec![
|
||
LlmMessage::system("你是拼图关卡命名编辑"),
|
||
LlmMessage::user_multimodal(vec![
|
||
LlmMessageContentPart::InputText {
|
||
text: "画面描述:一只猫在雨夜灯牌下回头。".to_string(),
|
||
},
|
||
LlmMessageContentPart::InputImage {
|
||
image_url: "data:image/png;base64,abcd".to_string(),
|
||
},
|
||
]),
|
||
]))
|
||
.await
|
||
.expect("request_text should succeed");
|
||
|
||
let request_text = server_handle.join().expect("server thread should join");
|
||
let request_line = request_text.lines().next().unwrap_or_default();
|
||
let request_body = request_text
|
||
.split("\r\n\r\n")
|
||
.nth(1)
|
||
.expect("request body should exist");
|
||
let request_json: serde_json::Value =
|
||
serde_json::from_str(request_body).expect("request body should be json");
|
||
|
||
assert!(request_line.contains("POST /chat/completions HTTP/1.1"));
|
||
assert_eq!(response.model, "gpt-4o-mini");
|
||
assert_eq!(response.content, r#"{"levelName":"雨夜猫街"}"#);
|
||
assert_eq!(request_json["official_fallback"], serde_json::json!(true));
|
||
assert_eq!(
|
||
request_json["messages"][1]["content"],
|
||
serde_json::json!([
|
||
{ "type": "text", "text": "画面描述:一只猫在雨夜灯牌下回头。" },
|
||
{ "type": "image_url", "image_url": { "url": "data:image/png;base64,abcd" } }
|
||
])
|
||
);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn request_text_sends_responses_body_with_web_search_tool() {
|
||
let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind");
|
||
let address = listener.local_addr().expect("listener should have addr");
|
||
let server_handle = thread::spawn(move || {
|
||
let (mut stream, _) = listener.accept().expect("request should connect");
|
||
let request_text = read_request(&mut stream);
|
||
write_response(
|
||
&mut stream,
|
||
MockResponse {
|
||
status_line: "200 OK",
|
||
content_type: "application/json; charset=utf-8",
|
||
body: r#"{"id":"resp_responses","model":"deepseek-v3-2-251201","output_text":"Responses 成功","status":"completed","usage":{"input_tokens":9,"output_tokens":4,"total_tokens":13}}"#.to_string(),
|
||
extra_headers: Vec::new(),
|
||
},
|
||
);
|
||
request_text
|
||
});
|
||
|
||
let client = build_test_client(format!("http://{address}"), 0);
|
||
let response = client
|
||
.request_text(
|
||
LlmTextRequest::single_turn("系统", "用户")
|
||
.with_model("deepseek-v3-2-251201")
|
||
.with_responses_api()
|
||
.with_web_search(true)
|
||
.with_max_tokens(128),
|
||
)
|
||
.await
|
||
.expect("responses request_text should succeed");
|
||
|
||
let request_text = server_handle.join().expect("server thread should join");
|
||
let request_line = request_text.lines().next().unwrap_or_default();
|
||
let request_body = request_text
|
||
.split("\r\n\r\n")
|
||
.nth(1)
|
||
.expect("request body should exist");
|
||
let request_json: serde_json::Value =
|
||
serde_json::from_str(request_body).expect("request body should be json");
|
||
|
||
assert!(request_line.contains("POST /responses HTTP/1.1"));
|
||
assert_eq!(response.content, "Responses 成功");
|
||
assert_eq!(response.model, "deepseek-v3-2-251201");
|
||
assert_eq!(
|
||
response.usage,
|
||
Some(LlmTokenUsage {
|
||
prompt_tokens: 9,
|
||
completion_tokens: 4,
|
||
total_tokens: 13,
|
||
})
|
||
);
|
||
assert_eq!(
|
||
request_json["model"],
|
||
serde_json::json!("deepseek-v3-2-251201")
|
||
);
|
||
assert_eq!(request_json["stream"], serde_json::json!(false));
|
||
assert_eq!(
|
||
request_json["tools"],
|
||
serde_json::json!([{ "type": "web_search", "max_keyword": 3 }])
|
||
);
|
||
assert!(request_json.get("official_fallback").is_none());
|
||
assert_eq!(
|
||
request_json["input"][0]["content"][0],
|
||
serde_json::json!({ "type": "input_text", "text": "系统" })
|
||
);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn responses_multimodal_request_sends_input_text_and_input_image() {
|
||
let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind");
|
||
let address = listener.local_addr().expect("listener should have addr");
|
||
let server_handle = thread::spawn(move || {
|
||
let (mut stream, _) = listener.accept().expect("request should connect");
|
||
let request_text = read_request(&mut stream);
|
||
write_response(
|
||
&mut stream,
|
||
MockResponse {
|
||
status_line: "200 OK",
|
||
content_type: "application/json; charset=utf-8",
|
||
body: r#"{"id":"resp_multimodal","model":"gpt-5","output_text":"多模态成功","status":"completed"}"#.to_string(),
|
||
extra_headers: Vec::new(),
|
||
},
|
||
);
|
||
request_text
|
||
});
|
||
|
||
let client = build_test_client(format!("http://{address}"), 0);
|
||
let response = client
|
||
.request_text(
|
||
LlmTextRequest::new(vec![
|
||
LlmMessage::system("你是创意互动内容生成 Agent"),
|
||
LlmMessage::user_multimodal(vec![
|
||
LlmMessageContentPart::InputText {
|
||
text: "把这张图做成拼图".to_string(),
|
||
},
|
||
LlmMessageContentPart::InputImage {
|
||
image_url: "https://example.com/ref.png".to_string(),
|
||
},
|
||
]),
|
||
])
|
||
.with_model("gpt-5")
|
||
.with_responses_api(),
|
||
)
|
||
.await
|
||
.expect("responses multimodal request_text should succeed");
|
||
|
||
let request_text = server_handle.join().expect("server thread should join");
|
||
let request_body = request_text
|
||
.split("\r\n\r\n")
|
||
.nth(1)
|
||
.expect("request body should exist");
|
||
let request_json: serde_json::Value =
|
||
serde_json::from_str(request_body).expect("request body should be json");
|
||
|
||
assert_eq!(response.model, "gpt-5");
|
||
assert_eq!(request_json["model"], serde_json::json!("gpt-5"));
|
||
assert!(request_json.get("official_fallback").is_none());
|
||
assert_eq!(
|
||
request_json["input"][1]["content"],
|
||
serde_json::json!([
|
||
{ "type": "input_text", "text": "把这张图做成拼图" },
|
||
{ "type": "input_image", "image_url": "https://example.com/ref.png" }
|
||
])
|
||
);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn stream_text_accumulates_sse_response() {
|
||
let server_url = spawn_mock_server(vec![MockResponse {
|
||
status_line: "200 OK",
|
||
content_type: "text/event-stream; charset=utf-8",
|
||
body: concat!(
|
||
"data: {\"choices\":[{\"delta\":{\"content\":\"你\"}}]}\n\n",
|
||
"data: {\"choices\":[{\"delta\":{\"content\":\"好\"}}]}\n\n",
|
||
"data: {\"choices\":[{\"finish_reason\":\"stop\"}]}\n\n",
|
||
"data: [DONE]\n\n"
|
||
)
|
||
.to_string(),
|
||
extra_headers: vec![("x-request-id", "req_stream_01")],
|
||
}]);
|
||
|
||
let client = build_test_client(server_url, 0);
|
||
let mut updates = Vec::new();
|
||
let response = client
|
||
.stream_single_message_text("系统", "用户", |delta| {
|
||
updates.push(delta.accumulated_text.clone());
|
||
})
|
||
.await
|
||
.expect("stream_text should succeed");
|
||
|
||
assert_eq!(updates, vec!["你".to_string(), "你好".to_string()]);
|
||
assert_eq!(response.content, "你好");
|
||
assert_eq!(response.finish_reason.as_deref(), Some("stop"));
|
||
assert_eq!(response.response_id.as_deref(), Some("req_stream_01"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn stream_text_accumulates_responses_sse_response() {
|
||
let server_url = spawn_mock_server(vec![MockResponse {
|
||
status_line: "200 OK",
|
||
content_type: "text/event-stream; charset=utf-8",
|
||
body: concat!(
|
||
"data: {\"type\":\"response.output_text.delta\",\"delta\":\"你\"}\n\n",
|
||
"data: {\"type\":\"response.output_text.delta\",\"delta\":\"好\"}\n\n",
|
||
"data: {\"type\":\"response.completed\"}\n\n"
|
||
)
|
||
.to_string(),
|
||
extra_headers: vec![("x-request-id", "req_responses_stream_01")],
|
||
}]);
|
||
|
||
let client = build_test_client(server_url, 0);
|
||
let mut updates = Vec::new();
|
||
let response = client
|
||
.stream_text(
|
||
LlmTextRequest::single_turn("系统", "用户").with_responses_api(),
|
||
|delta| {
|
||
updates.push(delta.accumulated_text.clone());
|
||
},
|
||
)
|
||
.await
|
||
.expect("responses stream_text should succeed");
|
||
|
||
assert_eq!(updates, vec!["你".to_string(), "你好".to_string()]);
|
||
assert_eq!(response.content, "你好");
|
||
assert_eq!(response.finish_reason.as_deref(), Some("completed"));
|
||
assert_eq!(
|
||
response.response_id.as_deref(),
|
||
Some("req_responses_stream_01")
|
||
);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn request_text_writes_raw_failure_logs_after_parse_error() {
|
||
let log_dir = std::env::temp_dir().join(format!(
|
||
"platform-llm-raw-log-test-{}",
|
||
build_llm_raw_log_prefix("parse_error")
|
||
));
|
||
unsafe {
|
||
std::env::set_var("LLM_RAW_LOG_DIR", &log_dir);
|
||
}
|
||
|
||
let server_url = spawn_mock_server(vec![MockResponse {
|
||
status_line: "200 OK",
|
||
content_type: "application/json; charset=utf-8",
|
||
body: "不是合法 JSON".to_string(),
|
||
extra_headers: Vec::new(),
|
||
}]);
|
||
|
||
let client = build_test_client(server_url, 0);
|
||
let error = client
|
||
.request_single_message_text("系统原文", "用户原文")
|
||
.await
|
||
.expect_err("invalid json should fail");
|
||
|
||
assert!(matches!(error, LlmError::Deserialize(_)));
|
||
let mut input_logs = Vec::new();
|
||
let mut output_logs = Vec::new();
|
||
for entry in fs::read_dir(&log_dir).expect("log dir should exist") {
|
||
let path = entry.expect("log entry should be readable").path();
|
||
let file_name = path
|
||
.file_name()
|
||
.and_then(|name| name.to_str())
|
||
.unwrap_or_default()
|
||
.to_string();
|
||
if file_name.ends_with(".input.json") {
|
||
input_logs.push(path);
|
||
} else if file_name.ends_with(".output.txt") {
|
||
output_logs.push(path);
|
||
}
|
||
}
|
||
|
||
assert_eq!(input_logs.len(), 1);
|
||
assert_eq!(output_logs.len(), 1);
|
||
let input_text = fs::read_to_string(&input_logs[0]).expect("input log should be readable");
|
||
let output_text =
|
||
fs::read_to_string(&output_logs[0]).expect("output log should be readable");
|
||
assert!(input_text.contains("系统原文"));
|
||
assert!(input_text.contains("用户原文"));
|
||
assert!(!input_text.contains("test-key"));
|
||
assert_eq!(output_text, "不是合法 JSON");
|
||
|
||
unsafe {
|
||
std::env::remove_var("LLM_RAW_LOG_DIR");
|
||
}
|
||
fs::remove_dir_all(log_dir).expect("log dir should be removed");
|
||
}
|
||
|
||
fn build_test_client(base_url: String, max_retries: u32) -> LlmClient {
|
||
let config = LlmConfig::new(
|
||
LlmProvider::Ark,
|
||
base_url,
|
||
"test-key".to_string(),
|
||
"test-model".to_string(),
|
||
DEFAULT_REQUEST_TIMEOUT_MS,
|
||
max_retries,
|
||
1,
|
||
)
|
||
.expect("config should be valid");
|
||
|
||
LlmClient::new(config).expect("client should be created")
|
||
}
|
||
|
||
fn spawn_mock_server(responses: Vec<MockResponse>) -> String {
|
||
let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind");
|
||
let address = listener.local_addr().expect("listener should have addr");
|
||
|
||
thread::spawn(move || {
|
||
for response in responses {
|
||
let (mut stream, _) = listener.accept().expect("request should connect");
|
||
read_request(&mut stream);
|
||
write_response(&mut stream, response);
|
||
}
|
||
});
|
||
|
||
format!("http://{address}")
|
||
}
|
||
|
||
fn read_request(stream: &mut std::net::TcpStream) -> String {
|
||
stream
|
||
.set_read_timeout(Some(StdDuration::from_secs(1)))
|
||
.expect("read timeout should be set");
|
||
let mut buffer = Vec::new();
|
||
let mut chunk = [0_u8; 1024];
|
||
let mut expected_total = None;
|
||
|
||
loop {
|
||
match stream.read(&mut chunk) {
|
||
Ok(0) => break,
|
||
Ok(bytes_read) => {
|
||
buffer.extend_from_slice(&chunk[..bytes_read]);
|
||
|
||
if expected_total.is_none()
|
||
&& let Some(header_end) = find_header_end(&buffer)
|
||
{
|
||
let content_length =
|
||
read_content_length(&buffer[..header_end]).unwrap_or(0);
|
||
expected_total = Some(header_end + content_length);
|
||
}
|
||
|
||
if let Some(total_bytes) = expected_total
|
||
&& buffer.len() >= total_bytes
|
||
{
|
||
break;
|
||
}
|
||
}
|
||
Err(error)
|
||
if error.kind() == std::io::ErrorKind::WouldBlock
|
||
|| error.kind() == std::io::ErrorKind::TimedOut =>
|
||
{
|
||
break;
|
||
}
|
||
Err(error) => panic!("mock server failed to read request: {error}"),
|
||
}
|
||
}
|
||
|
||
String::from_utf8_lossy(buffer.as_slice()).to_string()
|
||
}
|
||
|
||
fn write_response(stream: &mut std::net::TcpStream, response: MockResponse) {
|
||
let body = response.body;
|
||
let mut raw_response = format!(
|
||
"HTTP/1.1 {}\r\nContent-Type: {}\r\nContent-Length: {}\r\nConnection: close\r\n",
|
||
response.status_line,
|
||
response.content_type,
|
||
body.len()
|
||
);
|
||
for (name, value) in response.extra_headers {
|
||
raw_response.push_str(format!("{name}: {value}\r\n").as_str());
|
||
}
|
||
raw_response.push_str("\r\n");
|
||
raw_response.push_str(body.as_str());
|
||
|
||
stream
|
||
.write_all(raw_response.as_bytes())
|
||
.expect("mock response should be written");
|
||
stream.flush().expect("mock response should flush");
|
||
}
|
||
|
||
fn find_header_end(buffer: &[u8]) -> Option<usize> {
|
||
buffer
|
||
.windows(4)
|
||
.position(|window| window == b"\r\n\r\n")
|
||
.map(|index| index + 4)
|
||
}
|
||
|
||
fn read_content_length(headers: &[u8]) -> Option<usize> {
|
||
let text = String::from_utf8_lossy(headers);
|
||
text.lines().find_map(|line| {
|
||
let (name, value) = line.split_once(':')?;
|
||
if name.eq_ignore_ascii_case("content-length") {
|
||
return value.trim().parse::<usize>().ok();
|
||
}
|
||
None
|
||
})
|
||
}
|
||
}
|