use std::{ env, error::Error, fmt, fs, path::PathBuf, str as std_str, sync::atomic::{AtomicU64, Ordering}, time::{Duration, SystemTime, UNIX_EPOCH}, }; use log::{debug, warn}; use reqwest::{Client, StatusCode}; use serde::{Deserialize, Serialize}; use tokio::time::sleep; pub const DEFAULT_ARK_BASE_URL: &str = "https://ark.cn-beijing.volces.com/api/v3"; pub const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 30_000; pub const DEFAULT_MAX_RETRIES: u32 = 1; pub const DEFAULT_RETRY_BACKOFF_MS: u64 = 500; pub const CHAT_COMPLETIONS_PATH: &str = "/chat/completions"; pub const RESPONSES_PATH: &str = "/responses"; const DEFAULT_LLM_RAW_LOG_DIR: &str = "logs/llm-raw"; static LLM_RAW_LOG_SEQUENCE: AtomicU64 = AtomicU64::new(1); // 冻结平台来源,避免上层继续散落 provider 字符串。 #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum LlmProvider { Ark, DashScope, OpenAiCompatible, } // 统一收口文本模型网关配置,避免 api-server 和业务模块各自重复解析环境变量。 #[derive(Clone, Debug, PartialEq, Eq)] pub struct LlmConfig { provider: LlmProvider, base_url: String, api_key: String, model: String, request_timeout_ms: u64, max_retries: u32, retry_backoff_ms: u64, } // 首版只冻结当前项目已稳定使用的 system/user/assistant 三种消息角色。 #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum LlmMessageRole { System, User, Assistant, } // 单条消息保持 OpenAI 兼容格式,供统一请求体直接序列化。 #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct LlmMessage { pub role: LlmMessageRole, pub content: String, } // 文本补全请求冻结为“消息列表 + 可选模型覆盖 + 可选 max_tokens”最小闭环。 #[derive(Clone, Debug, PartialEq, Eq)] pub struct LlmTextRequest { pub model: Option, pub messages: Vec, pub max_tokens: Option, pub enable_web_search: bool, pub protocol: LlmTextProtocol, } // 文本协议必须由业务请求显式选择,避免全局默认模型把不同场景混到同一上游形态。 #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum LlmTextProtocol { ChatCompletions, Responses, } // 上层在流式消费时拿到的是“累计文本 + 当前增量”,避免每层重新自己拼接。 #[derive(Clone, Debug, PartialEq, Eq)] pub struct LlmStreamDelta { pub accumulated_text: String, pub delta_text: String, pub finish_reason: Option, } // 用于保留 token 计数,后续模块可以决定是否写入审计或成本统计。 #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct LlmTokenUsage { pub prompt_tokens: u64, pub completion_tokens: u64, pub total_tokens: u64, } // 统一文本响应,避免业务层再去解析 choices/message/content。 #[derive(Clone, Debug, PartialEq, Eq)] pub struct LlmTextResponse { pub provider: LlmProvider, pub model: String, pub content: String, pub finish_reason: Option, pub response_id: Option, pub usage: Option, } // 将上游错误归一到稳定的领域枚举,后续 api-server 可以直接映射成 HTTP error contract。 #[derive(Debug, PartialEq, Eq)] pub enum LlmError { InvalidConfig(String), InvalidRequest(String), Timeout { attempts: u32 }, Connectivity { attempts: u32, message: String }, Upstream { status_code: u16, message: String }, StreamUnavailable, EmptyResponse, Transport(String), Deserialize(String), } // 统一 OpenAI 兼容文本网关 client。 #[derive(Clone, Debug)] pub struct LlmClient { config: LlmConfig, http_client: Client, } #[derive(Serialize)] #[serde(untagged)] enum LlmRequestBody { ChatCompletions(ChatCompletionsRequestBody), Responses(ResponsesRequestBody), } #[derive(Serialize)] struct ChatCompletionsRequestBody { model: String, messages: Vec, stream: bool, #[serde(skip_serializing_if = "Option::is_none")] max_tokens: Option, #[serde(skip_serializing_if = "Option::is_none")] web_search_options: Option, } #[derive(Serialize)] struct ChatCompletionsWebSearchOptions {} #[derive(Serialize)] struct ResponsesRequestBody { model: String, stream: bool, input: Vec, #[serde(skip_serializing_if = "Option::is_none")] max_output_tokens: Option, #[serde(skip_serializing_if = "Option::is_none")] tools: Option>, } #[derive(Serialize)] struct ResponsesInputMessage { role: &'static str, content: Vec, } #[derive(Serialize)] struct ResponsesInputContentPart { #[serde(rename = "type")] part_type: &'static str, text: String, } #[derive(Serialize)] struct ResponsesWebSearchTool { #[serde(rename = "type")] tool_type: &'static str, max_keyword: u8, } #[derive(Serialize)] #[serde(rename_all = "camelCase")] struct LlmRawFailureInputLog<'a> { provider: &'static str, protocol: &'static str, model: &'a str, stream: bool, attempt: u32, max_tokens: Option, messages: &'a [LlmMessage], } #[derive(Deserialize)] struct ChatCompletionsResponseEnvelope { id: Option, model: Option, choices: Vec, usage: Option, } #[derive(Deserialize)] struct ChatCompletionsChoice { #[serde(default)] message: Option, #[serde(default)] delta: Option, #[serde(default)] finish_reason: Option, } #[derive(Deserialize)] struct ChatCompletionsMessage { #[serde(default)] content: Option, } #[derive(Deserialize)] #[serde(untagged)] enum ChatCompletionsContent { Text(String), Parts(Vec), } #[derive(Deserialize)] struct ChatCompletionsContentPart { #[serde(rename = "type")] #[allow(dead_code)] part_type: Option, #[serde(default)] text: Option, } #[derive(Deserialize)] struct ResponsesResponseEnvelope { id: Option, model: Option, #[serde(default)] output_text: Option, #[serde(default)] output: Vec, #[serde(default)] status: Option, usage: Option, } #[derive(Deserialize)] struct ResponsesOutputItem { #[serde(default)] content: Vec, } #[derive(Deserialize)] struct ResponsesOutputContentPart { #[serde(rename = "type")] #[allow(dead_code)] part_type: Option, #[serde(default)] text: Option, } #[derive(Deserialize)] struct ResponsesUsage { #[serde(default)] input_tokens: u64, #[serde(default)] output_tokens: u64, #[serde(default)] total_tokens: u64, } struct OpenAiCompatibleSseParser { buffer: String, raw_text: String, protocol: LlmTextProtocol, } #[derive(Debug)] struct ParsedStreamEvent { delta_text: Option, finish_reason: Option, } impl LlmProvider { pub fn as_str(&self) -> &'static str { match self { Self::Ark => "ark", Self::DashScope => "dash_scope", Self::OpenAiCompatible => "openai_compatible", } } } impl LlmConfig { #[allow(clippy::too_many_arguments)] pub fn new( provider: LlmProvider, base_url: String, api_key: String, model: String, request_timeout_ms: u64, max_retries: u32, retry_backoff_ms: u64, ) -> Result { let base_url = normalize_non_empty(base_url, "LLM base_url 不能为空")?; let api_key = normalize_non_empty(api_key, "LLM api_key 不能为空")?; let model = normalize_non_empty(model, "LLM model 不能为空")?; if request_timeout_ms == 0 { return Err(LlmError::InvalidConfig( "LLM request_timeout_ms 必须大于 0".to_string(), )); } Ok(Self { provider, base_url, api_key, model, request_timeout_ms, max_retries, retry_backoff_ms, }) } pub fn ark_default(api_key: String, model: String) -> Result { Self::new( LlmProvider::Ark, DEFAULT_ARK_BASE_URL.to_string(), api_key, model, DEFAULT_REQUEST_TIMEOUT_MS, DEFAULT_MAX_RETRIES, DEFAULT_RETRY_BACKOFF_MS, ) } pub fn provider(&self) -> LlmProvider { self.provider } pub fn base_url(&self) -> &str { &self.base_url } pub fn api_key(&self) -> &str { &self.api_key } pub fn model(&self) -> &str { &self.model } pub fn request_timeout_ms(&self) -> u64 { self.request_timeout_ms } pub fn max_retries(&self) -> u32 { self.max_retries } pub fn retry_backoff_ms(&self) -> u64 { self.retry_backoff_ms } pub fn chat_completions_url(&self) -> String { format!( "{}/{}", self.base_url.trim_end_matches('/'), CHAT_COMPLETIONS_PATH.trim_start_matches('/') ) } pub fn responses_url(&self) -> String { format!( "{}/{}", self.base_url.trim_end_matches('/'), RESPONSES_PATH.trim_start_matches('/') ) } } impl LlmMessage { pub fn new(role: LlmMessageRole, content: impl Into) -> Self { Self { role, content: content.into(), } } pub fn system(content: impl Into) -> Self { Self::new(LlmMessageRole::System, content) } pub fn user(content: impl Into) -> Self { Self::new(LlmMessageRole::User, content) } pub fn assistant(content: impl Into) -> Self { Self::new(LlmMessageRole::Assistant, content) } } impl LlmTextRequest { pub fn new(messages: Vec) -> Self { Self { model: None, messages, max_tokens: None, enable_web_search: false, protocol: LlmTextProtocol::ChatCompletions, } } pub fn single_turn(system_prompt: impl Into, user_prompt: impl Into) -> Self { Self::new(vec![ LlmMessage::system(system_prompt), LlmMessage::user(user_prompt), ]) } pub fn with_model(mut self, model: impl Into) -> Self { self.model = Some(model.into()); self } pub fn with_max_tokens(mut self, max_tokens: u32) -> Self { self.max_tokens = Some(max_tokens); self } pub fn with_web_search(mut self, enabled: bool) -> Self { self.enable_web_search = enabled; self } pub fn with_responses_api(mut self) -> Self { self.protocol = LlmTextProtocol::Responses; self } fn validate(&self) -> Result<(), LlmError> { if self.messages.is_empty() { return Err(LlmError::InvalidRequest( "LLM messages 不能为空".to_string(), )); } for message in &self.messages { if message.content.trim().is_empty() { return Err(LlmError::InvalidRequest( "LLM message.content 不能为空".to_string(), )); } } if let Some(model) = &self.model && model.trim().is_empty() { return Err(LlmError::InvalidRequest( "LLM request.model 不能为空字符串".to_string(), )); } Ok(()) } fn resolved_model<'a>(&'a self, fallback_model: &'a str) -> &'a str { self.model .as_deref() .map(str::trim) .filter(|value| !value.is_empty()) .unwrap_or(fallback_model) } } impl LlmTextProtocol { fn as_str(self) -> &'static str { match self { Self::ChatCompletions => "chat_completions", Self::Responses => "responses", } } } impl fmt::Display for LlmError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::InvalidConfig(message) | Self::InvalidRequest(message) | Self::Transport(message) | Self::Deserialize(message) => write!(f, "{message}"), Self::Timeout { attempts } => { write!(f, "LLM 请求超时,累计尝试 {attempts} 次") } Self::Connectivity { attempts, message } => { write!(f, "LLM 连接失败,累计尝试 {attempts} 次:{message}") } Self::Upstream { status_code, message, } => write!(f, "LLM 上游返回 {status_code}:{message}"), Self::StreamUnavailable => write!(f, "LLM 流式响应体不可用"), Self::EmptyResponse => write!(f, "LLM 返回内容为空"), } } } impl Error for LlmError {} impl LlmClient { pub fn new(config: LlmConfig) -> Result { let http_client = Client::builder().build().map_err(|error| { LlmError::InvalidConfig(format!("构建 reqwest client 失败:{error}")) })?; Ok(Self { config, http_client, }) } pub fn config(&self) -> &LlmConfig { &self.config } pub async fn request_text(&self, request: LlmTextRequest) -> Result { request.validate()?; let resolved_model = request.resolved_model(self.config.model()).to_string(); let response = self.execute_request(&request, false).await?; let raw_text = response.text().await.map_err(|error| { let llm_error = map_stream_read_error(error, 1); log_llm_raw_failure( &self.config, &request, false, 1, "read_response_failed", llm_error.to_string().as_str(), ); llm_error })?; parse_text_response( request.protocol, self.config.provider(), &resolved_model, raw_text.as_str(), ) .map_err(|error| { log_llm_raw_failure( &self.config, &request, false, 1, "parse_response_failed", raw_text.as_str(), ); error }) } pub async fn request_single_message_text( &self, system_prompt: impl Into, user_prompt: impl Into, ) -> Result { self.request_text(LlmTextRequest::single_turn(system_prompt, user_prompt)) .await } pub async fn stream_text( &self, request: LlmTextRequest, mut on_delta: F, ) -> Result where F: FnMut(&LlmStreamDelta), { request.validate()?; let resolved_model = request.resolved_model(self.config.model()).to_string(); let mut response = self.execute_request(&request, true).await?; let response_id = response .headers() .get("x-request-id") .and_then(|value| value.to_str().ok()) .map(str::to_string); let mut parser = OpenAiCompatibleSseParser::new(request.protocol); let mut accumulated_text = String::new(); let mut finish_reason = None; let mut undecoded_chunk_bytes = Vec::new(); loop { let next_chunk = response.chunk().await.map_err(|error| { let llm_error = map_stream_read_error(error, 1); log_llm_raw_failure( &self.config, &request, true, 1, "read_stream_failed", parser.raw_text().as_str(), ); llm_error })?; let Some(chunk) = next_chunk else { break; }; undecoded_chunk_bytes.extend_from_slice(chunk.as_ref()); let (chunk_text, remaining_bytes) = decode_utf8_stream_chunk(undecoded_chunk_bytes.as_slice()).map_err(|error| { log_llm_raw_failure( &self.config, &request, true, 1, "decode_stream_failed", parser.raw_text().as_str(), ); error })?; undecoded_chunk_bytes = remaining_bytes; if chunk_text.is_empty() { continue; } let stream_events = parser.push_chunk(chunk_text.as_ref()).map_err(|error| { log_llm_raw_failure( &self.config, &request, true, 1, "parse_stream_failed", parser.raw_text().as_str(), ); error })?; for event in stream_events { if let Some(delta_text) = event.delta_text && !delta_text.is_empty() { accumulated_text.push_str(delta_text.as_str()); let update = LlmStreamDelta { accumulated_text: accumulated_text.clone(), delta_text, finish_reason: event.finish_reason.clone(), }; on_delta(&update); } if event.finish_reason.is_some() { finish_reason = event.finish_reason; } } } if !undecoded_chunk_bytes.is_empty() { let trailing_text = std_str::from_utf8(undecoded_chunk_bytes.as_slice()).map_err(|error| { log_llm_raw_failure( &self.config, &request, true, 1, "decode_stream_failed", parser.raw_text().as_str(), ); LlmError::Deserialize(format!("解析 LLM 流式 UTF-8 响应失败:{error}")) })?; if !trailing_text.is_empty() { let trailing_events = parser.push_chunk(trailing_text).map_err(|error| { log_llm_raw_failure( &self.config, &request, true, 1, "parse_stream_failed", parser.raw_text().as_str(), ); error })?; for event in trailing_events { if let Some(delta_text) = event.delta_text && !delta_text.is_empty() { accumulated_text.push_str(delta_text.as_str()); let update = LlmStreamDelta { accumulated_text: accumulated_text.clone(), delta_text, finish_reason: event.finish_reason.clone(), }; on_delta(&update); } if event.finish_reason.is_some() { finish_reason = event.finish_reason; } } } } let remaining_events = parser.finish().map_err(|error| { log_llm_raw_failure( &self.config, &request, true, 1, "parse_stream_failed", parser.raw_text().as_str(), ); error })?; for event in remaining_events { if let Some(delta_text) = event.delta_text && !delta_text.is_empty() { accumulated_text.push_str(delta_text.as_str()); let update = LlmStreamDelta { accumulated_text: accumulated_text.clone(), delta_text, finish_reason: event.finish_reason.clone(), }; on_delta(&update); } if event.finish_reason.is_some() { finish_reason = event.finish_reason; } } let content = accumulated_text.trim().to_string(); if content.is_empty() { log_llm_raw_failure( &self.config, &request, true, 1, "empty_stream_response", parser.raw_text().as_str(), ); return Err(LlmError::EmptyResponse); } Ok(LlmTextResponse { provider: self.config.provider(), model: resolved_model, content, finish_reason, response_id, usage: None, }) } pub async fn stream_single_message_text( &self, system_prompt: impl Into, user_prompt: impl Into, on_delta: F, ) -> Result where F: FnMut(&LlmStreamDelta), { self.stream_text( LlmTextRequest::single_turn(system_prompt, user_prompt), on_delta, ) .await } async fn execute_request( &self, request: &LlmTextRequest, stream: bool, ) -> Result { let request_body = build_request_body(request, self.config.model(), stream); let model = request.resolved_model(self.config.model()); let url = match request.protocol { LlmTextProtocol::ChatCompletions => self.config.chat_completions_url(), LlmTextProtocol::Responses => self.config.responses_url(), }; let max_attempts = self.config.max_retries().saturating_add(1); for attempt in 1..=max_attempts { debug!( "platform-llm request started: provider={}, protocol={}, stream={}, attempt={}, model={}", self.config.provider().as_str(), request.protocol.as_str(), stream, attempt, model ); let send_result = self .http_client .post(url.as_str()) .bearer_auth(self.config.api_key()) .json(&request_body) .timeout(Duration::from_millis(self.config.request_timeout_ms())) .send() .await; match send_result { Ok(response) if response.status().is_success() => { debug!( "platform-llm request succeeded: provider={}, protocol={}, stream={}, attempt={}, status={}", self.config.provider().as_str(), request.protocol.as_str(), stream, attempt, response.status().as_u16() ); return Ok(response); } Ok(response) => { let status = response.status(); let raw_text = response.text().await.unwrap_or_default(); let message = extract_api_error_message(&raw_text, "LLM 上游请求失败"); if should_retry_status(status) && attempt < max_attempts { warn!( "platform-llm request retrying after upstream status: provider={}, protocol={}, attempt={}, status={}, message={}", self.config.provider().as_str(), request.protocol.as_str(), attempt, status.as_u16(), message ); self.sleep_before_retry(attempt).await; continue; } log_llm_raw_failure( &self.config, request, stream, attempt, "upstream_status_failed", raw_text.as_str(), ); return Err(LlmError::Upstream { status_code: status.as_u16(), message, }); } Err(error) if error.is_timeout() => { if attempt < max_attempts { warn!( "platform-llm request retrying after timeout: provider={}, protocol={}, attempt={}", self.config.provider().as_str(), request.protocol.as_str(), attempt ); self.sleep_before_retry(attempt).await; continue; } let error = LlmError::Timeout { attempts: attempt }; log_llm_raw_failure( &self.config, request, stream, attempt, "request_timeout", error.to_string().as_str(), ); return Err(error); } Err(error) if error.is_connect() => { let message = error.to_string(); if attempt < max_attempts { warn!( "platform-llm request retrying after connectivity failure: provider={}, protocol={}, attempt={}, error={}", self.config.provider().as_str(), request.protocol.as_str(), attempt, message ); self.sleep_before_retry(attempt).await; continue; } let error = LlmError::Connectivity { attempts: attempt, message, }; log_llm_raw_failure( &self.config, request, stream, attempt, "request_connectivity_failed", error.to_string().as_str(), ); return Err(error); } Err(error) => { let error = LlmError::Transport(error.to_string()); log_llm_raw_failure( &self.config, request, stream, attempt, "request_transport_failed", error.to_string().as_str(), ); return Err(error); } } } Err(LlmError::Transport( "LLM 请求在重试循环后仍未返回结果".to_string(), )) } async fn sleep_before_retry(&self, attempt: u32) { let backoff_ms = self .config .retry_backoff_ms() .saturating_mul(u64::from(attempt)); if backoff_ms > 0 { sleep(Duration::from_millis(backoff_ms)).await; } } } impl OpenAiCompatibleSseParser { fn new(protocol: LlmTextProtocol) -> Self { Self { buffer: String::new(), raw_text: String::new(), protocol, } } fn push_chunk(&mut self, chunk: &str) -> Result, LlmError> { self.raw_text.push_str(chunk); self.buffer.push_str(chunk); self.buffer = self.buffer.replace("\r\n", "\n"); self.drain_complete_events() } fn raw_text(&self) -> String { self.raw_text.clone() } fn finish(&mut self) -> Result, LlmError> { if self.buffer.trim().is_empty() { return Ok(Vec::new()); } self.buffer.push_str("\n\n"); self.drain_complete_events() } fn drain_complete_events(&mut self) -> Result, LlmError> { let mut events = Vec::new(); while let Some(boundary) = self.buffer.find("\n\n") { let block = self.buffer[..boundary].to_string(); self.buffer = self.buffer[(boundary + 2)..].to_string(); if let Some(event) = parse_sse_event_block(self.protocol, block.as_str())? { events.push(event); } } Ok(events) } } fn normalize_non_empty(value: String, error_message: &str) -> Result { let trimmed = value.trim().to_string(); if trimmed.is_empty() { return Err(LlmError::InvalidConfig(error_message.to_string())); } Ok(trimmed) } fn build_request_body( request: &LlmTextRequest, fallback_model: &str, stream: bool, ) -> LlmRequestBody { match request.protocol { LlmTextProtocol::ChatCompletions => { LlmRequestBody::ChatCompletions(ChatCompletionsRequestBody { model: request.resolved_model(fallback_model).to_string(), messages: request.messages.clone(), stream, max_tokens: request.max_tokens, web_search_options: request .enable_web_search .then_some(ChatCompletionsWebSearchOptions {}), }) } LlmTextProtocol::Responses => LlmRequestBody::Responses(ResponsesRequestBody { model: request.resolved_model(fallback_model).to_string(), stream, input: map_responses_input_messages(request.messages.as_slice()), max_output_tokens: request.max_tokens, tools: request.enable_web_search.then(|| { vec![ResponsesWebSearchTool { tool_type: "web_search", max_keyword: 3, }] }), }), } } fn map_responses_input_messages(messages: &[LlmMessage]) -> Vec { messages .iter() .map(|message| ResponsesInputMessage { role: match message.role { LlmMessageRole::System => "system", LlmMessageRole::User => "user", LlmMessageRole::Assistant => "assistant", }, content: vec![ResponsesInputContentPart { part_type: "input_text", text: message.content.clone(), }], }) .collect() } fn log_llm_raw_failure( config: &LlmConfig, request: &LlmTextRequest, stream: bool, attempt: u32, failure_stage: &str, raw_output: &str, ) { if let Err(error) = write_llm_raw_failure(config, request, stream, attempt, failure_stage, raw_output) { warn!( "LLM 失败原文日志落盘失败,主错误流程继续执行: failure_stage={}, error={}", failure_stage, error ); } } fn write_llm_raw_failure( config: &LlmConfig, request: &LlmTextRequest, stream: bool, attempt: u32, failure_stage: &str, raw_output: &str, ) -> Result<(), String> { let log_dir = env::var("LLM_RAW_LOG_DIR") .map(PathBuf::from) .unwrap_or_else(|_| PathBuf::from(DEFAULT_LLM_RAW_LOG_DIR)); fs::create_dir_all(&log_dir).map_err(|error| format!("创建日志目录失败:{error}"))?; let prefix = build_llm_raw_log_prefix(failure_stage); let model = request.resolved_model(config.model()); let input_log = LlmRawFailureInputLog { provider: config.provider().as_str(), protocol: request.protocol.as_str(), model, stream, attempt, max_tokens: request.max_tokens, messages: request.messages.as_slice(), }; let input_text = serde_json::to_string_pretty(&input_log) .map_err(|error| format!("序列化模型输入日志失败:{error}"))?; fs::write(log_dir.join(format!("{prefix}.input.json")), input_text) .map_err(|error| format!("写入模型输入日志失败:{error}"))?; fs::write(log_dir.join(format!("{prefix}.output.txt")), raw_output) .map_err(|error| format!("写入模型输出日志失败:{error}"))?; Ok(()) } fn build_llm_raw_log_prefix(failure_stage: &str) -> String { let millis = SystemTime::now() .duration_since(UNIX_EPOCH) .map(|duration| duration.as_millis()) .unwrap_or_default(); let sequence = LLM_RAW_LOG_SEQUENCE.fetch_add(1, Ordering::Relaxed); let safe_stage = sanitize_log_file_segment(failure_stage); format!("{millis}-{}-{sequence:06}-{safe_stage}", std::process::id()) } fn sanitize_log_file_segment(value: &str) -> String { let sanitized = value .chars() .map(|character| { if character.is_ascii_alphanumeric() || character == '-' || character == '_' { character } else { '_' } }) .collect::(); if sanitized.is_empty() { "unknown".to_string() } else { sanitized } } fn parse_text_response( protocol: LlmTextProtocol, provider: LlmProvider, fallback_model: &str, raw_text: &str, ) -> Result { match protocol { LlmTextProtocol::ChatCompletions => { parse_chat_completions_response(provider, fallback_model, raw_text) } LlmTextProtocol::Responses => parse_responses_response(provider, fallback_model, raw_text), } } fn parse_chat_completions_response( provider: LlmProvider, fallback_model: &str, raw_text: &str, ) -> Result { let parsed: ChatCompletionsResponseEnvelope = serde_json::from_str(raw_text) .map_err(|error| LlmError::Deserialize(format!("解析 LLM JSON 响应失败:{error}")))?; let first_choice = parsed .choices .first() .ok_or_else(|| LlmError::Deserialize("LLM 响应缺少 choices[0]".to_string()))?; let content = extract_message_text(first_choice) .ok_or(LlmError::EmptyResponse)? .trim() .to_string(); if content.is_empty() { return Err(LlmError::EmptyResponse); } Ok(LlmTextResponse { provider, model: parsed.model.unwrap_or_else(|| fallback_model.to_string()), content, finish_reason: first_choice.finish_reason.clone(), response_id: parsed.id, usage: parsed.usage, }) } fn parse_responses_response( provider: LlmProvider, fallback_model: &str, raw_text: &str, ) -> Result { let parsed: ResponsesResponseEnvelope = serde_json::from_str(raw_text).map_err(|error| { LlmError::Deserialize(format!("解析 LLM Responses JSON 响应失败:{error}")) })?; let content = extract_responses_text(&parsed) .ok_or(LlmError::EmptyResponse)? .trim() .to_string(); if content.is_empty() { return Err(LlmError::EmptyResponse); } Ok(LlmTextResponse { provider, model: parsed.model.unwrap_or_else(|| fallback_model.to_string()), content, finish_reason: parsed.status, response_id: parsed.id, usage: parsed.usage.map(|usage| LlmTokenUsage { prompt_tokens: usage.input_tokens, completion_tokens: usage.output_tokens, total_tokens: usage.total_tokens, }), }) } fn extract_responses_text(parsed: &ResponsesResponseEnvelope) -> Option { parsed .output_text .as_deref() .map(str::to_string) .filter(|text| !text.is_empty()) .or_else(|| { let text = parsed .output .iter() .flat_map(|item| item.content.iter()) .filter_map(|part| part.text.as_deref()) .collect::>() .join(""); if text.is_empty() { None } else { Some(text) } }) } fn extract_message_text(choice: &ChatCompletionsChoice) -> Option { choice .message .as_ref() .and_then(|message| message.content.as_ref()) .and_then(extract_content_text) .or_else(|| { choice .delta .as_ref() .and_then(|message| message.content.as_ref()) .and_then(extract_content_text) }) } fn extract_content_text(content: &ChatCompletionsContent) -> Option { match content { ChatCompletionsContent::Text(text) => Some(text.clone()), ChatCompletionsContent::Parts(parts) => { let text = parts .iter() .filter_map(|part| part.text.as_deref()) .collect::>() .join(""); if text.is_empty() { None } else { Some(text) } } } } fn decode_utf8_stream_chunk(bytes: &[u8]) -> Result<(String, Vec), LlmError> { match std_str::from_utf8(bytes) { Ok(text) => Ok((text.to_string(), Vec::new())), Err(error) => { let valid_up_to = error.valid_up_to(); let Some(_) = error.error_len() else { let decoded = std_str::from_utf8(&bytes[..valid_up_to]).map_err(|inner_error| { LlmError::Deserialize(format!("解析 LLM 流式 UTF-8 响应失败:{inner_error}")) })?; return Ok((decoded.to_string(), bytes[valid_up_to..].to_vec())); }; Err(LlmError::Deserialize(format!( "解析 LLM 流式 UTF-8 响应失败:{error}" ))) } } } fn parse_sse_event_block( protocol: LlmTextProtocol, block: &str, ) -> Result, LlmError> { let data_lines = block .lines() .filter_map(|line| line.trim().strip_prefix("data:")) .map(str::trim_start) .collect::>(); if data_lines.is_empty() { return Ok(None); } let data = data_lines.join("\n"); if data.trim().is_empty() || data.trim() == "[DONE]" { return Ok(None); } if protocol == LlmTextProtocol::Responses { return parse_responses_sse_event(data.as_str()); } let parsed: ChatCompletionsResponseEnvelope = serde_json::from_str(data.as_str()) .map_err(|error| LlmError::Deserialize(format!("解析 LLM SSE 事件失败:{error}")))?; let first_choice = parsed .choices .first() .ok_or_else(|| LlmError::Deserialize("LLM SSE 响应缺少 choices[0]".to_string()))?; Ok(Some(ParsedStreamEvent { delta_text: extract_message_text(first_choice), finish_reason: first_choice.finish_reason.clone(), })) } fn parse_responses_sse_event(data: &str) -> Result, LlmError> { let parsed: serde_json::Value = serde_json::from_str(data).map_err(|error| { LlmError::Deserialize(format!("解析 LLM Responses SSE 事件失败:{error}")) })?; let event_type = parsed .get("type") .and_then(serde_json::Value::as_str) .unwrap_or_default(); match event_type { "response.output_text.delta" => Ok(Some(ParsedStreamEvent { delta_text: parsed .get("delta") .and_then(serde_json::Value::as_str) .map(str::to_string), finish_reason: None, })), "response.completed" => Ok(Some(ParsedStreamEvent { delta_text: None, finish_reason: Some("completed".to_string()), })), "response.failed" | "error" => { let message = parsed .get("error") .and_then(|error| error.get("message")) .and_then(serde_json::Value::as_str) .or_else(|| parsed.get("message").and_then(serde_json::Value::as_str)) .unwrap_or("LLM Responses SSE 返回失败事件") .to_string(); Err(LlmError::Upstream { status_code: 502, message, }) } _ => Ok(None), } } fn should_retry_status(status: StatusCode) -> bool { status == StatusCode::REQUEST_TIMEOUT || status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error() } fn extract_api_error_message(raw_text: &str, fallback_message: &str) -> String { let trimmed = raw_text.trim(); if trimmed.is_empty() { return fallback_message.to_string(); } let parsed = serde_json::from_str::(trimmed); if let Ok(value) = parsed { if let Some(message) = value .get("error") .and_then(|error| error.get("message")) .and_then(serde_json::Value::as_str) .map(str::trim) .filter(|message| !message.is_empty()) { return message.to_string(); } if let Some(message) = value .get("message") .and_then(serde_json::Value::as_str) .map(str::trim) .filter(|message| !message.is_empty()) { return message.to_string(); } } trimmed.to_string() } fn map_stream_read_error(error: reqwest::Error, attempts: u32) -> LlmError { if error.is_timeout() { return LlmError::Timeout { attempts }; } if error.is_connect() { return LlmError::Connectivity { attempts, message: error.to_string(), }; } LlmError::Transport(error.to_string()) } #[cfg(test)] mod tests { use std::{ io::{Read, Write}, net::TcpListener, thread, time::Duration as StdDuration, }; use super::*; struct MockResponse { status_line: &'static str, content_type: &'static str, body: String, extra_headers: Vec<(&'static str, &'static str)>, } #[test] fn llm_config_rejects_blank_api_key() { let error = LlmConfig::new( LlmProvider::Ark, DEFAULT_ARK_BASE_URL.to_string(), " ".to_string(), "model-a".to_string(), DEFAULT_REQUEST_TIMEOUT_MS, DEFAULT_MAX_RETRIES, DEFAULT_RETRY_BACKOFF_MS, ) .expect_err("blank api key should be rejected"); assert_eq!( error, LlmError::InvalidConfig("LLM api_key 不能为空".to_string()) ); } #[test] fn llm_chat_completion_url_normalizes_trailing_slash() { let config = LlmConfig::new( LlmProvider::OpenAiCompatible, "https://example.com/base///".to_string(), "secret".to_string(), "model-a".to_string(), DEFAULT_REQUEST_TIMEOUT_MS, DEFAULT_MAX_RETRIES, DEFAULT_RETRY_BACKOFF_MS, ) .expect("config should be valid"); assert_eq!( config.chat_completions_url(), "https://example.com/base/chat/completions" ); assert_eq!(config.responses_url(), "https://example.com/base/responses"); } #[test] fn sse_parser_handles_split_chunks_and_done_marker() { let mut parser = OpenAiCompatibleSseParser::new(LlmTextProtocol::ChatCompletions); let events_a = parser .push_chunk("data: {\"choices\":[{\"delta\":{\"content\":\"你\"}}]}\r\n\r\n") .expect("first chunk should parse"); let events_b = parser .push_chunk("data: {\"choices\":[{\"delta\":{\"content\":\"好\"},\"finish_reason\":\"stop\"}]}\n\ndata: [DONE]\n\n") .expect("second chunk should parse"); assert_eq!(events_a.len(), 1); assert_eq!(events_a[0].delta_text.as_deref(), Some("你")); assert_eq!(events_b.len(), 1); assert_eq!(events_b[0].delta_text.as_deref(), Some("好")); assert_eq!(events_b[0].finish_reason.as_deref(), Some("stop")); } #[test] fn responses_sse_parser_only_emits_output_text_delta() { let mut parser = OpenAiCompatibleSseParser::new(LlmTextProtocol::Responses); let events = parser .push_chunk(concat!( "data: {\"type\":\"response.created\"}\n\n", "data: {\"type\":\"response.output_text.delta\",\"delta\":\"你\"}\n\n", "data: {\"type\":\"response.output_text.delta\",\"delta\":\"好\"}\n\n", "data: {\"type\":\"response.completed\"}\n\n", )) .expect("responses stream should parse"); assert_eq!(events.len(), 3); assert_eq!(events[0].delta_text.as_deref(), Some("你")); assert_eq!(events[1].delta_text.as_deref(), Some("好")); assert_eq!(events[2].finish_reason.as_deref(), Some("completed")); } #[test] fn decode_utf8_stream_chunk_preserves_incomplete_multibyte_suffix() { let full_bytes = "你好".as_bytes(); let first_result = decode_utf8_stream_chunk(&full_bytes[..2]) .expect("incomplete utf-8 chunk should be buffered"); assert_eq!(first_result.0, ""); assert_eq!(first_result.1, full_bytes[..2].to_vec()); let mut combined = first_result.1; combined.extend_from_slice(&full_bytes[2..]); let second_result = decode_utf8_stream_chunk(combined.as_slice()) .expect("completed utf-8 bytes should decode"); assert_eq!(second_result.0, "你好"); assert!(second_result.1.is_empty()); } #[tokio::test] async fn request_text_parses_non_stream_response() { let server_url = spawn_mock_server(vec![MockResponse { status_line: "200 OK", content_type: "application/json; charset=utf-8", body: r#"{"id":"resp_01","model":"ark-test-model","choices":[{"message":{"content":"测试成功"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":6,"total_tokens":16}}"#.to_string(), extra_headers: Vec::new(), }]); let client = build_test_client(server_url, 0); let response = client .request_single_message_text("系统", "用户") .await .expect("request_text should succeed"); assert_eq!(response.provider, LlmProvider::Ark); assert_eq!(response.model, "ark-test-model"); assert_eq!(response.content, "测试成功"); assert_eq!(response.finish_reason.as_deref(), Some("stop")); assert_eq!(response.response_id.as_deref(), Some("resp_01")); assert_eq!( response.usage, Some(LlmTokenUsage { prompt_tokens: 10, completion_tokens: 6, total_tokens: 16, }) ); } #[tokio::test] async fn request_text_retries_after_upstream_500() { let server_url = spawn_mock_server(vec![ MockResponse { status_line: "500 Internal Server Error", content_type: "application/json; charset=utf-8", body: r#"{"error":{"message":"temporary upstream failure"}}"#.to_string(), extra_headers: Vec::new(), }, MockResponse { status_line: "200 OK", content_type: "application/json; charset=utf-8", body: r#"{"id":"resp_retry","choices":[{"message":{"content":"第二次成功"},"finish_reason":"stop"}]}"#.to_string(), extra_headers: Vec::new(), }, ]); let client = build_test_client(server_url, 1); let response = client .request_single_message_text("系统", "用户") .await .expect("second attempt should succeed"); assert_eq!(response.content, "第二次成功"); assert_eq!(response.response_id.as_deref(), Some("resp_retry")); } #[tokio::test] async fn request_text_sends_web_search_options_when_enabled() { let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind"); let address = listener.local_addr().expect("listener should have addr"); let server_handle = thread::spawn(move || { let (mut stream, _) = listener.accept().expect("request should connect"); let request_text = read_request(&mut stream); write_response( &mut stream, MockResponse { status_line: "200 OK", content_type: "application/json; charset=utf-8", body: r#"{"id":"resp_search","model":"test-model","choices":[{"message":{"content":"搜索成功"},"finish_reason":"stop"}]}"#.to_string(), extra_headers: Vec::new(), }, ); request_text }); let client = build_test_client(format!("http://{address}"), 0); let response = client .request_text( LlmTextRequest::single_turn("系统", "用户") .with_web_search(true) .with_max_tokens(128), ) .await .expect("request_text should succeed"); let request_text = server_handle.join().expect("server thread should join"); let request_body = request_text .split("\r\n\r\n") .nth(1) .expect("request body should exist"); let request_json: serde_json::Value = serde_json::from_str(request_body).expect("request body should be json"); assert_eq!(response.content, "搜索成功"); assert_eq!(request_json["web_search_options"], serde_json::json!({})); } #[tokio::test] async fn request_text_sends_responses_body_with_web_search_tool() { let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind"); let address = listener.local_addr().expect("listener should have addr"); let server_handle = thread::spawn(move || { let (mut stream, _) = listener.accept().expect("request should connect"); let request_text = read_request(&mut stream); write_response( &mut stream, MockResponse { status_line: "200 OK", content_type: "application/json; charset=utf-8", body: r#"{"id":"resp_responses","model":"deepseek-v3-2-251201","output_text":"Responses 成功","status":"completed","usage":{"input_tokens":9,"output_tokens":4,"total_tokens":13}}"#.to_string(), extra_headers: Vec::new(), }, ); request_text }); let client = build_test_client(format!("http://{address}"), 0); let response = client .request_text( LlmTextRequest::single_turn("系统", "用户") .with_model("deepseek-v3-2-251201") .with_responses_api() .with_web_search(true) .with_max_tokens(128), ) .await .expect("responses request_text should succeed"); let request_text = server_handle.join().expect("server thread should join"); let request_line = request_text.lines().next().unwrap_or_default(); let request_body = request_text .split("\r\n\r\n") .nth(1) .expect("request body should exist"); let request_json: serde_json::Value = serde_json::from_str(request_body).expect("request body should be json"); assert!(request_line.contains("POST /responses HTTP/1.1")); assert_eq!(response.content, "Responses 成功"); assert_eq!(response.model, "deepseek-v3-2-251201"); assert_eq!( response.usage, Some(LlmTokenUsage { prompt_tokens: 9, completion_tokens: 4, total_tokens: 13, }) ); assert_eq!( request_json["model"], serde_json::json!("deepseek-v3-2-251201") ); assert_eq!(request_json["stream"], serde_json::json!(false)); assert_eq!( request_json["tools"], serde_json::json!([{ "type": "web_search", "max_keyword": 3 }]) ); assert_eq!( request_json["input"][0]["content"][0], serde_json::json!({ "type": "input_text", "text": "系统" }) ); } #[tokio::test] async fn stream_text_accumulates_sse_response() { let server_url = spawn_mock_server(vec![MockResponse { status_line: "200 OK", content_type: "text/event-stream; charset=utf-8", body: concat!( "data: {\"choices\":[{\"delta\":{\"content\":\"你\"}}]}\n\n", "data: {\"choices\":[{\"delta\":{\"content\":\"好\"}}]}\n\n", "data: {\"choices\":[{\"finish_reason\":\"stop\"}]}\n\n", "data: [DONE]\n\n" ) .to_string(), extra_headers: vec![("x-request-id", "req_stream_01")], }]); let client = build_test_client(server_url, 0); let mut updates = Vec::new(); let response = client .stream_single_message_text("系统", "用户", |delta| { updates.push(delta.accumulated_text.clone()); }) .await .expect("stream_text should succeed"); assert_eq!(updates, vec!["你".to_string(), "你好".to_string()]); assert_eq!(response.content, "你好"); assert_eq!(response.finish_reason.as_deref(), Some("stop")); assert_eq!(response.response_id.as_deref(), Some("req_stream_01")); } #[tokio::test] async fn stream_text_accumulates_responses_sse_response() { let server_url = spawn_mock_server(vec![MockResponse { status_line: "200 OK", content_type: "text/event-stream; charset=utf-8", body: concat!( "data: {\"type\":\"response.output_text.delta\",\"delta\":\"你\"}\n\n", "data: {\"type\":\"response.output_text.delta\",\"delta\":\"好\"}\n\n", "data: {\"type\":\"response.completed\"}\n\n" ) .to_string(), extra_headers: vec![("x-request-id", "req_responses_stream_01")], }]); let client = build_test_client(server_url, 0); let mut updates = Vec::new(); let response = client .stream_text( LlmTextRequest::single_turn("系统", "用户").with_responses_api(), |delta| { updates.push(delta.accumulated_text.clone()); }, ) .await .expect("responses stream_text should succeed"); assert_eq!(updates, vec!["你".to_string(), "你好".to_string()]); assert_eq!(response.content, "你好"); assert_eq!(response.finish_reason.as_deref(), Some("completed")); assert_eq!( response.response_id.as_deref(), Some("req_responses_stream_01") ); } #[tokio::test] async fn request_text_writes_raw_failure_logs_after_parse_error() { let log_dir = std::env::temp_dir().join(format!( "platform-llm-raw-log-test-{}", build_llm_raw_log_prefix("parse_error") )); unsafe { std::env::set_var("LLM_RAW_LOG_DIR", &log_dir); } let server_url = spawn_mock_server(vec![MockResponse { status_line: "200 OK", content_type: "application/json; charset=utf-8", body: "不是合法 JSON".to_string(), extra_headers: Vec::new(), }]); let client = build_test_client(server_url, 0); let error = client .request_single_message_text("系统原文", "用户原文") .await .expect_err("invalid json should fail"); assert!(matches!(error, LlmError::Deserialize(_))); let mut input_logs = Vec::new(); let mut output_logs = Vec::new(); for entry in fs::read_dir(&log_dir).expect("log dir should exist") { let path = entry.expect("log entry should be readable").path(); let file_name = path .file_name() .and_then(|name| name.to_str()) .unwrap_or_default() .to_string(); if file_name.ends_with(".input.json") { input_logs.push(path); } else if file_name.ends_with(".output.txt") { output_logs.push(path); } } assert_eq!(input_logs.len(), 1); assert_eq!(output_logs.len(), 1); let input_text = fs::read_to_string(&input_logs[0]).expect("input log should be readable"); let output_text = fs::read_to_string(&output_logs[0]).expect("output log should be readable"); assert!(input_text.contains("系统原文")); assert!(input_text.contains("用户原文")); assert!(!input_text.contains("test-key")); assert_eq!(output_text, "不是合法 JSON"); unsafe { std::env::remove_var("LLM_RAW_LOG_DIR"); } fs::remove_dir_all(log_dir).expect("log dir should be removed"); } fn build_test_client(base_url: String, max_retries: u32) -> LlmClient { let config = LlmConfig::new( LlmProvider::Ark, base_url, "test-key".to_string(), "test-model".to_string(), DEFAULT_REQUEST_TIMEOUT_MS, max_retries, 1, ) .expect("config should be valid"); LlmClient::new(config).expect("client should be created") } fn spawn_mock_server(responses: Vec) -> String { let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind"); let address = listener.local_addr().expect("listener should have addr"); thread::spawn(move || { for response in responses { let (mut stream, _) = listener.accept().expect("request should connect"); read_request(&mut stream); write_response(&mut stream, response); } }); format!("http://{address}") } fn read_request(stream: &mut std::net::TcpStream) -> String { stream .set_read_timeout(Some(StdDuration::from_secs(1))) .expect("read timeout should be set"); let mut buffer = Vec::new(); let mut chunk = [0_u8; 1024]; let mut expected_total = None; loop { match stream.read(&mut chunk) { Ok(0) => break, Ok(bytes_read) => { buffer.extend_from_slice(&chunk[..bytes_read]); if expected_total.is_none() && let Some(header_end) = find_header_end(&buffer) { let content_length = read_content_length(&buffer[..header_end]).unwrap_or(0); expected_total = Some(header_end + content_length); } if let Some(total_bytes) = expected_total && buffer.len() >= total_bytes { break; } } Err(error) if error.kind() == std::io::ErrorKind::WouldBlock || error.kind() == std::io::ErrorKind::TimedOut => { break; } Err(error) => panic!("mock server failed to read request: {error}"), } } String::from_utf8_lossy(buffer.as_slice()).to_string() } fn write_response(stream: &mut std::net::TcpStream, response: MockResponse) { let body = response.body; let mut raw_response = format!( "HTTP/1.1 {}\r\nContent-Type: {}\r\nContent-Length: {}\r\nConnection: close\r\n", response.status_line, response.content_type, body.len() ); for (name, value) in response.extra_headers { raw_response.push_str(format!("{name}: {value}\r\n").as_str()); } raw_response.push_str("\r\n"); raw_response.push_str(body.as_str()); stream .write_all(raw_response.as_bytes()) .expect("mock response should be written"); stream.flush().expect("mock response should flush"); } fn find_header_end(buffer: &[u8]) -> Option { buffer .windows(4) .position(|window| window == b"\r\n\r\n") .map(|index| index + 4) } fn read_content_length(headers: &[u8]) -> Option { let text = String::from_utf8_lossy(headers); text.lines().find_map(|line| { let (name, value) = line.split_once(':')?; if name.eq_ignore_ascii_case("content-length") { return value.trim().parse::().ok(); } None }) } }