Files
Genarrative/server-rs/crates/platform-llm/src/lib.rs
2026-04-30 17:49:07 +08:00

1885 lines
61 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use std::{
env,
error::Error,
fmt, fs,
path::PathBuf,
str as std_str,
sync::atomic::{AtomicU64, Ordering},
time::{Duration, SystemTime, UNIX_EPOCH},
};
use log::{debug, warn};
use reqwest::{Client, StatusCode};
use serde::{Deserialize, Serialize};
use tokio::time::sleep;
pub const DEFAULT_ARK_BASE_URL: &str = "https://ark.cn-beijing.volces.com/api/v3";
pub const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 30_000;
pub const DEFAULT_MAX_RETRIES: u32 = 1;
pub const DEFAULT_RETRY_BACKOFF_MS: u64 = 500;
pub const CHAT_COMPLETIONS_PATH: &str = "/chat/completions";
pub const RESPONSES_PATH: &str = "/responses";
const DEFAULT_LLM_RAW_LOG_DIR: &str = "logs/llm-raw";
static LLM_RAW_LOG_SEQUENCE: AtomicU64 = AtomicU64::new(1);
// 冻结平台来源,避免上层继续散落 provider 字符串。
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LlmProvider {
Ark,
DashScope,
OpenAiCompatible,
}
// 统一收口文本模型网关配置,避免 api-server 和业务模块各自重复解析环境变量。
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct LlmConfig {
provider: LlmProvider,
base_url: String,
api_key: String,
model: String,
request_timeout_ms: u64,
max_retries: u32,
retry_backoff_ms: u64,
}
// 首版只冻结当前项目已稳定使用的 system/user/assistant 三种消息角色。
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LlmMessageRole {
System,
User,
Assistant,
}
// 单条消息保持 OpenAI 兼容格式,供统一请求体直接序列化。
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct LlmMessage {
pub role: LlmMessageRole,
pub content: String,
}
// 文本补全请求冻结为“消息列表 + 可选模型覆盖 + 可选 max_tokens”最小闭环。
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct LlmTextRequest {
pub model: Option<String>,
pub messages: Vec<LlmMessage>,
pub max_tokens: Option<u32>,
pub enable_web_search: bool,
pub protocol: LlmTextProtocol,
}
// 文本协议必须由业务请求显式选择,避免全局默认模型把不同场景混到同一上游形态。
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum LlmTextProtocol {
ChatCompletions,
Responses,
}
// 上层在流式消费时拿到的是“累计文本 + 当前增量”,避免每层重新自己拼接。
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct LlmStreamDelta {
pub accumulated_text: String,
pub delta_text: String,
pub finish_reason: Option<String>,
}
// 用于保留 token 计数,后续模块可以决定是否写入审计或成本统计。
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct LlmTokenUsage {
pub prompt_tokens: u64,
pub completion_tokens: u64,
pub total_tokens: u64,
}
// 统一文本响应,避免业务层再去解析 choices/message/content。
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct LlmTextResponse {
pub provider: LlmProvider,
pub model: String,
pub content: String,
pub finish_reason: Option<String>,
pub response_id: Option<String>,
pub usage: Option<LlmTokenUsage>,
}
// 将上游错误归一到稳定的领域枚举,后续 api-server 可以直接映射成 HTTP error contract。
#[derive(Debug, PartialEq, Eq)]
pub enum LlmError {
InvalidConfig(String),
InvalidRequest(String),
Timeout { attempts: u32 },
Connectivity { attempts: u32, message: String },
Upstream { status_code: u16, message: String },
StreamUnavailable,
EmptyResponse,
Transport(String),
Deserialize(String),
}
// 统一 OpenAI 兼容文本网关 client。
#[derive(Clone, Debug)]
pub struct LlmClient {
config: LlmConfig,
http_client: Client,
}
#[derive(Serialize)]
#[serde(untagged)]
enum LlmRequestBody {
ChatCompletions(ChatCompletionsRequestBody),
Responses(ResponsesRequestBody),
}
#[derive(Serialize)]
struct ChatCompletionsRequestBody {
model: String,
messages: Vec<LlmMessage>,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
web_search_options: Option<ChatCompletionsWebSearchOptions>,
}
#[derive(Serialize)]
struct ChatCompletionsWebSearchOptions {}
#[derive(Serialize)]
struct ResponsesRequestBody {
model: String,
stream: bool,
input: Vec<ResponsesInputMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
max_output_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<ResponsesWebSearchTool>>,
}
#[derive(Serialize)]
struct ResponsesInputMessage {
role: &'static str,
content: Vec<ResponsesInputContentPart>,
}
#[derive(Serialize)]
struct ResponsesInputContentPart {
#[serde(rename = "type")]
part_type: &'static str,
text: String,
}
#[derive(Serialize)]
struct ResponsesWebSearchTool {
#[serde(rename = "type")]
tool_type: &'static str,
max_keyword: u8,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct LlmRawFailureInputLog<'a> {
provider: &'static str,
protocol: &'static str,
model: &'a str,
stream: bool,
attempt: u32,
max_tokens: Option<u32>,
messages: &'a [LlmMessage],
}
#[derive(Deserialize)]
struct ChatCompletionsResponseEnvelope {
id: Option<String>,
model: Option<String>,
choices: Vec<ChatCompletionsChoice>,
usage: Option<LlmTokenUsage>,
}
#[derive(Deserialize)]
struct ChatCompletionsChoice {
#[serde(default)]
message: Option<ChatCompletionsMessage>,
#[serde(default)]
delta: Option<ChatCompletionsMessage>,
#[serde(default)]
finish_reason: Option<String>,
}
#[derive(Deserialize)]
struct ChatCompletionsMessage {
#[serde(default)]
content: Option<ChatCompletionsContent>,
}
#[derive(Deserialize)]
#[serde(untagged)]
enum ChatCompletionsContent {
Text(String),
Parts(Vec<ChatCompletionsContentPart>),
}
#[derive(Deserialize)]
struct ChatCompletionsContentPart {
#[serde(rename = "type")]
#[allow(dead_code)]
part_type: Option<String>,
#[serde(default)]
text: Option<String>,
}
#[derive(Deserialize)]
struct ResponsesResponseEnvelope {
id: Option<String>,
model: Option<String>,
#[serde(default)]
output_text: Option<String>,
#[serde(default)]
output: Vec<ResponsesOutputItem>,
#[serde(default)]
status: Option<String>,
usage: Option<ResponsesUsage>,
}
#[derive(Deserialize)]
struct ResponsesOutputItem {
#[serde(default)]
content: Vec<ResponsesOutputContentPart>,
}
#[derive(Deserialize)]
struct ResponsesOutputContentPart {
#[serde(rename = "type")]
#[allow(dead_code)]
part_type: Option<String>,
#[serde(default)]
text: Option<String>,
}
#[derive(Deserialize)]
struct ResponsesUsage {
#[serde(default)]
input_tokens: u64,
#[serde(default)]
output_tokens: u64,
#[serde(default)]
total_tokens: u64,
}
struct OpenAiCompatibleSseParser {
buffer: String,
raw_text: String,
protocol: LlmTextProtocol,
}
#[derive(Debug)]
struct ParsedStreamEvent {
delta_text: Option<String>,
finish_reason: Option<String>,
}
impl LlmProvider {
pub fn as_str(&self) -> &'static str {
match self {
Self::Ark => "ark",
Self::DashScope => "dash_scope",
Self::OpenAiCompatible => "openai_compatible",
}
}
}
impl LlmConfig {
#[allow(clippy::too_many_arguments)]
pub fn new(
provider: LlmProvider,
base_url: String,
api_key: String,
model: String,
request_timeout_ms: u64,
max_retries: u32,
retry_backoff_ms: u64,
) -> Result<Self, LlmError> {
let base_url = normalize_non_empty(base_url, "LLM base_url 不能为空")?;
let api_key = normalize_non_empty(api_key, "LLM api_key 不能为空")?;
let model = normalize_non_empty(model, "LLM model 不能为空")?;
if request_timeout_ms == 0 {
return Err(LlmError::InvalidConfig(
"LLM request_timeout_ms 必须大于 0".to_string(),
));
}
Ok(Self {
provider,
base_url,
api_key,
model,
request_timeout_ms,
max_retries,
retry_backoff_ms,
})
}
pub fn ark_default(api_key: String, model: String) -> Result<Self, LlmError> {
Self::new(
LlmProvider::Ark,
DEFAULT_ARK_BASE_URL.to_string(),
api_key,
model,
DEFAULT_REQUEST_TIMEOUT_MS,
DEFAULT_MAX_RETRIES,
DEFAULT_RETRY_BACKOFF_MS,
)
}
pub fn provider(&self) -> LlmProvider {
self.provider
}
pub fn base_url(&self) -> &str {
&self.base_url
}
pub fn api_key(&self) -> &str {
&self.api_key
}
pub fn model(&self) -> &str {
&self.model
}
pub fn request_timeout_ms(&self) -> u64 {
self.request_timeout_ms
}
pub fn max_retries(&self) -> u32 {
self.max_retries
}
pub fn retry_backoff_ms(&self) -> u64 {
self.retry_backoff_ms
}
pub fn chat_completions_url(&self) -> String {
format!(
"{}/{}",
self.base_url.trim_end_matches('/'),
CHAT_COMPLETIONS_PATH.trim_start_matches('/')
)
}
pub fn responses_url(&self) -> String {
format!(
"{}/{}",
self.base_url.trim_end_matches('/'),
RESPONSES_PATH.trim_start_matches('/')
)
}
}
impl LlmMessage {
pub fn new(role: LlmMessageRole, content: impl Into<String>) -> Self {
Self {
role,
content: content.into(),
}
}
pub fn system(content: impl Into<String>) -> Self {
Self::new(LlmMessageRole::System, content)
}
pub fn user(content: impl Into<String>) -> Self {
Self::new(LlmMessageRole::User, content)
}
pub fn assistant(content: impl Into<String>) -> Self {
Self::new(LlmMessageRole::Assistant, content)
}
}
impl LlmTextRequest {
pub fn new(messages: Vec<LlmMessage>) -> Self {
Self {
model: None,
messages,
max_tokens: None,
enable_web_search: false,
protocol: LlmTextProtocol::ChatCompletions,
}
}
pub fn single_turn(system_prompt: impl Into<String>, user_prompt: impl Into<String>) -> Self {
Self::new(vec![
LlmMessage::system(system_prompt),
LlmMessage::user(user_prompt),
])
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_web_search(mut self, enabled: bool) -> Self {
self.enable_web_search = enabled;
self
}
pub fn with_responses_api(mut self) -> Self {
self.protocol = LlmTextProtocol::Responses;
self
}
fn validate(&self) -> Result<(), LlmError> {
if self.messages.is_empty() {
return Err(LlmError::InvalidRequest(
"LLM messages 不能为空".to_string(),
));
}
for message in &self.messages {
if message.content.trim().is_empty() {
return Err(LlmError::InvalidRequest(
"LLM message.content 不能为空".to_string(),
));
}
}
if let Some(model) = &self.model
&& model.trim().is_empty()
{
return Err(LlmError::InvalidRequest(
"LLM request.model 不能为空字符串".to_string(),
));
}
Ok(())
}
fn resolved_model<'a>(&'a self, fallback_model: &'a str) -> &'a str {
self.model
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty())
.unwrap_or(fallback_model)
}
}
impl LlmTextProtocol {
fn as_str(self) -> &'static str {
match self {
Self::ChatCompletions => "chat_completions",
Self::Responses => "responses",
}
}
}
impl fmt::Display for LlmError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidConfig(message)
| Self::InvalidRequest(message)
| Self::Transport(message)
| Self::Deserialize(message) => write!(f, "{message}"),
Self::Timeout { attempts } => {
write!(f, "LLM 请求超时,累计尝试 {attempts} 次")
}
Self::Connectivity { attempts, message } => {
write!(f, "LLM 连接失败,累计尝试 {attempts} 次:{message}")
}
Self::Upstream {
status_code,
message,
} => write!(f, "LLM 上游返回 {status_code}{message}"),
Self::StreamUnavailable => write!(f, "LLM 流式响应体不可用"),
Self::EmptyResponse => write!(f, "LLM 返回内容为空"),
}
}
}
impl Error for LlmError {}
impl LlmClient {
pub fn new(config: LlmConfig) -> Result<Self, LlmError> {
let http_client = Client::builder().build().map_err(|error| {
LlmError::InvalidConfig(format!("构建 reqwest client 失败:{error}"))
})?;
Ok(Self {
config,
http_client,
})
}
pub fn config(&self) -> &LlmConfig {
&self.config
}
pub async fn request_text(&self, request: LlmTextRequest) -> Result<LlmTextResponse, LlmError> {
request.validate()?;
let resolved_model = request.resolved_model(self.config.model()).to_string();
let response = self.execute_request(&request, false).await?;
let raw_text = response.text().await.map_err(|error| {
let llm_error = map_stream_read_error(error, 1);
log_llm_raw_failure(
&self.config,
&request,
false,
1,
"read_response_failed",
llm_error.to_string().as_str(),
);
llm_error
})?;
parse_text_response(
request.protocol,
self.config.provider(),
&resolved_model,
raw_text.as_str(),
)
.map_err(|error| {
log_llm_raw_failure(
&self.config,
&request,
false,
1,
"parse_response_failed",
raw_text.as_str(),
);
error
})
}
pub async fn request_single_message_text(
&self,
system_prompt: impl Into<String>,
user_prompt: impl Into<String>,
) -> Result<LlmTextResponse, LlmError> {
self.request_text(LlmTextRequest::single_turn(system_prompt, user_prompt))
.await
}
pub async fn stream_text<F>(
&self,
request: LlmTextRequest,
mut on_delta: F,
) -> Result<LlmTextResponse, LlmError>
where
F: FnMut(&LlmStreamDelta),
{
request.validate()?;
let resolved_model = request.resolved_model(self.config.model()).to_string();
let mut response = self.execute_request(&request, true).await?;
let response_id = response
.headers()
.get("x-request-id")
.and_then(|value| value.to_str().ok())
.map(str::to_string);
let mut parser = OpenAiCompatibleSseParser::new(request.protocol);
let mut accumulated_text = String::new();
let mut finish_reason = None;
let mut undecoded_chunk_bytes = Vec::new();
loop {
let next_chunk = response.chunk().await.map_err(|error| {
let llm_error = map_stream_read_error(error, 1);
log_llm_raw_failure(
&self.config,
&request,
true,
1,
"read_stream_failed",
parser.raw_text().as_str(),
);
llm_error
})?;
let Some(chunk) = next_chunk else {
break;
};
undecoded_chunk_bytes.extend_from_slice(chunk.as_ref());
let (chunk_text, remaining_bytes) =
decode_utf8_stream_chunk(undecoded_chunk_bytes.as_slice()).map_err(|error| {
log_llm_raw_failure(
&self.config,
&request,
true,
1,
"decode_stream_failed",
parser.raw_text().as_str(),
);
error
})?;
undecoded_chunk_bytes = remaining_bytes;
if chunk_text.is_empty() {
continue;
}
let stream_events = parser.push_chunk(chunk_text.as_ref()).map_err(|error| {
log_llm_raw_failure(
&self.config,
&request,
true,
1,
"parse_stream_failed",
parser.raw_text().as_str(),
);
error
})?;
for event in stream_events {
if let Some(delta_text) = event.delta_text
&& !delta_text.is_empty()
{
accumulated_text.push_str(delta_text.as_str());
let update = LlmStreamDelta {
accumulated_text: accumulated_text.clone(),
delta_text,
finish_reason: event.finish_reason.clone(),
};
on_delta(&update);
}
if event.finish_reason.is_some() {
finish_reason = event.finish_reason;
}
}
}
if !undecoded_chunk_bytes.is_empty() {
let trailing_text =
std_str::from_utf8(undecoded_chunk_bytes.as_slice()).map_err(|error| {
log_llm_raw_failure(
&self.config,
&request,
true,
1,
"decode_stream_failed",
parser.raw_text().as_str(),
);
LlmError::Deserialize(format!("解析 LLM 流式 UTF-8 响应失败:{error}"))
})?;
if !trailing_text.is_empty() {
let trailing_events = parser.push_chunk(trailing_text).map_err(|error| {
log_llm_raw_failure(
&self.config,
&request,
true,
1,
"parse_stream_failed",
parser.raw_text().as_str(),
);
error
})?;
for event in trailing_events {
if let Some(delta_text) = event.delta_text
&& !delta_text.is_empty()
{
accumulated_text.push_str(delta_text.as_str());
let update = LlmStreamDelta {
accumulated_text: accumulated_text.clone(),
delta_text,
finish_reason: event.finish_reason.clone(),
};
on_delta(&update);
}
if event.finish_reason.is_some() {
finish_reason = event.finish_reason;
}
}
}
}
let remaining_events = parser.finish().map_err(|error| {
log_llm_raw_failure(
&self.config,
&request,
true,
1,
"parse_stream_failed",
parser.raw_text().as_str(),
);
error
})?;
for event in remaining_events {
if let Some(delta_text) = event.delta_text
&& !delta_text.is_empty()
{
accumulated_text.push_str(delta_text.as_str());
let update = LlmStreamDelta {
accumulated_text: accumulated_text.clone(),
delta_text,
finish_reason: event.finish_reason.clone(),
};
on_delta(&update);
}
if event.finish_reason.is_some() {
finish_reason = event.finish_reason;
}
}
let content = accumulated_text.trim().to_string();
if content.is_empty() {
log_llm_raw_failure(
&self.config,
&request,
true,
1,
"empty_stream_response",
parser.raw_text().as_str(),
);
return Err(LlmError::EmptyResponse);
}
Ok(LlmTextResponse {
provider: self.config.provider(),
model: resolved_model,
content,
finish_reason,
response_id,
usage: None,
})
}
pub async fn stream_single_message_text<F>(
&self,
system_prompt: impl Into<String>,
user_prompt: impl Into<String>,
on_delta: F,
) -> Result<LlmTextResponse, LlmError>
where
F: FnMut(&LlmStreamDelta),
{
self.stream_text(
LlmTextRequest::single_turn(system_prompt, user_prompt),
on_delta,
)
.await
}
async fn execute_request(
&self,
request: &LlmTextRequest,
stream: bool,
) -> Result<reqwest::Response, LlmError> {
let request_body = build_request_body(request, self.config.model(), stream);
let model = request.resolved_model(self.config.model());
let url = match request.protocol {
LlmTextProtocol::ChatCompletions => self.config.chat_completions_url(),
LlmTextProtocol::Responses => self.config.responses_url(),
};
let max_attempts = self.config.max_retries().saturating_add(1);
for attempt in 1..=max_attempts {
debug!(
"platform-llm request started: provider={}, protocol={}, stream={}, attempt={}, model={}",
self.config.provider().as_str(),
request.protocol.as_str(),
stream,
attempt,
model
);
let send_result = self
.http_client
.post(url.as_str())
.bearer_auth(self.config.api_key())
.json(&request_body)
.timeout(Duration::from_millis(self.config.request_timeout_ms()))
.send()
.await;
match send_result {
Ok(response) if response.status().is_success() => {
debug!(
"platform-llm request succeeded: provider={}, protocol={}, stream={}, attempt={}, status={}",
self.config.provider().as_str(),
request.protocol.as_str(),
stream,
attempt,
response.status().as_u16()
);
return Ok(response);
}
Ok(response) => {
let status = response.status();
let raw_text = response.text().await.unwrap_or_default();
let message = extract_api_error_message(&raw_text, "LLM 上游请求失败");
if should_retry_status(status) && attempt < max_attempts {
warn!(
"platform-llm request retrying after upstream status: provider={}, protocol={}, attempt={}, status={}, message={}",
self.config.provider().as_str(),
request.protocol.as_str(),
attempt,
status.as_u16(),
message
);
self.sleep_before_retry(attempt).await;
continue;
}
log_llm_raw_failure(
&self.config,
request,
stream,
attempt,
"upstream_status_failed",
raw_text.as_str(),
);
return Err(LlmError::Upstream {
status_code: status.as_u16(),
message,
});
}
Err(error) if error.is_timeout() => {
if attempt < max_attempts {
warn!(
"platform-llm request retrying after timeout: provider={}, protocol={}, attempt={}",
self.config.provider().as_str(),
request.protocol.as_str(),
attempt
);
self.sleep_before_retry(attempt).await;
continue;
}
let error = LlmError::Timeout { attempts: attempt };
log_llm_raw_failure(
&self.config,
request,
stream,
attempt,
"request_timeout",
error.to_string().as_str(),
);
return Err(error);
}
Err(error) if error.is_connect() => {
let message = error.to_string();
if attempt < max_attempts {
warn!(
"platform-llm request retrying after connectivity failure: provider={}, protocol={}, attempt={}, error={}",
self.config.provider().as_str(),
request.protocol.as_str(),
attempt,
message
);
self.sleep_before_retry(attempt).await;
continue;
}
let error = LlmError::Connectivity {
attempts: attempt,
message,
};
log_llm_raw_failure(
&self.config,
request,
stream,
attempt,
"request_connectivity_failed",
error.to_string().as_str(),
);
return Err(error);
}
Err(error) => {
let error = LlmError::Transport(error.to_string());
log_llm_raw_failure(
&self.config,
request,
stream,
attempt,
"request_transport_failed",
error.to_string().as_str(),
);
return Err(error);
}
}
}
Err(LlmError::Transport(
"LLM 请求在重试循环后仍未返回结果".to_string(),
))
}
async fn sleep_before_retry(&self, attempt: u32) {
let backoff_ms = self
.config
.retry_backoff_ms()
.saturating_mul(u64::from(attempt));
if backoff_ms > 0 {
sleep(Duration::from_millis(backoff_ms)).await;
}
}
}
impl OpenAiCompatibleSseParser {
fn new(protocol: LlmTextProtocol) -> Self {
Self {
buffer: String::new(),
raw_text: String::new(),
protocol,
}
}
fn push_chunk(&mut self, chunk: &str) -> Result<Vec<ParsedStreamEvent>, LlmError> {
self.raw_text.push_str(chunk);
self.buffer.push_str(chunk);
self.buffer = self.buffer.replace("\r\n", "\n");
self.drain_complete_events()
}
fn raw_text(&self) -> String {
self.raw_text.clone()
}
fn finish(&mut self) -> Result<Vec<ParsedStreamEvent>, LlmError> {
if self.buffer.trim().is_empty() {
return Ok(Vec::new());
}
self.buffer.push_str("\n\n");
self.drain_complete_events()
}
fn drain_complete_events(&mut self) -> Result<Vec<ParsedStreamEvent>, LlmError> {
let mut events = Vec::new();
while let Some(boundary) = self.buffer.find("\n\n") {
let block = self.buffer[..boundary].to_string();
self.buffer = self.buffer[(boundary + 2)..].to_string();
if let Some(event) = parse_sse_event_block(self.protocol, block.as_str())? {
events.push(event);
}
}
Ok(events)
}
}
fn normalize_non_empty(value: String, error_message: &str) -> Result<String, LlmError> {
let trimmed = value.trim().to_string();
if trimmed.is_empty() {
return Err(LlmError::InvalidConfig(error_message.to_string()));
}
Ok(trimmed)
}
fn build_request_body(
request: &LlmTextRequest,
fallback_model: &str,
stream: bool,
) -> LlmRequestBody {
match request.protocol {
LlmTextProtocol::ChatCompletions => {
LlmRequestBody::ChatCompletions(ChatCompletionsRequestBody {
model: request.resolved_model(fallback_model).to_string(),
messages: request.messages.clone(),
stream,
max_tokens: request.max_tokens,
web_search_options: request
.enable_web_search
.then_some(ChatCompletionsWebSearchOptions {}),
})
}
LlmTextProtocol::Responses => LlmRequestBody::Responses(ResponsesRequestBody {
model: request.resolved_model(fallback_model).to_string(),
stream,
input: map_responses_input_messages(request.messages.as_slice()),
max_output_tokens: request.max_tokens,
tools: request.enable_web_search.then(|| {
vec![ResponsesWebSearchTool {
tool_type: "web_search",
max_keyword: 3,
}]
}),
}),
}
}
fn map_responses_input_messages(messages: &[LlmMessage]) -> Vec<ResponsesInputMessage> {
messages
.iter()
.map(|message| ResponsesInputMessage {
role: match message.role {
LlmMessageRole::System => "system",
LlmMessageRole::User => "user",
LlmMessageRole::Assistant => "assistant",
},
content: vec![ResponsesInputContentPart {
part_type: "input_text",
text: message.content.clone(),
}],
})
.collect()
}
fn log_llm_raw_failure(
config: &LlmConfig,
request: &LlmTextRequest,
stream: bool,
attempt: u32,
failure_stage: &str,
raw_output: &str,
) {
if let Err(error) =
write_llm_raw_failure(config, request, stream, attempt, failure_stage, raw_output)
{
warn!(
"LLM 失败原文日志落盘失败,主错误流程继续执行: failure_stage={}, error={}",
failure_stage, error
);
}
}
fn write_llm_raw_failure(
config: &LlmConfig,
request: &LlmTextRequest,
stream: bool,
attempt: u32,
failure_stage: &str,
raw_output: &str,
) -> Result<(), String> {
let log_dir = env::var("LLM_RAW_LOG_DIR")
.map(PathBuf::from)
.unwrap_or_else(|_| PathBuf::from(DEFAULT_LLM_RAW_LOG_DIR));
fs::create_dir_all(&log_dir).map_err(|error| format!("创建日志目录失败:{error}"))?;
let prefix = build_llm_raw_log_prefix(failure_stage);
let model = request.resolved_model(config.model());
let input_log = LlmRawFailureInputLog {
provider: config.provider().as_str(),
protocol: request.protocol.as_str(),
model,
stream,
attempt,
max_tokens: request.max_tokens,
messages: request.messages.as_slice(),
};
let input_text = serde_json::to_string_pretty(&input_log)
.map_err(|error| format!("序列化模型输入日志失败:{error}"))?;
fs::write(log_dir.join(format!("{prefix}.input.json")), input_text)
.map_err(|error| format!("写入模型输入日志失败:{error}"))?;
fs::write(log_dir.join(format!("{prefix}.output.txt")), raw_output)
.map_err(|error| format!("写入模型输出日志失败:{error}"))?;
Ok(())
}
fn build_llm_raw_log_prefix(failure_stage: &str) -> String {
let millis = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|duration| duration.as_millis())
.unwrap_or_default();
let sequence = LLM_RAW_LOG_SEQUENCE.fetch_add(1, Ordering::Relaxed);
let safe_stage = sanitize_log_file_segment(failure_stage);
format!("{millis}-{}-{sequence:06}-{safe_stage}", std::process::id())
}
fn sanitize_log_file_segment(value: &str) -> String {
let sanitized = value
.chars()
.map(|character| {
if character.is_ascii_alphanumeric() || character == '-' || character == '_' {
character
} else {
'_'
}
})
.collect::<String>();
if sanitized.is_empty() {
"unknown".to_string()
} else {
sanitized
}
}
fn parse_text_response(
protocol: LlmTextProtocol,
provider: LlmProvider,
fallback_model: &str,
raw_text: &str,
) -> Result<LlmTextResponse, LlmError> {
match protocol {
LlmTextProtocol::ChatCompletions => {
parse_chat_completions_response(provider, fallback_model, raw_text)
}
LlmTextProtocol::Responses => parse_responses_response(provider, fallback_model, raw_text),
}
}
fn parse_chat_completions_response(
provider: LlmProvider,
fallback_model: &str,
raw_text: &str,
) -> Result<LlmTextResponse, LlmError> {
let parsed: ChatCompletionsResponseEnvelope = serde_json::from_str(raw_text)
.map_err(|error| LlmError::Deserialize(format!("解析 LLM JSON 响应失败:{error}")))?;
let first_choice = parsed
.choices
.first()
.ok_or_else(|| LlmError::Deserialize("LLM 响应缺少 choices[0]".to_string()))?;
let content = extract_message_text(first_choice)
.ok_or(LlmError::EmptyResponse)?
.trim()
.to_string();
if content.is_empty() {
return Err(LlmError::EmptyResponse);
}
Ok(LlmTextResponse {
provider,
model: parsed.model.unwrap_or_else(|| fallback_model.to_string()),
content,
finish_reason: first_choice.finish_reason.clone(),
response_id: parsed.id,
usage: parsed.usage,
})
}
fn parse_responses_response(
provider: LlmProvider,
fallback_model: &str,
raw_text: &str,
) -> Result<LlmTextResponse, LlmError> {
let parsed: ResponsesResponseEnvelope = serde_json::from_str(raw_text).map_err(|error| {
LlmError::Deserialize(format!("解析 LLM Responses JSON 响应失败:{error}"))
})?;
let content = extract_responses_text(&parsed)
.ok_or(LlmError::EmptyResponse)?
.trim()
.to_string();
if content.is_empty() {
return Err(LlmError::EmptyResponse);
}
Ok(LlmTextResponse {
provider,
model: parsed.model.unwrap_or_else(|| fallback_model.to_string()),
content,
finish_reason: parsed.status,
response_id: parsed.id,
usage: parsed.usage.map(|usage| LlmTokenUsage {
prompt_tokens: usage.input_tokens,
completion_tokens: usage.output_tokens,
total_tokens: usage.total_tokens,
}),
})
}
fn extract_responses_text(parsed: &ResponsesResponseEnvelope) -> Option<String> {
parsed
.output_text
.as_deref()
.map(str::to_string)
.filter(|text| !text.is_empty())
.or_else(|| {
let text = parsed
.output
.iter()
.flat_map(|item| item.content.iter())
.filter_map(|part| part.text.as_deref())
.collect::<Vec<_>>()
.join("");
if text.is_empty() { None } else { Some(text) }
})
}
fn extract_message_text(choice: &ChatCompletionsChoice) -> Option<String> {
choice
.message
.as_ref()
.and_then(|message| message.content.as_ref())
.and_then(extract_content_text)
.or_else(|| {
choice
.delta
.as_ref()
.and_then(|message| message.content.as_ref())
.and_then(extract_content_text)
})
}
fn extract_content_text(content: &ChatCompletionsContent) -> Option<String> {
match content {
ChatCompletionsContent::Text(text) => Some(text.clone()),
ChatCompletionsContent::Parts(parts) => {
let text = parts
.iter()
.filter_map(|part| part.text.as_deref())
.collect::<Vec<_>>()
.join("");
if text.is_empty() { None } else { Some(text) }
}
}
}
fn decode_utf8_stream_chunk(bytes: &[u8]) -> Result<(String, Vec<u8>), LlmError> {
match std_str::from_utf8(bytes) {
Ok(text) => Ok((text.to_string(), Vec::new())),
Err(error) => {
let valid_up_to = error.valid_up_to();
let Some(_) = error.error_len() else {
let decoded = std_str::from_utf8(&bytes[..valid_up_to]).map_err(|inner_error| {
LlmError::Deserialize(format!("解析 LLM 流式 UTF-8 响应失败:{inner_error}"))
})?;
return Ok((decoded.to_string(), bytes[valid_up_to..].to_vec()));
};
Err(LlmError::Deserialize(format!(
"解析 LLM 流式 UTF-8 响应失败:{error}"
)))
}
}
}
fn parse_sse_event_block(
protocol: LlmTextProtocol,
block: &str,
) -> Result<Option<ParsedStreamEvent>, LlmError> {
let data_lines = block
.lines()
.filter_map(|line| line.trim().strip_prefix("data:"))
.map(str::trim_start)
.collect::<Vec<_>>();
if data_lines.is_empty() {
return Ok(None);
}
let data = data_lines.join("\n");
if data.trim().is_empty() || data.trim() == "[DONE]" {
return Ok(None);
}
if protocol == LlmTextProtocol::Responses {
return parse_responses_sse_event(data.as_str());
}
let parsed: ChatCompletionsResponseEnvelope = serde_json::from_str(data.as_str())
.map_err(|error| LlmError::Deserialize(format!("解析 LLM SSE 事件失败:{error}")))?;
let first_choice = parsed
.choices
.first()
.ok_or_else(|| LlmError::Deserialize("LLM SSE 响应缺少 choices[0]".to_string()))?;
Ok(Some(ParsedStreamEvent {
delta_text: extract_message_text(first_choice),
finish_reason: first_choice.finish_reason.clone(),
}))
}
fn parse_responses_sse_event(data: &str) -> Result<Option<ParsedStreamEvent>, LlmError> {
let parsed: serde_json::Value = serde_json::from_str(data).map_err(|error| {
LlmError::Deserialize(format!("解析 LLM Responses SSE 事件失败:{error}"))
})?;
let event_type = parsed
.get("type")
.and_then(serde_json::Value::as_str)
.unwrap_or_default();
match event_type {
"response.output_text.delta" => Ok(Some(ParsedStreamEvent {
delta_text: parsed
.get("delta")
.and_then(serde_json::Value::as_str)
.map(str::to_string),
finish_reason: None,
})),
"response.completed" => Ok(Some(ParsedStreamEvent {
delta_text: None,
finish_reason: Some("completed".to_string()),
})),
"response.failed" | "error" => {
let message = parsed
.get("error")
.and_then(|error| error.get("message"))
.and_then(serde_json::Value::as_str)
.or_else(|| parsed.get("message").and_then(serde_json::Value::as_str))
.unwrap_or("LLM Responses SSE 返回失败事件")
.to_string();
Err(LlmError::Upstream {
status_code: 502,
message,
})
}
_ => Ok(None),
}
}
fn should_retry_status(status: StatusCode) -> bool {
status == StatusCode::REQUEST_TIMEOUT
|| status == StatusCode::TOO_MANY_REQUESTS
|| status.is_server_error()
}
fn extract_api_error_message(raw_text: &str, fallback_message: &str) -> String {
let trimmed = raw_text.trim();
if trimmed.is_empty() {
return fallback_message.to_string();
}
let parsed = serde_json::from_str::<serde_json::Value>(trimmed);
if let Ok(value) = parsed {
if let Some(message) = value
.get("error")
.and_then(|error| error.get("message"))
.and_then(serde_json::Value::as_str)
.map(str::trim)
.filter(|message| !message.is_empty())
{
return message.to_string();
}
if let Some(message) = value
.get("message")
.and_then(serde_json::Value::as_str)
.map(str::trim)
.filter(|message| !message.is_empty())
{
return message.to_string();
}
}
trimmed.to_string()
}
fn map_stream_read_error(error: reqwest::Error, attempts: u32) -> LlmError {
if error.is_timeout() {
return LlmError::Timeout { attempts };
}
if error.is_connect() {
return LlmError::Connectivity {
attempts,
message: error.to_string(),
};
}
LlmError::Transport(error.to_string())
}
#[cfg(test)]
mod tests {
use std::{
io::{Read, Write},
net::TcpListener,
thread,
time::Duration as StdDuration,
};
use super::*;
struct MockResponse {
status_line: &'static str,
content_type: &'static str,
body: String,
extra_headers: Vec<(&'static str, &'static str)>,
}
#[test]
fn llm_config_rejects_blank_api_key() {
let error = LlmConfig::new(
LlmProvider::Ark,
DEFAULT_ARK_BASE_URL.to_string(),
" ".to_string(),
"model-a".to_string(),
DEFAULT_REQUEST_TIMEOUT_MS,
DEFAULT_MAX_RETRIES,
DEFAULT_RETRY_BACKOFF_MS,
)
.expect_err("blank api key should be rejected");
assert_eq!(
error,
LlmError::InvalidConfig("LLM api_key 不能为空".to_string())
);
}
#[test]
fn llm_chat_completion_url_normalizes_trailing_slash() {
let config = LlmConfig::new(
LlmProvider::OpenAiCompatible,
"https://example.com/base///".to_string(),
"secret".to_string(),
"model-a".to_string(),
DEFAULT_REQUEST_TIMEOUT_MS,
DEFAULT_MAX_RETRIES,
DEFAULT_RETRY_BACKOFF_MS,
)
.expect("config should be valid");
assert_eq!(
config.chat_completions_url(),
"https://example.com/base/chat/completions"
);
assert_eq!(config.responses_url(), "https://example.com/base/responses");
}
#[test]
fn sse_parser_handles_split_chunks_and_done_marker() {
let mut parser = OpenAiCompatibleSseParser::new(LlmTextProtocol::ChatCompletions);
let events_a = parser
.push_chunk("data: {\"choices\":[{\"delta\":{\"content\":\"\"}}]}\r\n\r\n")
.expect("first chunk should parse");
let events_b = parser
.push_chunk("data: {\"choices\":[{\"delta\":{\"content\":\"\"},\"finish_reason\":\"stop\"}]}\n\ndata: [DONE]\n\n")
.expect("second chunk should parse");
assert_eq!(events_a.len(), 1);
assert_eq!(events_a[0].delta_text.as_deref(), Some(""));
assert_eq!(events_b.len(), 1);
assert_eq!(events_b[0].delta_text.as_deref(), Some(""));
assert_eq!(events_b[0].finish_reason.as_deref(), Some("stop"));
}
#[test]
fn responses_sse_parser_only_emits_output_text_delta() {
let mut parser = OpenAiCompatibleSseParser::new(LlmTextProtocol::Responses);
let events = parser
.push_chunk(concat!(
"data: {\"type\":\"response.created\"}\n\n",
"data: {\"type\":\"response.output_text.delta\",\"delta\":\"\"}\n\n",
"data: {\"type\":\"response.output_text.delta\",\"delta\":\"\"}\n\n",
"data: {\"type\":\"response.completed\"}\n\n",
))
.expect("responses stream should parse");
assert_eq!(events.len(), 3);
assert_eq!(events[0].delta_text.as_deref(), Some(""));
assert_eq!(events[1].delta_text.as_deref(), Some(""));
assert_eq!(events[2].finish_reason.as_deref(), Some("completed"));
}
#[test]
fn decode_utf8_stream_chunk_preserves_incomplete_multibyte_suffix() {
let full_bytes = "你好".as_bytes();
let first_result = decode_utf8_stream_chunk(&full_bytes[..2])
.expect("incomplete utf-8 chunk should be buffered");
assert_eq!(first_result.0, "");
assert_eq!(first_result.1, full_bytes[..2].to_vec());
let mut combined = first_result.1;
combined.extend_from_slice(&full_bytes[2..]);
let second_result = decode_utf8_stream_chunk(combined.as_slice())
.expect("completed utf-8 bytes should decode");
assert_eq!(second_result.0, "你好");
assert!(second_result.1.is_empty());
}
#[tokio::test]
async fn request_text_parses_non_stream_response() {
let server_url = spawn_mock_server(vec![MockResponse {
status_line: "200 OK",
content_type: "application/json; charset=utf-8",
body: r#"{"id":"resp_01","model":"ark-test-model","choices":[{"message":{"content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":6,"total_tokens":16}}"#.to_string(),
extra_headers: Vec::new(),
}]);
let client = build_test_client(server_url, 0);
let response = client
.request_single_message_text("系统", "用户")
.await
.expect("request_text should succeed");
assert_eq!(response.provider, LlmProvider::Ark);
assert_eq!(response.model, "ark-test-model");
assert_eq!(response.content, "测试成功");
assert_eq!(response.finish_reason.as_deref(), Some("stop"));
assert_eq!(response.response_id.as_deref(), Some("resp_01"));
assert_eq!(
response.usage,
Some(LlmTokenUsage {
prompt_tokens: 10,
completion_tokens: 6,
total_tokens: 16,
})
);
}
#[tokio::test]
async fn request_text_retries_after_upstream_500() {
let server_url = spawn_mock_server(vec![
MockResponse {
status_line: "500 Internal Server Error",
content_type: "application/json; charset=utf-8",
body: r#"{"error":{"message":"temporary upstream failure"}}"#.to_string(),
extra_headers: Vec::new(),
},
MockResponse {
status_line: "200 OK",
content_type: "application/json; charset=utf-8",
body: r#"{"id":"resp_retry","choices":[{"message":{"content":""},"finish_reason":"stop"}]}"#.to_string(),
extra_headers: Vec::new(),
},
]);
let client = build_test_client(server_url, 1);
let response = client
.request_single_message_text("系统", "用户")
.await
.expect("second attempt should succeed");
assert_eq!(response.content, "第二次成功");
assert_eq!(response.response_id.as_deref(), Some("resp_retry"));
}
#[tokio::test]
async fn request_text_sends_web_search_options_when_enabled() {
let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind");
let address = listener.local_addr().expect("listener should have addr");
let server_handle = thread::spawn(move || {
let (mut stream, _) = listener.accept().expect("request should connect");
let request_text = read_request(&mut stream);
write_response(
&mut stream,
MockResponse {
status_line: "200 OK",
content_type: "application/json; charset=utf-8",
body: r#"{"id":"resp_search","model":"test-model","choices":[{"message":{"content":""},"finish_reason":"stop"}]}"#.to_string(),
extra_headers: Vec::new(),
},
);
request_text
});
let client = build_test_client(format!("http://{address}"), 0);
let response = client
.request_text(
LlmTextRequest::single_turn("系统", "用户")
.with_web_search(true)
.with_max_tokens(128),
)
.await
.expect("request_text should succeed");
let request_text = server_handle.join().expect("server thread should join");
let request_body = request_text
.split("\r\n\r\n")
.nth(1)
.expect("request body should exist");
let request_json: serde_json::Value =
serde_json::from_str(request_body).expect("request body should be json");
assert_eq!(response.content, "搜索成功");
assert_eq!(request_json["web_search_options"], serde_json::json!({}));
}
#[tokio::test]
async fn request_text_sends_responses_body_with_web_search_tool() {
let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind");
let address = listener.local_addr().expect("listener should have addr");
let server_handle = thread::spawn(move || {
let (mut stream, _) = listener.accept().expect("request should connect");
let request_text = read_request(&mut stream);
write_response(
&mut stream,
MockResponse {
status_line: "200 OK",
content_type: "application/json; charset=utf-8",
body: r#"{"id":"resp_responses","model":"deepseek-v3-2-251201","output_text":"Responses ","status":"completed","usage":{"input_tokens":9,"output_tokens":4,"total_tokens":13}}"#.to_string(),
extra_headers: Vec::new(),
},
);
request_text
});
let client = build_test_client(format!("http://{address}"), 0);
let response = client
.request_text(
LlmTextRequest::single_turn("系统", "用户")
.with_model("deepseek-v3-2-251201")
.with_responses_api()
.with_web_search(true)
.with_max_tokens(128),
)
.await
.expect("responses request_text should succeed");
let request_text = server_handle.join().expect("server thread should join");
let request_line = request_text.lines().next().unwrap_or_default();
let request_body = request_text
.split("\r\n\r\n")
.nth(1)
.expect("request body should exist");
let request_json: serde_json::Value =
serde_json::from_str(request_body).expect("request body should be json");
assert!(request_line.contains("POST /responses HTTP/1.1"));
assert_eq!(response.content, "Responses 成功");
assert_eq!(response.model, "deepseek-v3-2-251201");
assert_eq!(
response.usage,
Some(LlmTokenUsage {
prompt_tokens: 9,
completion_tokens: 4,
total_tokens: 13,
})
);
assert_eq!(
request_json["model"],
serde_json::json!("deepseek-v3-2-251201")
);
assert_eq!(request_json["stream"], serde_json::json!(false));
assert_eq!(
request_json["tools"],
serde_json::json!([{ "type": "web_search", "max_keyword": 3 }])
);
assert_eq!(
request_json["input"][0]["content"][0],
serde_json::json!({ "type": "input_text", "text": "系统" })
);
}
#[tokio::test]
async fn stream_text_accumulates_sse_response() {
let server_url = spawn_mock_server(vec![MockResponse {
status_line: "200 OK",
content_type: "text/event-stream; charset=utf-8",
body: concat!(
"data: {\"choices\":[{\"delta\":{\"content\":\"\"}}]}\n\n",
"data: {\"choices\":[{\"delta\":{\"content\":\"\"}}]}\n\n",
"data: {\"choices\":[{\"finish_reason\":\"stop\"}]}\n\n",
"data: [DONE]\n\n"
)
.to_string(),
extra_headers: vec![("x-request-id", "req_stream_01")],
}]);
let client = build_test_client(server_url, 0);
let mut updates = Vec::new();
let response = client
.stream_single_message_text("系统", "用户", |delta| {
updates.push(delta.accumulated_text.clone());
})
.await
.expect("stream_text should succeed");
assert_eq!(updates, vec!["".to_string(), "你好".to_string()]);
assert_eq!(response.content, "你好");
assert_eq!(response.finish_reason.as_deref(), Some("stop"));
assert_eq!(response.response_id.as_deref(), Some("req_stream_01"));
}
#[tokio::test]
async fn stream_text_accumulates_responses_sse_response() {
let server_url = spawn_mock_server(vec![MockResponse {
status_line: "200 OK",
content_type: "text/event-stream; charset=utf-8",
body: concat!(
"data: {\"type\":\"response.output_text.delta\",\"delta\":\"\"}\n\n",
"data: {\"type\":\"response.output_text.delta\",\"delta\":\"\"}\n\n",
"data: {\"type\":\"response.completed\"}\n\n"
)
.to_string(),
extra_headers: vec![("x-request-id", "req_responses_stream_01")],
}]);
let client = build_test_client(server_url, 0);
let mut updates = Vec::new();
let response = client
.stream_text(
LlmTextRequest::single_turn("系统", "用户").with_responses_api(),
|delta| {
updates.push(delta.accumulated_text.clone());
},
)
.await
.expect("responses stream_text should succeed");
assert_eq!(updates, vec!["".to_string(), "你好".to_string()]);
assert_eq!(response.content, "你好");
assert_eq!(response.finish_reason.as_deref(), Some("completed"));
assert_eq!(
response.response_id.as_deref(),
Some("req_responses_stream_01")
);
}
#[tokio::test]
async fn request_text_writes_raw_failure_logs_after_parse_error() {
let log_dir = std::env::temp_dir().join(format!(
"platform-llm-raw-log-test-{}",
build_llm_raw_log_prefix("parse_error")
));
unsafe {
std::env::set_var("LLM_RAW_LOG_DIR", &log_dir);
}
let server_url = spawn_mock_server(vec![MockResponse {
status_line: "200 OK",
content_type: "application/json; charset=utf-8",
body: "不是合法 JSON".to_string(),
extra_headers: Vec::new(),
}]);
let client = build_test_client(server_url, 0);
let error = client
.request_single_message_text("系统原文", "用户原文")
.await
.expect_err("invalid json should fail");
assert!(matches!(error, LlmError::Deserialize(_)));
let mut input_logs = Vec::new();
let mut output_logs = Vec::new();
for entry in fs::read_dir(&log_dir).expect("log dir should exist") {
let path = entry.expect("log entry should be readable").path();
let file_name = path
.file_name()
.and_then(|name| name.to_str())
.unwrap_or_default()
.to_string();
if file_name.ends_with(".input.json") {
input_logs.push(path);
} else if file_name.ends_with(".output.txt") {
output_logs.push(path);
}
}
assert_eq!(input_logs.len(), 1);
assert_eq!(output_logs.len(), 1);
let input_text = fs::read_to_string(&input_logs[0]).expect("input log should be readable");
let output_text =
fs::read_to_string(&output_logs[0]).expect("output log should be readable");
assert!(input_text.contains("系统原文"));
assert!(input_text.contains("用户原文"));
assert!(!input_text.contains("test-key"));
assert_eq!(output_text, "不是合法 JSON");
unsafe {
std::env::remove_var("LLM_RAW_LOG_DIR");
}
fs::remove_dir_all(log_dir).expect("log dir should be removed");
}
fn build_test_client(base_url: String, max_retries: u32) -> LlmClient {
let config = LlmConfig::new(
LlmProvider::Ark,
base_url,
"test-key".to_string(),
"test-model".to_string(),
DEFAULT_REQUEST_TIMEOUT_MS,
max_retries,
1,
)
.expect("config should be valid");
LlmClient::new(config).expect("client should be created")
}
fn spawn_mock_server(responses: Vec<MockResponse>) -> String {
let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind");
let address = listener.local_addr().expect("listener should have addr");
thread::spawn(move || {
for response in responses {
let (mut stream, _) = listener.accept().expect("request should connect");
read_request(&mut stream);
write_response(&mut stream, response);
}
});
format!("http://{address}")
}
fn read_request(stream: &mut std::net::TcpStream) -> String {
stream
.set_read_timeout(Some(StdDuration::from_secs(1)))
.expect("read timeout should be set");
let mut buffer = Vec::new();
let mut chunk = [0_u8; 1024];
let mut expected_total = None;
loop {
match stream.read(&mut chunk) {
Ok(0) => break,
Ok(bytes_read) => {
buffer.extend_from_slice(&chunk[..bytes_read]);
if expected_total.is_none()
&& let Some(header_end) = find_header_end(&buffer)
{
let content_length =
read_content_length(&buffer[..header_end]).unwrap_or(0);
expected_total = Some(header_end + content_length);
}
if let Some(total_bytes) = expected_total
&& buffer.len() >= total_bytes
{
break;
}
}
Err(error)
if error.kind() == std::io::ErrorKind::WouldBlock
|| error.kind() == std::io::ErrorKind::TimedOut =>
{
break;
}
Err(error) => panic!("mock server failed to read request: {error}"),
}
}
String::from_utf8_lossy(buffer.as_slice()).to_string()
}
fn write_response(stream: &mut std::net::TcpStream, response: MockResponse) {
let body = response.body;
let mut raw_response = format!(
"HTTP/1.1 {}\r\nContent-Type: {}\r\nContent-Length: {}\r\nConnection: close\r\n",
response.status_line,
response.content_type,
body.len()
);
for (name, value) in response.extra_headers {
raw_response.push_str(format!("{name}: {value}\r\n").as_str());
}
raw_response.push_str("\r\n");
raw_response.push_str(body.as_str());
stream
.write_all(raw_response.as_bytes())
.expect("mock response should be written");
stream.flush().expect("mock response should flush");
}
fn find_header_end(buffer: &[u8]) -> Option<usize> {
buffer
.windows(4)
.position(|window| window == b"\r\n\r\n")
.map(|index| index + 4)
}
fn read_content_length(headers: &[u8]) -> Option<usize> {
let text = String::from_utf8_lossy(headers);
text.lines().find_map(|line| {
let (name, value) = line.split_once(':')?;
if name.eq_ignore_ascii_case("content-length") {
return value.trim().parse::<usize>().ok();
}
None
})
}
}