This commit is contained in:
2026-05-01 22:16:01 +08:00
parent 8d46c05129
commit 33dd105630
36 changed files with 1999 additions and 236 deletions

View File

@@ -1,4 +1,4 @@
use platform_llm::{LlmClient, LlmMessage, LlmStreamDelta, LlmTextRequest};
use platform_llm::{LlmClient, LlmError, LlmMessage, LlmStreamDelta, LlmTextRequest};
use serde_json::Value as JsonValue;
use crate::llm_model_routing::CREATION_TEMPLATE_LLM_MODEL;
@@ -33,10 +33,63 @@ where
{
let llm_client =
llm_client.ok_or_else(|| build_error(messages.model_unavailable.to_string()))?;
let user_prompt = user_prompt.into();
let turn_output = match request_stream_creation_agent_json_turn(
llm_client,
system_prompt.clone(),
user_prompt.clone(),
enable_web_search,
&mut on_reply_update,
)
.await
{
Ok(turn_output) => Ok(turn_output),
Err(CreationAgentJsonTurnFailure::Stream(error))
if enable_web_search && is_web_search_tool_unavailable(&error) =>
{
tracing::warn!(
error = %error,
"创作 Agent 联网搜索插件不可用,自动降级为无联网搜索重试"
);
request_stream_creation_agent_json_turn(
llm_client,
system_prompt,
user_prompt,
false,
&mut on_reply_update,
)
.await
}
Err(error) => Err(error),
};
turn_output.map_err(|error| match error {
CreationAgentJsonTurnFailure::Stream(_) => {
build_error(messages.generation_failed.to_string())
}
CreationAgentJsonTurnFailure::Parse => build_error(messages.parse_failed.to_string()),
})
}
enum CreationAgentJsonTurnFailure {
Stream(LlmError),
Parse,
}
async fn request_stream_creation_agent_json_turn<F>(
llm_client: &LlmClient,
system_prompt: String,
user_prompt: String,
enable_web_search: bool,
on_reply_update: &mut F,
) -> Result<CreationAgentJsonTurnOutput, CreationAgentJsonTurnFailure>
where
F: FnMut(&str),
{
let mut latest_reply_text = String::new();
let response = llm_client
.stream_text(
build_creation_agent_llm_request(system_prompt, user_prompt.into(), enable_web_search),
build_creation_agent_llm_request(system_prompt, user_prompt, enable_web_search),
|delta: &LlmStreamDelta| {
if let Some(reply_progress) =
extract_reply_text_from_partial_json(delta.accumulated_text.as_str())
@@ -48,9 +101,9 @@ where
},
)
.await
.map_err(|_| build_error(messages.generation_failed.to_string()))?;
.map_err(CreationAgentJsonTurnFailure::Stream)?;
let parsed = parse_json_response_text(response.content.as_str())
.map_err(|_| build_error(messages.parse_failed.to_string()))?;
.map_err(|_| CreationAgentJsonTurnFailure::Parse)?;
let reply_text = read_reply_text(&parsed);
if let Some(reply_text) = reply_text.as_deref()
&& reply_text != latest_reply_text
@@ -61,6 +114,13 @@ where
Ok(CreationAgentJsonTurnOutput { parsed })
}
fn is_web_search_tool_unavailable(error: &LlmError) -> bool {
let message = error.to_string();
message.contains("ToolNotOpen")
|| message.contains("has not activated web search")
|| message.contains("未开通")
}
fn build_creation_agent_llm_request(
system_prompt: String,
user_prompt: String,
@@ -168,11 +228,23 @@ fn read_reply_text(parsed: &JsonValue) -> Option<String> {
#[cfg(test)]
mod tests {
use std::{
fs,
io::{Read, Write},
net::TcpListener,
sync::{Arc, Mutex},
thread,
time::{Duration as StdDuration, SystemTime, UNIX_EPOCH},
};
use platform_llm::{LlmConfig, LlmProvider};
use crate::llm_model_routing::CREATION_TEMPLATE_LLM_MODEL;
use super::{
build_creation_agent_llm_request, extract_reply_text_from_partial_json,
parse_json_response_text,
CreationAgentLlmTurnErrorMessages, build_creation_agent_llm_request,
extract_reply_text_from_partial_json, is_web_search_tool_unavailable,
parse_json_response_text, stream_creation_agent_json_turn,
};
#[test]
@@ -202,4 +274,214 @@ mod tests {
assert_eq!(request.protocol, platform_llm::LlmTextProtocol::Responses);
assert_eq!(request.messages.len(), 2);
}
#[test]
fn detects_upstream_web_search_tool_unavailable_error() {
let error = platform_llm::LlmError::Upstream {
status_code: 502,
message: "Your account has not activated web search. code=ToolNotOpen".to_string(),
};
assert!(is_web_search_tool_unavailable(&error));
}
#[tokio::test]
async fn stream_turn_retries_without_web_search_when_tool_is_unavailable() {
let log_dir = std::env::temp_dir().join(format!(
"api-server-creation-agent-raw-log-test-{}-{}",
std::process::id(),
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time should be after epoch")
.as_nanos()
));
unsafe {
std::env::set_var("LLM_RAW_LOG_DIR", &log_dir);
}
let success_json = serde_json::json!({
"replyText": "好,我们先把玩具王国定住。",
"progressPercent": 12,
"nextAnchorContent": {
"worldPromise": "玩具王国初步方向",
"playerFantasy": null,
"themeBoundary": null,
"playerEntryPoint": null,
"coreConflict": null,
"keyRelationships": null,
"hiddenLines": null,
"iconicElements": null
}
})
.to_string();
let server = spawn_capturing_mock_server(vec![
MockResponse {
body: concat!(
"data: {\"type\":\"error\",\"code\":\"ToolNotOpen\",\"message\":\"Your account has not activated web search.\"}\n\n",
"data: [DONE]\n\n"
)
.to_string(),
},
MockResponse {
body: format!(
"data: {}\n\n",
serde_json::json!({
"type": "response.output_text.delta",
"delta": success_json
})
) + "data: {\"type\":\"response.completed\"}\n\n",
},
]);
let config = LlmConfig::new(
LlmProvider::Ark,
server.base_url,
"test-key".to_string(),
"test-model".to_string(),
30_000,
0,
1,
)
.expect("LLM config should build");
let llm_client = platform_llm::LlmClient::new(config).expect("LLM client should build");
let mut visible_replies = Vec::new();
let output = stream_creation_agent_json_turn(
Some(&llm_client),
"系统提示".to_string(),
"用户提示",
true,
CreationAgentLlmTurnErrorMessages {
model_unavailable: "模型不可用",
generation_failed: "生成失败",
parse_failed: "解析失败",
},
|text| visible_replies.push(text.to_string()),
|message| message,
)
.await
.expect("web search fallback should succeed");
assert_eq!(
output.parsed["replyText"].as_str(),
Some("好,我们先把玩具王国定住。")
);
assert_eq!(visible_replies, vec!["好,我们先把玩具王国定住。"]);
let requests = server.requests.lock().expect("requests lock").clone();
assert_eq!(requests.len(), 2);
assert!(requests[0].contains("\"tools\""));
assert!(requests[0].contains("\"web_search\""));
assert!(!requests[1].contains("\"tools\""));
unsafe {
std::env::remove_var("LLM_RAW_LOG_DIR");
}
if log_dir.exists() {
fs::remove_dir_all(log_dir).expect("temporary LLM raw log dir should be removed");
}
}
struct MockResponse {
body: String,
}
struct CapturingMockServer {
base_url: String,
requests: Arc<Mutex<Vec<String>>>,
}
fn spawn_capturing_mock_server(responses: Vec<MockResponse>) -> CapturingMockServer {
let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind");
let address = listener.local_addr().expect("listener should have addr");
let requests = Arc::new(Mutex::new(Vec::new()));
let requests_for_thread = Arc::clone(&requests);
thread::spawn(move || {
for response in responses {
let (mut stream, _) = listener.accept().expect("request should connect");
let request_text = read_request(&mut stream);
requests_for_thread
.lock()
.expect("requests lock")
.push(request_text);
write_sse_response(&mut stream, response);
}
});
CapturingMockServer {
base_url: format!("http://{address}"),
requests,
}
}
fn read_request(stream: &mut std::net::TcpStream) -> String {
stream
.set_read_timeout(Some(StdDuration::from_secs(1)))
.expect("read timeout should be set");
let mut buffer = Vec::new();
let mut chunk = [0_u8; 1024];
let mut expected_total = None;
loop {
match stream.read(&mut chunk) {
Ok(0) => break,
Ok(bytes_read) => {
buffer.extend_from_slice(&chunk[..bytes_read]);
if expected_total.is_none()
&& let Some(header_end) = find_header_end(&buffer)
{
let content_length =
read_content_length(&buffer[..header_end]).unwrap_or(0);
expected_total = Some(header_end + content_length);
}
if let Some(total_bytes) = expected_total
&& buffer.len() >= total_bytes
{
break;
}
}
Err(error)
if error.kind() == std::io::ErrorKind::WouldBlock
|| error.kind() == std::io::ErrorKind::TimedOut =>
{
break;
}
Err(error) => panic!("mock server failed to read request: {error}"),
}
}
String::from_utf8_lossy(buffer.as_slice()).to_string()
}
fn write_sse_response(stream: &mut std::net::TcpStream, response: MockResponse) {
let raw_response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: text/event-stream; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
response.body.len(),
response.body
);
stream
.write_all(raw_response.as_bytes())
.expect("mock response should be written");
stream.flush().expect("mock response should flush");
}
fn find_header_end(buffer: &[u8]) -> Option<usize> {
buffer
.windows(4)
.position(|window| window == b"\r\n\r\n")
.map(|index| index + 4)
}
fn read_content_length(headers: &[u8]) -> Option<usize> {
let text = String::from_utf8_lossy(headers);
text.lines().find_map(|line| {
let (name, value) = line.split_once(':')?;
if name.eq_ignore_ascii_case("content-length") {
return value.trim().parse::<usize>().ok();
}
None
})
}
}