use std::{error::Error, fmt, str as std_str, time::Duration}; use log::{debug, warn}; use reqwest::{Client, StatusCode}; use serde::{Deserialize, Serialize}; use tokio::time::sleep; pub const DEFAULT_ARK_BASE_URL: &str = "https://ark.cn-beijing.volces.com/api/v3"; pub const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 30_000; pub const DEFAULT_MAX_RETRIES: u32 = 1; pub const DEFAULT_RETRY_BACKOFF_MS: u64 = 500; pub const CHAT_COMPLETIONS_PATH: &str = "/chat/completions"; // 冻结平台来源,避免上层继续散落 provider 字符串。 #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum LlmProvider { Ark, DashScope, OpenAiCompatible, } // 统一收口文本模型网关配置,避免 api-server 和业务模块各自重复解析环境变量。 #[derive(Clone, Debug, PartialEq, Eq)] pub struct LlmConfig { provider: LlmProvider, base_url: String, api_key: String, model: String, request_timeout_ms: u64, max_retries: u32, retry_backoff_ms: u64, } // 首版只冻结当前项目已稳定使用的 system/user/assistant 三种消息角色。 #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum LlmMessageRole { System, User, Assistant, } // 单条消息保持 OpenAI 兼容格式,供统一请求体直接序列化。 #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct LlmMessage { pub role: LlmMessageRole, pub content: String, } // 文本补全请求冻结为“消息列表 + 可选模型覆盖 + 可选 max_tokens”最小闭环。 #[derive(Clone, Debug, PartialEq, Eq)] pub struct LlmTextRequest { pub model: Option, pub messages: Vec, pub max_tokens: Option, } // 上层在流式消费时拿到的是“累计文本 + 当前增量”,避免每层重新自己拼接。 #[derive(Clone, Debug, PartialEq, Eq)] pub struct LlmStreamDelta { pub accumulated_text: String, pub delta_text: String, pub finish_reason: Option, } // 用于保留 token 计数,后续模块可以决定是否写入审计或成本统计。 #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct LlmTokenUsage { pub prompt_tokens: u64, pub completion_tokens: u64, pub total_tokens: u64, } // 统一文本响应,避免业务层再去解析 choices/message/content。 #[derive(Clone, Debug, PartialEq, Eq)] pub struct LlmTextResponse { pub provider: LlmProvider, pub model: String, pub content: String, pub finish_reason: Option, pub response_id: Option, pub usage: Option, } // 将上游错误归一到稳定的领域枚举,后续 api-server 可以直接映射成 HTTP error contract。 #[derive(Debug, PartialEq, Eq)] pub enum LlmError { InvalidConfig(String), InvalidRequest(String), Timeout { attempts: u32 }, Connectivity { attempts: u32, message: String }, Upstream { status_code: u16, message: String }, StreamUnavailable, EmptyResponse, Transport(String), Deserialize(String), } // 统一 OpenAI 兼容文本网关 client。 #[derive(Clone, Debug)] pub struct LlmClient { config: LlmConfig, http_client: Client, } #[derive(Serialize)] struct ChatCompletionsRequestBody<'a> { model: &'a str, messages: &'a [LlmMessage], stream: bool, #[serde(skip_serializing_if = "Option::is_none")] max_tokens: Option, } #[derive(Deserialize)] struct ChatCompletionsResponseEnvelope { id: Option, model: Option, choices: Vec, usage: Option, } #[derive(Deserialize)] struct ChatCompletionsChoice { #[serde(default)] message: Option, #[serde(default)] delta: Option, #[serde(default)] finish_reason: Option, } #[derive(Deserialize)] struct ChatCompletionsMessage { #[serde(default)] content: Option, } #[derive(Deserialize)] #[serde(untagged)] enum ChatCompletionsContent { Text(String), Parts(Vec), } #[derive(Deserialize)] struct ChatCompletionsContentPart { #[serde(rename = "type")] #[allow(dead_code)] part_type: Option, #[serde(default)] text: Option, } #[derive(Default)] struct OpenAiCompatibleSseParser { buffer: String, } #[derive(Debug)] struct ParsedStreamEvent { delta_text: Option, finish_reason: Option, } impl LlmProvider { pub fn as_str(&self) -> &'static str { match self { Self::Ark => "ark", Self::DashScope => "dash_scope", Self::OpenAiCompatible => "openai_compatible", } } } impl LlmConfig { #[allow(clippy::too_many_arguments)] pub fn new( provider: LlmProvider, base_url: String, api_key: String, model: String, request_timeout_ms: u64, max_retries: u32, retry_backoff_ms: u64, ) -> Result { let base_url = normalize_non_empty(base_url, "LLM base_url 不能为空")?; let api_key = normalize_non_empty(api_key, "LLM api_key 不能为空")?; let model = normalize_non_empty(model, "LLM model 不能为空")?; if request_timeout_ms == 0 { return Err(LlmError::InvalidConfig( "LLM request_timeout_ms 必须大于 0".to_string(), )); } Ok(Self { provider, base_url, api_key, model, request_timeout_ms, max_retries, retry_backoff_ms, }) } pub fn ark_default(api_key: String, model: String) -> Result { Self::new( LlmProvider::Ark, DEFAULT_ARK_BASE_URL.to_string(), api_key, model, DEFAULT_REQUEST_TIMEOUT_MS, DEFAULT_MAX_RETRIES, DEFAULT_RETRY_BACKOFF_MS, ) } pub fn provider(&self) -> LlmProvider { self.provider } pub fn base_url(&self) -> &str { &self.base_url } pub fn api_key(&self) -> &str { &self.api_key } pub fn model(&self) -> &str { &self.model } pub fn request_timeout_ms(&self) -> u64 { self.request_timeout_ms } pub fn max_retries(&self) -> u32 { self.max_retries } pub fn retry_backoff_ms(&self) -> u64 { self.retry_backoff_ms } pub fn chat_completions_url(&self) -> String { format!( "{}/{}", self.base_url.trim_end_matches('/'), CHAT_COMPLETIONS_PATH.trim_start_matches('/') ) } } impl LlmMessage { pub fn new(role: LlmMessageRole, content: impl Into) -> Self { Self { role, content: content.into(), } } pub fn system(content: impl Into) -> Self { Self::new(LlmMessageRole::System, content) } pub fn user(content: impl Into) -> Self { Self::new(LlmMessageRole::User, content) } pub fn assistant(content: impl Into) -> Self { Self::new(LlmMessageRole::Assistant, content) } } impl LlmTextRequest { pub fn new(messages: Vec) -> Self { Self { model: None, messages, max_tokens: None, } } pub fn single_turn(system_prompt: impl Into, user_prompt: impl Into) -> Self { Self::new(vec![ LlmMessage::system(system_prompt), LlmMessage::user(user_prompt), ]) } pub fn with_model(mut self, model: impl Into) -> Self { self.model = Some(model.into()); self } pub fn with_max_tokens(mut self, max_tokens: u32) -> Self { self.max_tokens = Some(max_tokens); self } fn validate(&self) -> Result<(), LlmError> { if self.messages.is_empty() { return Err(LlmError::InvalidRequest( "LLM messages 不能为空".to_string(), )); } for message in &self.messages { if message.content.trim().is_empty() { return Err(LlmError::InvalidRequest( "LLM message.content 不能为空".to_string(), )); } } if let Some(model) = &self.model && model.trim().is_empty() { return Err(LlmError::InvalidRequest( "LLM request.model 不能为空字符串".to_string(), )); } Ok(()) } fn resolved_model<'a>(&'a self, fallback_model: &'a str) -> &'a str { self.model .as_deref() .map(str::trim) .filter(|value| !value.is_empty()) .unwrap_or(fallback_model) } } impl fmt::Display for LlmError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::InvalidConfig(message) | Self::InvalidRequest(message) | Self::Transport(message) | Self::Deserialize(message) => write!(f, "{message}"), Self::Timeout { attempts } => { write!(f, "LLM 请求超时,累计尝试 {attempts} 次") } Self::Connectivity { attempts, message } => { write!(f, "LLM 连接失败,累计尝试 {attempts} 次:{message}") } Self::Upstream { status_code, message, } => write!(f, "LLM 上游返回 {status_code}:{message}"), Self::StreamUnavailable => write!(f, "LLM 流式响应体不可用"), Self::EmptyResponse => write!(f, "LLM 返回内容为空"), } } } impl Error for LlmError {} impl LlmClient { pub fn new(config: LlmConfig) -> Result { let http_client = Client::builder().build().map_err(|error| { LlmError::InvalidConfig(format!("构建 reqwest client 失败:{error}")) })?; Ok(Self { config, http_client, }) } pub fn config(&self) -> &LlmConfig { &self.config } pub async fn request_text(&self, request: LlmTextRequest) -> Result { request.validate()?; let resolved_model = request.resolved_model(self.config.model()).to_string(); let response = self.execute_request(&request, false).await?; let raw_text = response .text() .await .map_err(|error| map_stream_read_error(error, 1))?; parse_chat_completions_response(self.config.provider(), &resolved_model, raw_text.as_str()) } pub async fn request_single_message_text( &self, system_prompt: impl Into, user_prompt: impl Into, ) -> Result { self.request_text(LlmTextRequest::single_turn(system_prompt, user_prompt)) .await } pub async fn stream_text( &self, request: LlmTextRequest, mut on_delta: F, ) -> Result where F: FnMut(&LlmStreamDelta), { request.validate()?; let resolved_model = request.resolved_model(self.config.model()).to_string(); let mut response = self.execute_request(&request, true).await?; let response_id = response .headers() .get("x-request-id") .and_then(|value| value.to_str().ok()) .map(str::to_string); let mut parser = OpenAiCompatibleSseParser::default(); let mut accumulated_text = String::new(); let mut finish_reason = None; let mut undecoded_chunk_bytes = Vec::new(); loop { let next_chunk = response .chunk() .await .map_err(|error| map_stream_read_error(error, 1))?; let Some(chunk) = next_chunk else { break; }; undecoded_chunk_bytes.extend_from_slice(chunk.as_ref()); let (chunk_text, remaining_bytes) = decode_utf8_stream_chunk(undecoded_chunk_bytes.as_slice())?; undecoded_chunk_bytes = remaining_bytes; if chunk_text.is_empty() { continue; } for event in parser.push_chunk(chunk_text.as_ref())? { if let Some(delta_text) = event.delta_text && !delta_text.is_empty() { accumulated_text.push_str(delta_text.as_str()); let update = LlmStreamDelta { accumulated_text: accumulated_text.clone(), delta_text, finish_reason: event.finish_reason.clone(), }; on_delta(&update); } if event.finish_reason.is_some() { finish_reason = event.finish_reason; } } } if !undecoded_chunk_bytes.is_empty() { let trailing_text = std_str::from_utf8(undecoded_chunk_bytes.as_slice()) .map_err(|error| { LlmError::Deserialize(format!( "解析 LLM 流式 UTF-8 响应失败:{error}" )) })?; if !trailing_text.is_empty() { for event in parser.push_chunk(trailing_text)? { if let Some(delta_text) = event.delta_text && !delta_text.is_empty() { accumulated_text.push_str(delta_text.as_str()); let update = LlmStreamDelta { accumulated_text: accumulated_text.clone(), delta_text, finish_reason: event.finish_reason.clone(), }; on_delta(&update); } if event.finish_reason.is_some() { finish_reason = event.finish_reason; } } } } for event in parser.finish()? { if let Some(delta_text) = event.delta_text && !delta_text.is_empty() { accumulated_text.push_str(delta_text.as_str()); let update = LlmStreamDelta { accumulated_text: accumulated_text.clone(), delta_text, finish_reason: event.finish_reason.clone(), }; on_delta(&update); } if event.finish_reason.is_some() { finish_reason = event.finish_reason; } } let content = accumulated_text.trim().to_string(); if content.is_empty() { return Err(LlmError::EmptyResponse); } Ok(LlmTextResponse { provider: self.config.provider(), model: resolved_model, content, finish_reason, response_id, usage: None, }) } pub async fn stream_single_message_text( &self, system_prompt: impl Into, user_prompt: impl Into, on_delta: F, ) -> Result where F: FnMut(&LlmStreamDelta), { self.stream_text( LlmTextRequest::single_turn(system_prompt, user_prompt), on_delta, ) .await } async fn execute_request( &self, request: &LlmTextRequest, stream: bool, ) -> Result { let request_body = ChatCompletionsRequestBody { model: request.resolved_model(self.config.model()), messages: request.messages.as_slice(), stream, max_tokens: request.max_tokens, }; let max_attempts = self.config.max_retries().saturating_add(1); for attempt in 1..=max_attempts { debug!( "platform-llm request started: provider={}, stream={}, attempt={}, model={}", self.config.provider().as_str(), stream, attempt, request_body.model ); let send_result = self .http_client .post(self.config.chat_completions_url()) .bearer_auth(self.config.api_key()) .json(&request_body) .timeout(Duration::from_millis(self.config.request_timeout_ms())) .send() .await; match send_result { Ok(response) if response.status().is_success() => { debug!( "platform-llm request succeeded: provider={}, stream={}, attempt={}, status={}", self.config.provider().as_str(), stream, attempt, response.status().as_u16() ); return Ok(response); } Ok(response) => { let status = response.status(); let raw_text = response.text().await.unwrap_or_default(); let message = extract_api_error_message(&raw_text, "LLM 上游请求失败"); if should_retry_status(status) && attempt < max_attempts { warn!( "platform-llm request retrying after upstream status: provider={}, attempt={}, status={}, message={}", self.config.provider().as_str(), attempt, status.as_u16(), message ); self.sleep_before_retry(attempt).await; continue; } return Err(LlmError::Upstream { status_code: status.as_u16(), message, }); } Err(error) if error.is_timeout() => { if attempt < max_attempts { warn!( "platform-llm request retrying after timeout: provider={}, attempt={}", self.config.provider().as_str(), attempt ); self.sleep_before_retry(attempt).await; continue; } return Err(LlmError::Timeout { attempts: attempt }); } Err(error) if error.is_connect() => { let message = error.to_string(); if attempt < max_attempts { warn!( "platform-llm request retrying after connectivity failure: provider={}, attempt={}, error={}", self.config.provider().as_str(), attempt, message ); self.sleep_before_retry(attempt).await; continue; } return Err(LlmError::Connectivity { attempts: attempt, message, }); } Err(error) => { return Err(LlmError::Transport(error.to_string())); } } } Err(LlmError::Transport( "LLM 请求在重试循环后仍未返回结果".to_string(), )) } async fn sleep_before_retry(&self, attempt: u32) { let backoff_ms = self .config .retry_backoff_ms() .saturating_mul(u64::from(attempt)); if backoff_ms > 0 { sleep(Duration::from_millis(backoff_ms)).await; } } } impl OpenAiCompatibleSseParser { fn push_chunk(&mut self, chunk: &str) -> Result, LlmError> { self.buffer.push_str(chunk); self.buffer = self.buffer.replace("\r\n", "\n"); self.drain_complete_events() } fn finish(&mut self) -> Result, LlmError> { if self.buffer.trim().is_empty() { return Ok(Vec::new()); } self.buffer.push_str("\n\n"); self.drain_complete_events() } fn drain_complete_events(&mut self) -> Result, LlmError> { let mut events = Vec::new(); while let Some(boundary) = self.buffer.find("\n\n") { let block = self.buffer[..boundary].to_string(); self.buffer = self.buffer[(boundary + 2)..].to_string(); if let Some(event) = parse_sse_event_block(block.as_str())? { events.push(event); } } Ok(events) } } fn normalize_non_empty(value: String, error_message: &str) -> Result { let trimmed = value.trim().to_string(); if trimmed.is_empty() { return Err(LlmError::InvalidConfig(error_message.to_string())); } Ok(trimmed) } fn parse_chat_completions_response( provider: LlmProvider, fallback_model: &str, raw_text: &str, ) -> Result { let parsed: ChatCompletionsResponseEnvelope = serde_json::from_str(raw_text) .map_err(|error| LlmError::Deserialize(format!("解析 LLM JSON 响应失败:{error}")))?; let first_choice = parsed .choices .first() .ok_or_else(|| LlmError::Deserialize("LLM 响应缺少 choices[0]".to_string()))?; let content = extract_message_text(first_choice) .ok_or(LlmError::EmptyResponse)? .trim() .to_string(); if content.is_empty() { return Err(LlmError::EmptyResponse); } Ok(LlmTextResponse { provider, model: parsed.model.unwrap_or_else(|| fallback_model.to_string()), content, finish_reason: first_choice.finish_reason.clone(), response_id: parsed.id, usage: parsed.usage, }) } fn extract_message_text(choice: &ChatCompletionsChoice) -> Option { choice .message .as_ref() .and_then(|message| message.content.as_ref()) .and_then(extract_content_text) .or_else(|| { choice .delta .as_ref() .and_then(|message| message.content.as_ref()) .and_then(extract_content_text) }) } fn extract_content_text(content: &ChatCompletionsContent) -> Option { match content { ChatCompletionsContent::Text(text) => Some(text.clone()), ChatCompletionsContent::Parts(parts) => { let text = parts .iter() .filter_map(|part| part.text.as_deref()) .collect::>() .join(""); if text.is_empty() { None } else { Some(text) } } } } fn decode_utf8_stream_chunk(bytes: &[u8]) -> Result<(String, Vec), LlmError> { match std_str::from_utf8(bytes) { Ok(text) => Ok((text.to_string(), Vec::new())), Err(error) => { let valid_up_to = error.valid_up_to(); let Some(_) = error.error_len() else { let decoded = std_str::from_utf8(&bytes[..valid_up_to]).map_err(|inner_error| { LlmError::Deserialize(format!( "解析 LLM 流式 UTF-8 响应失败:{inner_error}" )) })?; return Ok((decoded.to_string(), bytes[valid_up_to..].to_vec())); }; Err(LlmError::Deserialize(format!( "解析 LLM 流式 UTF-8 响应失败:{error}" ))) } } } fn parse_sse_event_block(block: &str) -> Result, LlmError> { let data_lines = block .lines() .filter_map(|line| line.trim().strip_prefix("data:")) .map(str::trim_start) .collect::>(); if data_lines.is_empty() { return Ok(None); } let data = data_lines.join("\n"); if data.trim().is_empty() || data.trim() == "[DONE]" { return Ok(None); } let parsed: ChatCompletionsResponseEnvelope = serde_json::from_str(data.as_str()) .map_err(|error| LlmError::Deserialize(format!("解析 LLM SSE 事件失败:{error}")))?; let first_choice = parsed .choices .first() .ok_or_else(|| LlmError::Deserialize("LLM SSE 响应缺少 choices[0]".to_string()))?; Ok(Some(ParsedStreamEvent { delta_text: extract_message_text(first_choice), finish_reason: first_choice.finish_reason.clone(), })) } fn should_retry_status(status: StatusCode) -> bool { status == StatusCode::REQUEST_TIMEOUT || status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error() } fn extract_api_error_message(raw_text: &str, fallback_message: &str) -> String { let trimmed = raw_text.trim(); if trimmed.is_empty() { return fallback_message.to_string(); } let parsed = serde_json::from_str::(trimmed); if let Ok(value) = parsed { if let Some(message) = value .get("error") .and_then(|error| error.get("message")) .and_then(serde_json::Value::as_str) .map(str::trim) .filter(|message| !message.is_empty()) { return message.to_string(); } if let Some(message) = value .get("message") .and_then(serde_json::Value::as_str) .map(str::trim) .filter(|message| !message.is_empty()) { return message.to_string(); } } trimmed.to_string() } fn map_stream_read_error(error: reqwest::Error, attempts: u32) -> LlmError { if error.is_timeout() { return LlmError::Timeout { attempts }; } if error.is_connect() { return LlmError::Connectivity { attempts, message: error.to_string(), }; } LlmError::Transport(error.to_string()) } #[cfg(test)] mod tests { use std::{ io::{Read, Write}, net::TcpListener, thread, time::Duration as StdDuration, }; use super::*; struct MockResponse { status_line: &'static str, content_type: &'static str, body: String, extra_headers: Vec<(&'static str, &'static str)>, } #[test] fn llm_config_rejects_blank_api_key() { let error = LlmConfig::new( LlmProvider::Ark, DEFAULT_ARK_BASE_URL.to_string(), " ".to_string(), "model-a".to_string(), DEFAULT_REQUEST_TIMEOUT_MS, DEFAULT_MAX_RETRIES, DEFAULT_RETRY_BACKOFF_MS, ) .expect_err("blank api key should be rejected"); assert_eq!( error, LlmError::InvalidConfig("LLM api_key 不能为空".to_string()) ); } #[test] fn llm_chat_completion_url_normalizes_trailing_slash() { let config = LlmConfig::new( LlmProvider::OpenAiCompatible, "https://example.com/base///".to_string(), "secret".to_string(), "model-a".to_string(), DEFAULT_REQUEST_TIMEOUT_MS, DEFAULT_MAX_RETRIES, DEFAULT_RETRY_BACKOFF_MS, ) .expect("config should be valid"); assert_eq!( config.chat_completions_url(), "https://example.com/base/chat/completions" ); } #[test] fn sse_parser_handles_split_chunks_and_done_marker() { let mut parser = OpenAiCompatibleSseParser::default(); let events_a = parser .push_chunk("data: {\"choices\":[{\"delta\":{\"content\":\"你\"}}]}\r\n\r\n") .expect("first chunk should parse"); let events_b = parser .push_chunk("data: {\"choices\":[{\"delta\":{\"content\":\"好\"},\"finish_reason\":\"stop\"}]}\n\ndata: [DONE]\n\n") .expect("second chunk should parse"); assert_eq!(events_a.len(), 1); assert_eq!(events_a[0].delta_text.as_deref(), Some("你")); assert_eq!(events_b.len(), 1); assert_eq!(events_b[0].delta_text.as_deref(), Some("好")); assert_eq!(events_b[0].finish_reason.as_deref(), Some("stop")); } #[test] fn decode_utf8_stream_chunk_preserves_incomplete_multibyte_suffix() { let full_bytes = "你好".as_bytes(); let first_result = decode_utf8_stream_chunk(&full_bytes[..2]) .expect("incomplete utf-8 chunk should be buffered"); assert_eq!(first_result.0, ""); assert_eq!(first_result.1, full_bytes[..2].to_vec()); let mut combined = first_result.1; combined.extend_from_slice(&full_bytes[2..]); let second_result = decode_utf8_stream_chunk(combined.as_slice()) .expect("completed utf-8 bytes should decode"); assert_eq!(second_result.0, "你好"); assert!(second_result.1.is_empty()); } #[tokio::test] async fn request_text_parses_non_stream_response() { let server_url = spawn_mock_server(vec![MockResponse { status_line: "200 OK", content_type: "application/json; charset=utf-8", body: r#"{"id":"resp_01","model":"ark-test-model","choices":[{"message":{"content":"测试成功"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":6,"total_tokens":16}}"#.to_string(), extra_headers: Vec::new(), }]); let client = build_test_client(server_url, 0); let response = client .request_single_message_text("系统", "用户") .await .expect("request_text should succeed"); assert_eq!(response.provider, LlmProvider::Ark); assert_eq!(response.model, "ark-test-model"); assert_eq!(response.content, "测试成功"); assert_eq!(response.finish_reason.as_deref(), Some("stop")); assert_eq!(response.response_id.as_deref(), Some("resp_01")); assert_eq!( response.usage, Some(LlmTokenUsage { prompt_tokens: 10, completion_tokens: 6, total_tokens: 16, }) ); } #[tokio::test] async fn request_text_retries_after_upstream_500() { let server_url = spawn_mock_server(vec![ MockResponse { status_line: "500 Internal Server Error", content_type: "application/json; charset=utf-8", body: r#"{"error":{"message":"temporary upstream failure"}}"#.to_string(), extra_headers: Vec::new(), }, MockResponse { status_line: "200 OK", content_type: "application/json; charset=utf-8", body: r#"{"id":"resp_retry","choices":[{"message":{"content":"第二次成功"},"finish_reason":"stop"}]}"#.to_string(), extra_headers: Vec::new(), }, ]); let client = build_test_client(server_url, 1); let response = client .request_single_message_text("系统", "用户") .await .expect("second attempt should succeed"); assert_eq!(response.content, "第二次成功"); assert_eq!(response.response_id.as_deref(), Some("resp_retry")); } #[tokio::test] async fn stream_text_accumulates_sse_response() { let server_url = spawn_mock_server(vec![MockResponse { status_line: "200 OK", content_type: "text/event-stream; charset=utf-8", body: concat!( "data: {\"choices\":[{\"delta\":{\"content\":\"你\"}}]}\n\n", "data: {\"choices\":[{\"delta\":{\"content\":\"好\"}}]}\n\n", "data: {\"choices\":[{\"finish_reason\":\"stop\"}]}\n\n", "data: [DONE]\n\n" ) .to_string(), extra_headers: vec![("x-request-id", "req_stream_01")], }]); let client = build_test_client(server_url, 0); let mut updates = Vec::new(); let response = client .stream_single_message_text("系统", "用户", |delta| { updates.push(delta.accumulated_text.clone()); }) .await .expect("stream_text should succeed"); assert_eq!(updates, vec!["你".to_string(), "你好".to_string()]); assert_eq!(response.content, "你好"); assert_eq!(response.finish_reason.as_deref(), Some("stop")); assert_eq!(response.response_id.as_deref(), Some("req_stream_01")); } fn build_test_client(base_url: String, max_retries: u32) -> LlmClient { let config = LlmConfig::new( LlmProvider::Ark, base_url, "test-key".to_string(), "test-model".to_string(), DEFAULT_REQUEST_TIMEOUT_MS, max_retries, 1, ) .expect("config should be valid"); LlmClient::new(config).expect("client should be created") } fn spawn_mock_server(responses: Vec) -> String { let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind"); let address = listener.local_addr().expect("listener should have addr"); thread::spawn(move || { for response in responses { let (mut stream, _) = listener.accept().expect("request should connect"); read_request(&mut stream); write_response(&mut stream, response); } }); format!("http://{address}") } fn read_request(stream: &mut std::net::TcpStream) { stream .set_read_timeout(Some(StdDuration::from_secs(1))) .expect("read timeout should be set"); let mut buffer = Vec::new(); let mut chunk = [0_u8; 1024]; let mut expected_total = None; loop { match stream.read(&mut chunk) { Ok(0) => break, Ok(bytes_read) => { buffer.extend_from_slice(&chunk[..bytes_read]); if expected_total.is_none() && let Some(header_end) = find_header_end(&buffer) { let content_length = read_content_length(&buffer[..header_end]).unwrap_or(0); expected_total = Some(header_end + content_length); } if let Some(total_bytes) = expected_total && buffer.len() >= total_bytes { break; } } Err(error) if error.kind() == std::io::ErrorKind::WouldBlock || error.kind() == std::io::ErrorKind::TimedOut => { break; } Err(error) => panic!("mock server failed to read request: {error}"), } } } fn write_response(stream: &mut std::net::TcpStream, response: MockResponse) { let body = response.body; let mut raw_response = format!( "HTTP/1.1 {}\r\nContent-Type: {}\r\nContent-Length: {}\r\nConnection: close\r\n", response.status_line, response.content_type, body.len() ); for (name, value) in response.extra_headers { raw_response.push_str(format!("{name}: {value}\r\n").as_str()); } raw_response.push_str("\r\n"); raw_response.push_str(body.as_str()); stream .write_all(raw_response.as_bytes()) .expect("mock response should be written"); stream.flush().expect("mock response should flush"); } fn find_header_end(buffer: &[u8]) -> Option { buffer .windows(4) .position(|window| window == b"\r\n\r\n") .map(|index| index + 4) } fn read_content_length(headers: &[u8]) -> Option { let text = String::from_utf8_lossy(headers); text.lines().find_map(|line| { let (name, value) = line.split_once(':')?; if name.eq_ignore_ascii_case("content-length") { return value.trim().parse::().ok(); } None }) } }