1
This commit is contained in:
@@ -42,6 +42,7 @@ pub struct LlmConfig {
|
||||
request_timeout_ms: u64,
|
||||
max_retries: u32,
|
||||
retry_backoff_ms: u64,
|
||||
official_fallback: bool,
|
||||
}
|
||||
|
||||
// 首版只冻结当前项目已稳定使用的 system/user/assistant 三种消息角色。
|
||||
@@ -161,9 +162,11 @@ enum LlmRequestBody {
|
||||
#[derive(Serialize)]
|
||||
struct ChatCompletionsRequestBody {
|
||||
model: String,
|
||||
messages: Vec<LlmMessage>,
|
||||
messages: Vec<ChatCompletionsInputMessage>,
|
||||
stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
official_fallback: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
max_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
web_search_options: Option<ChatCompletionsWebSearchOptions>,
|
||||
@@ -172,12 +175,41 @@ struct ChatCompletionsRequestBody {
|
||||
#[derive(Serialize)]
|
||||
struct ChatCompletionsWebSearchOptions {}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ChatCompletionsInputMessage {
|
||||
role: &'static str,
|
||||
content: ChatCompletionsInputContent,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(untagged)]
|
||||
enum ChatCompletionsInputContent {
|
||||
Text(String),
|
||||
Parts(Vec<ChatCompletionsInputContentPart>),
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
enum ChatCompletionsInputContentPart {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
#[serde(rename = "image_url")]
|
||||
ImageUrl { image_url: ChatCompletionsImageUrl },
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ChatCompletionsImageUrl {
|
||||
url: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ResponsesRequestBody {
|
||||
model: String,
|
||||
stream: bool,
|
||||
input: Vec<ResponsesInputMessage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
official_fallback: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
max_output_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<ResponsesWebSearchTool>>,
|
||||
@@ -215,6 +247,15 @@ struct LlmRawFailureInputLog<'a> {
|
||||
messages: &'a [LlmMessage],
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum ChatCompletionsResponsePayload {
|
||||
Direct(ChatCompletionsResponseEnvelope),
|
||||
Wrapped {
|
||||
data: ChatCompletionsResponseEnvelope,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ChatCompletionsResponseEnvelope {
|
||||
id: Option<String>,
|
||||
@@ -344,9 +385,15 @@ impl LlmConfig {
|
||||
request_timeout_ms,
|
||||
max_retries,
|
||||
retry_backoff_ms,
|
||||
official_fallback: false,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_official_fallback(mut self, official_fallback: bool) -> Self {
|
||||
self.official_fallback = official_fallback;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn ark_default(api_key: String, model: String) -> Result<Self, LlmError> {
|
||||
Self::new(
|
||||
LlmProvider::Ark,
|
||||
@@ -387,6 +434,10 @@ impl LlmConfig {
|
||||
self.retry_backoff_ms
|
||||
}
|
||||
|
||||
pub fn official_fallback(&self) -> bool {
|
||||
self.official_fallback
|
||||
}
|
||||
|
||||
pub fn chat_completions_url(&self) -> String {
|
||||
format!(
|
||||
"{}/{}",
|
||||
@@ -886,7 +937,7 @@ impl LlmClient {
|
||||
request: &LlmTextRequest,
|
||||
stream: bool,
|
||||
) -> Result<reqwest::Response, LlmError> {
|
||||
let request_body = build_request_body(request, self.config.model(), stream);
|
||||
let request_body = build_request_body(request, &self.config, stream);
|
||||
let model = request.resolved_model(self.config.model());
|
||||
let url = match request.protocol {
|
||||
LlmTextProtocol::ChatCompletions => self.config.chat_completions_url(),
|
||||
@@ -1097,15 +1148,18 @@ fn normalize_non_empty(value: String, error_message: &str) -> Result<String, Llm
|
||||
|
||||
fn build_request_body(
|
||||
request: &LlmTextRequest,
|
||||
fallback_model: &str,
|
||||
config: &LlmConfig,
|
||||
stream: bool,
|
||||
) -> LlmRequestBody {
|
||||
let fallback_model = config.model();
|
||||
let official_fallback = config.official_fallback().then_some(true);
|
||||
match request.protocol {
|
||||
LlmTextProtocol::ChatCompletions => {
|
||||
LlmRequestBody::ChatCompletions(ChatCompletionsRequestBody {
|
||||
model: request.resolved_model(fallback_model).to_string(),
|
||||
messages: request.messages.clone(),
|
||||
messages: map_chat_completions_input_messages(request.messages.as_slice()),
|
||||
stream,
|
||||
official_fallback,
|
||||
max_tokens: request.max_tokens,
|
||||
web_search_options: request
|
||||
.enable_web_search
|
||||
@@ -1116,6 +1170,7 @@ fn build_request_body(
|
||||
model: request.resolved_model(fallback_model).to_string(),
|
||||
stream,
|
||||
input: map_responses_input_messages(request.messages.as_slice()),
|
||||
official_fallback,
|
||||
max_output_tokens: request.max_tokens,
|
||||
tools: request.enable_web_search.then(|| {
|
||||
vec![ResponsesWebSearchTool {
|
||||
@@ -1127,20 +1182,61 @@ fn build_request_body(
|
||||
}
|
||||
}
|
||||
|
||||
fn map_chat_completions_input_messages(
|
||||
messages: &[LlmMessage],
|
||||
) -> Vec<ChatCompletionsInputMessage> {
|
||||
messages
|
||||
.iter()
|
||||
.map(|message| ChatCompletionsInputMessage {
|
||||
role: map_llm_message_role(message.role),
|
||||
content: map_chat_completions_content(message),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn map_chat_completions_content(message: &LlmMessage) -> ChatCompletionsInputContent {
|
||||
if message.content_parts.is_empty() {
|
||||
return ChatCompletionsInputContent::Text(message.content.clone());
|
||||
}
|
||||
|
||||
ChatCompletionsInputContent::Parts(
|
||||
message
|
||||
.content_parts
|
||||
.iter()
|
||||
.map(|part| match part {
|
||||
LlmMessageContentPart::InputText { text } => {
|
||||
ChatCompletionsInputContentPart::Text { text: text.clone() }
|
||||
}
|
||||
LlmMessageContentPart::InputImage { image_url } => {
|
||||
ChatCompletionsInputContentPart::ImageUrl {
|
||||
image_url: ChatCompletionsImageUrl {
|
||||
url: image_url.clone(),
|
||||
},
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
fn map_responses_input_messages(messages: &[LlmMessage]) -> Vec<ResponsesInputMessage> {
|
||||
messages
|
||||
.iter()
|
||||
.map(|message| ResponsesInputMessage {
|
||||
role: match message.role {
|
||||
LlmMessageRole::System => "system",
|
||||
LlmMessageRole::User => "user",
|
||||
LlmMessageRole::Assistant => "assistant",
|
||||
},
|
||||
role: map_llm_message_role(message.role),
|
||||
content: map_responses_content_parts(message),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn map_llm_message_role(role: LlmMessageRole) -> &'static str {
|
||||
match role {
|
||||
LlmMessageRole::System => "system",
|
||||
LlmMessageRole::User => "user",
|
||||
LlmMessageRole::Assistant => "assistant",
|
||||
}
|
||||
}
|
||||
|
||||
fn map_responses_content_parts(message: &LlmMessage) -> Vec<ResponsesInputContentPart> {
|
||||
if message.content_parts.is_empty() {
|
||||
return vec![ResponsesInputContentPart::InputText {
|
||||
@@ -1265,8 +1361,12 @@ fn parse_chat_completions_response(
|
||||
fallback_model: &str,
|
||||
raw_text: &str,
|
||||
) -> Result<LlmTextResponse, LlmError> {
|
||||
let parsed: ChatCompletionsResponseEnvelope = serde_json::from_str(raw_text)
|
||||
let parsed: ChatCompletionsResponsePayload = serde_json::from_str(raw_text)
|
||||
.map_err(|error| LlmError::Deserialize(format!("解析 LLM JSON 响应失败:{error}")))?;
|
||||
let parsed = match parsed {
|
||||
ChatCompletionsResponsePayload::Direct(envelope) => envelope,
|
||||
ChatCompletionsResponsePayload::Wrapped { data } => data,
|
||||
};
|
||||
|
||||
let first_choice = parsed
|
||||
.choices
|
||||
@@ -1590,6 +1690,71 @@ mod tests {
|
||||
assert_eq!(config.responses_url(), "https://example.com/base/responses");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn llm_config_official_fallback_is_opt_in() {
|
||||
let config = LlmConfig::new(
|
||||
LlmProvider::OpenAiCompatible,
|
||||
"https://example.com/base".to_string(),
|
||||
"secret".to_string(),
|
||||
"model-a".to_string(),
|
||||
DEFAULT_REQUEST_TIMEOUT_MS,
|
||||
DEFAULT_MAX_RETRIES,
|
||||
DEFAULT_RETRY_BACKOFF_MS,
|
||||
)
|
||||
.expect("config should be valid");
|
||||
|
||||
assert!(!config.official_fallback());
|
||||
assert!(config.with_official_fallback(true).official_fallback());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn request_text_sends_official_fallback_for_openai_compatible_clients() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind");
|
||||
let address = listener.local_addr().expect("listener should have addr");
|
||||
let server_handle = thread::spawn(move || {
|
||||
let (mut stream, _) = listener.accept().expect("request should connect");
|
||||
let request_text = read_request(&mut stream);
|
||||
write_response(
|
||||
&mut stream,
|
||||
MockResponse {
|
||||
status_line: "200 OK",
|
||||
content_type: "application/json; charset=utf-8",
|
||||
body: r#"{"id":"resp_openai_compatible","model":"gpt-5","output_text":"兼容成功","status":"completed"}"#.to_string(),
|
||||
extra_headers: Vec::new(),
|
||||
},
|
||||
);
|
||||
request_text
|
||||
});
|
||||
|
||||
let config = LlmConfig::new(
|
||||
LlmProvider::OpenAiCompatible,
|
||||
format!("http://{address}"),
|
||||
"test-key".to_string(),
|
||||
"gpt-5".to_string(),
|
||||
DEFAULT_REQUEST_TIMEOUT_MS,
|
||||
0,
|
||||
1,
|
||||
)
|
||||
.expect("config should be valid")
|
||||
.with_official_fallback(true);
|
||||
let client = LlmClient::new(config).expect("client should be created");
|
||||
let response = client
|
||||
.request_text(LlmTextRequest::single_turn("系统", "用户").with_responses_api())
|
||||
.await
|
||||
.expect("request_text should succeed");
|
||||
|
||||
let request_text = server_handle.join().expect("server thread should join");
|
||||
let request_body = request_text
|
||||
.split("\r\n\r\n")
|
||||
.nth(1)
|
||||
.expect("request body should exist");
|
||||
let request_json: serde_json::Value =
|
||||
serde_json::from_str(request_body).expect("request body should be json");
|
||||
|
||||
assert_eq!(response.content, "兼容成功");
|
||||
assert_eq!(request_json["official_fallback"], serde_json::json!(true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sse_parser_handles_split_chunks_and_done_marker() {
|
||||
let mut parser = OpenAiCompatibleSseParser::new(LlmTextProtocol::ChatCompletions);
|
||||
@@ -1711,8 +1876,9 @@ mod tests {
|
||||
MockResponse {
|
||||
status_line: "200 OK",
|
||||
content_type: "application/json; charset=utf-8",
|
||||
body: r#"{"choices":[{"message":{"content":"too late"},"finish_reason":"stop"}]}"#
|
||||
.to_string(),
|
||||
body:
|
||||
r#"{"choices":[{"message":{"content":"too late"},"finish_reason":"stop"}]}"#
|
||||
.to_string(),
|
||||
extra_headers: Vec::new(),
|
||||
},
|
||||
);
|
||||
@@ -1731,9 +1897,7 @@ mod tests {
|
||||
let client = LlmClient::new(config).expect("client should be created");
|
||||
|
||||
let error = client
|
||||
.request_text(
|
||||
LlmTextRequest::single_turn("系统", "用户").with_request_timeout_ms(20),
|
||||
)
|
||||
.request_text(LlmTextRequest::single_turn("系统", "用户").with_request_timeout_ms(20))
|
||||
.await
|
||||
.expect_err("request override should timeout before the global timeout");
|
||||
|
||||
@@ -1779,6 +1943,75 @@ mod tests {
|
||||
|
||||
assert_eq!(response.content, "搜索成功");
|
||||
assert_eq!(request_json["web_search_options"], serde_json::json!({}));
|
||||
assert!(request_json.get("official_fallback").is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_completions_multimodal_request_sends_text_and_image_url_parts() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind");
|
||||
let address = listener.local_addr().expect("listener should have addr");
|
||||
let server_handle = thread::spawn(move || {
|
||||
let (mut stream, _) = listener.accept().expect("request should connect");
|
||||
let request_text = read_request(&mut stream);
|
||||
write_response(
|
||||
&mut stream,
|
||||
MockResponse {
|
||||
status_line: "200 OK",
|
||||
content_type: "application/json; charset=utf-8",
|
||||
body: r#"{"id":"chat_multimodal","model":"gpt-4o-mini","choices":[{"message":{"content":"{\"levelName\":\"雨夜猫街\"}"},"finish_reason":"stop"}]}"#.to_string(),
|
||||
extra_headers: Vec::new(),
|
||||
},
|
||||
);
|
||||
request_text
|
||||
});
|
||||
|
||||
let config = LlmConfig::new(
|
||||
LlmProvider::OpenAiCompatible,
|
||||
format!("http://{address}"),
|
||||
"test-key".to_string(),
|
||||
"gpt-4o-mini".to_string(),
|
||||
DEFAULT_REQUEST_TIMEOUT_MS,
|
||||
0,
|
||||
1,
|
||||
)
|
||||
.expect("config should be valid")
|
||||
.with_official_fallback(true);
|
||||
let client = LlmClient::new(config).expect("client should be created");
|
||||
let response = client
|
||||
.request_text(LlmTextRequest::new(vec![
|
||||
LlmMessage::system("你是拼图关卡命名编辑"),
|
||||
LlmMessage::user_multimodal(vec![
|
||||
LlmMessageContentPart::InputText {
|
||||
text: "画面描述:一只猫在雨夜灯牌下回头。".to_string(),
|
||||
},
|
||||
LlmMessageContentPart::InputImage {
|
||||
image_url: "data:image/png;base64,abcd".to_string(),
|
||||
},
|
||||
]),
|
||||
]))
|
||||
.await
|
||||
.expect("request_text should succeed");
|
||||
|
||||
let request_text = server_handle.join().expect("server thread should join");
|
||||
let request_line = request_text.lines().next().unwrap_or_default();
|
||||
let request_body = request_text
|
||||
.split("\r\n\r\n")
|
||||
.nth(1)
|
||||
.expect("request body should exist");
|
||||
let request_json: serde_json::Value =
|
||||
serde_json::from_str(request_body).expect("request body should be json");
|
||||
|
||||
assert!(request_line.contains("POST /chat/completions HTTP/1.1"));
|
||||
assert_eq!(response.model, "gpt-4o-mini");
|
||||
assert_eq!(response.content, r#"{"levelName":"雨夜猫街"}"#);
|
||||
assert_eq!(request_json["official_fallback"], serde_json::json!(true));
|
||||
assert_eq!(
|
||||
request_json["messages"][1]["content"],
|
||||
serde_json::json!([
|
||||
{ "type": "text", "text": "画面描述:一只猫在雨夜灯牌下回头。" },
|
||||
{ "type": "image_url", "image_url": { "url": "data:image/png;base64,abcd" } }
|
||||
])
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -1841,6 +2074,7 @@ mod tests {
|
||||
request_json["tools"],
|
||||
serde_json::json!([{ "type": "web_search", "max_keyword": 3 }])
|
||||
);
|
||||
assert!(request_json.get("official_fallback").is_none());
|
||||
assert_eq!(
|
||||
request_json["input"][0]["content"][0],
|
||||
serde_json::json!({ "type": "input_text", "text": "系统" })
|
||||
@@ -1896,6 +2130,7 @@ mod tests {
|
||||
|
||||
assert_eq!(response.model, "gpt-5");
|
||||
assert_eq!(request_json["model"], serde_json::json!("gpt-5"));
|
||||
assert!(request_json.get("official_fallback").is_none());
|
||||
assert_eq!(
|
||||
request_json["input"][1]["content"],
|
||||
serde_json::json!([
|
||||
|
||||
Reference in New Issue
Block a user