492 lines
16 KiB
Rust
492 lines
16 KiB
Rust
use platform_llm::{LlmClient, LlmError, LlmMessage, LlmStreamDelta, LlmTextRequest};
|
||
use serde_json::Value as JsonValue;
|
||
|
||
use crate::llm_model_routing::CREATION_TEMPLATE_LLM_MODEL;
|
||
|
||
#[derive(Clone, Copy, Debug)]
|
||
pub(crate) struct CreationAgentLlmTurnErrorMessages<'a> {
|
||
pub model_unavailable: &'a str,
|
||
pub generation_failed: &'a str,
|
||
pub parse_failed: &'a str,
|
||
}
|
||
|
||
#[derive(Clone, Debug)]
|
||
pub(crate) struct CreationAgentJsonTurnOutput {
|
||
pub parsed: JsonValue,
|
||
}
|
||
|
||
/**
|
||
* 创作 Agent 的通用流式 JSON turn 调用。
|
||
* 这里只处理跨玩法一致的 LLM 调用骨架,prompt 内容和领域 JSON 解析仍由调用方负责。
|
||
*/
|
||
pub(crate) async fn stream_creation_agent_json_turn<F, E>(
|
||
llm_client: Option<&LlmClient>,
|
||
system_prompt: String,
|
||
user_prompt: impl Into<String>,
|
||
enable_web_search: bool,
|
||
messages: CreationAgentLlmTurnErrorMessages<'_>,
|
||
mut on_reply_update: F,
|
||
build_error: impl Fn(String) -> E,
|
||
) -> Result<CreationAgentJsonTurnOutput, E>
|
||
where
|
||
F: FnMut(&str),
|
||
{
|
||
let llm_client =
|
||
llm_client.ok_or_else(|| build_error(messages.model_unavailable.to_string()))?;
|
||
let user_prompt = user_prompt.into();
|
||
let turn_output = match request_stream_creation_agent_json_turn(
|
||
llm_client,
|
||
system_prompt.clone(),
|
||
user_prompt.clone(),
|
||
enable_web_search,
|
||
&mut on_reply_update,
|
||
)
|
||
.await
|
||
{
|
||
Ok(turn_output) => Ok(turn_output),
|
||
Err(CreationAgentJsonTurnFailure::Stream(error))
|
||
if enable_web_search && is_web_search_tool_unavailable(&error) =>
|
||
{
|
||
tracing::warn!(
|
||
error = %error,
|
||
"创作 Agent 联网搜索插件不可用,自动降级为无联网搜索重试"
|
||
);
|
||
request_stream_creation_agent_json_turn(
|
||
llm_client,
|
||
system_prompt,
|
||
user_prompt,
|
||
false,
|
||
&mut on_reply_update,
|
||
)
|
||
.await
|
||
}
|
||
Err(error) => Err(error),
|
||
};
|
||
|
||
turn_output.map_err(|error| match error {
|
||
CreationAgentJsonTurnFailure::Stream(error) => {
|
||
tracing::warn!(
|
||
error = %error,
|
||
"创作 Agent 流式 LLM 请求失败"
|
||
);
|
||
build_error(format!("{}:{error}", messages.generation_failed))
|
||
}
|
||
CreationAgentJsonTurnFailure::Parse => build_error(messages.parse_failed.to_string()),
|
||
})
|
||
}
|
||
|
||
enum CreationAgentJsonTurnFailure {
|
||
Stream(LlmError),
|
||
Parse,
|
||
}
|
||
|
||
async fn request_stream_creation_agent_json_turn<F>(
|
||
llm_client: &LlmClient,
|
||
system_prompt: String,
|
||
user_prompt: String,
|
||
enable_web_search: bool,
|
||
on_reply_update: &mut F,
|
||
) -> Result<CreationAgentJsonTurnOutput, CreationAgentJsonTurnFailure>
|
||
where
|
||
F: FnMut(&str),
|
||
{
|
||
let mut latest_reply_text = String::new();
|
||
let response = llm_client
|
||
.stream_text(
|
||
build_creation_agent_llm_request(system_prompt, user_prompt, enable_web_search),
|
||
|delta: &LlmStreamDelta| {
|
||
if let Some(reply_progress) =
|
||
extract_reply_text_from_partial_json(delta.accumulated_text.as_str())
|
||
&& reply_progress != latest_reply_text
|
||
{
|
||
latest_reply_text = reply_progress.clone();
|
||
on_reply_update(reply_progress.as_str());
|
||
}
|
||
},
|
||
)
|
||
.await
|
||
.map_err(CreationAgentJsonTurnFailure::Stream)?;
|
||
let parsed = parse_json_response_text(response.content.as_str())
|
||
.map_err(|_| CreationAgentJsonTurnFailure::Parse)?;
|
||
let reply_text = read_reply_text(&parsed);
|
||
if let Some(reply_text) = reply_text.as_deref()
|
||
&& reply_text != latest_reply_text
|
||
{
|
||
on_reply_update(reply_text);
|
||
}
|
||
|
||
Ok(CreationAgentJsonTurnOutput { parsed })
|
||
}
|
||
|
||
fn is_web_search_tool_unavailable(error: &LlmError) -> bool {
|
||
let message = error.to_string();
|
||
message.contains("ToolNotOpen")
|
||
|| message.contains("has not activated web search")
|
||
|| message.contains("未开通")
|
||
}
|
||
|
||
fn build_creation_agent_llm_request(
|
||
system_prompt: String,
|
||
user_prompt: String,
|
||
enable_web_search: bool,
|
||
) -> LlmTextRequest {
|
||
// 创作 Agent 是否联网由 api-server 配置集中传入,避免各玩法各自散落默认值。
|
||
LlmTextRequest::new(vec![
|
||
LlmMessage::system(system_prompt),
|
||
LlmMessage::user(user_prompt),
|
||
])
|
||
.with_model(CREATION_TEMPLATE_LLM_MODEL)
|
||
.with_responses_api()
|
||
.with_web_search(enable_web_search)
|
||
}
|
||
|
||
pub(crate) async fn request_creation_agent_json_turn<E>(
|
||
llm_client: &LlmClient,
|
||
system_prompt: String,
|
||
user_prompt: String,
|
||
build_error: impl Fn(String) -> E,
|
||
) -> Result<JsonValue, E> {
|
||
let response = llm_client
|
||
.request_text(
|
||
LlmTextRequest::new(vec![
|
||
LlmMessage::system(system_prompt),
|
||
LlmMessage::user(user_prompt),
|
||
])
|
||
.with_model(CREATION_TEMPLATE_LLM_MODEL)
|
||
.with_responses_api(),
|
||
)
|
||
.await
|
||
.map_err(|error| build_error(error.to_string()))?;
|
||
parse_json_response_text(response.content.as_str())
|
||
.map_err(|error| build_error(error.to_string()))
|
||
}
|
||
|
||
pub(crate) fn parse_json_response_text(text: &str) -> Result<JsonValue, serde_json::Error> {
|
||
let trimmed = text.trim();
|
||
if let Some(start) = trimmed.find('{')
|
||
&& let Some(end) = trimmed.rfind('}')
|
||
&& end > start
|
||
{
|
||
return serde_json::from_str::<JsonValue>(&trimmed[start..=end]);
|
||
}
|
||
serde_json::from_str::<JsonValue>(trimmed)
|
||
}
|
||
|
||
pub(crate) fn extract_reply_text_from_partial_json(text: &str) -> Option<String> {
|
||
let key_index = text.find("\"replyText\"")?;
|
||
let colon_index = text[key_index..].find(':')? + key_index;
|
||
let mut cursor = colon_index + 1;
|
||
while cursor < text.len() && text.as_bytes()[cursor].is_ascii_whitespace() {
|
||
cursor += 1;
|
||
}
|
||
if text.as_bytes().get(cursor).copied() != Some(b'"') {
|
||
return None;
|
||
}
|
||
cursor += 1;
|
||
let mut decoded = String::new();
|
||
let remainder = text.get(cursor..)?;
|
||
let mut characters = remainder.chars().peekable();
|
||
while let Some(current) = characters.next() {
|
||
if current == '"' {
|
||
return Some(decoded);
|
||
}
|
||
if current == '\\' {
|
||
let escaped = characters.next()?;
|
||
match escaped {
|
||
'"' => decoded.push('"'),
|
||
'\\' => decoded.push('\\'),
|
||
'/' => decoded.push('/'),
|
||
'b' => decoded.push('\u{0008}'),
|
||
'f' => decoded.push('\u{000C}'),
|
||
'n' => decoded.push('\n'),
|
||
'r' => decoded.push('\r'),
|
||
't' => decoded.push('\t'),
|
||
'u' => {
|
||
let mut hex = String::new();
|
||
for _ in 0..4 {
|
||
hex.push(characters.next()?);
|
||
}
|
||
if let Ok(code) = u16::from_str_radix(hex.as_str(), 16)
|
||
&& let Some(character) = char::from_u32(code as u32)
|
||
{
|
||
decoded.push(character);
|
||
}
|
||
}
|
||
other => decoded.push(other),
|
||
}
|
||
continue;
|
||
}
|
||
decoded.push(current);
|
||
}
|
||
Some(decoded)
|
||
}
|
||
|
||
fn read_reply_text(parsed: &JsonValue) -> Option<String> {
|
||
parsed
|
||
.get("replyText")
|
||
.and_then(JsonValue::as_str)
|
||
.map(str::trim)
|
||
.filter(|value| !value.is_empty())
|
||
.map(str::to_string)
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use std::{
|
||
fs,
|
||
io::{Read, Write},
|
||
net::TcpListener,
|
||
sync::{Arc, Mutex},
|
||
thread,
|
||
time::{Duration as StdDuration, SystemTime, UNIX_EPOCH},
|
||
};
|
||
|
||
use platform_llm::{LlmConfig, LlmProvider};
|
||
|
||
use crate::llm_model_routing::CREATION_TEMPLATE_LLM_MODEL;
|
||
|
||
use super::{
|
||
CreationAgentLlmTurnErrorMessages, build_creation_agent_llm_request,
|
||
extract_reply_text_from_partial_json, is_web_search_tool_unavailable,
|
||
parse_json_response_text, stream_creation_agent_json_turn,
|
||
};
|
||
|
||
#[test]
|
||
fn extracts_reply_text_from_partial_json_with_chinese_text() {
|
||
let partial_json = r#"{"replyText":"你好,潮雾列岛","progressPercent":32"#;
|
||
|
||
let extracted = extract_reply_text_from_partial_json(partial_json);
|
||
|
||
assert_eq!(extracted.as_deref(), Some("你好,潮雾列岛"));
|
||
}
|
||
|
||
#[test]
|
||
fn parses_json_inside_model_markdown_noise() {
|
||
let parsed = parse_json_response_text("```json\n{\"replyText\":\"好\"}\n```")
|
||
.expect("应能截取模型返回中的 JSON 对象");
|
||
|
||
assert_eq!(parsed["replyText"].as_str(), Some("好"));
|
||
}
|
||
|
||
#[test]
|
||
fn builds_stream_request_with_web_search_when_enabled() {
|
||
let request =
|
||
build_creation_agent_llm_request("系统提示".to_string(), "用户提示".to_string(), true);
|
||
|
||
assert!(request.enable_web_search);
|
||
assert_eq!(request.model.as_deref(), Some(CREATION_TEMPLATE_LLM_MODEL));
|
||
assert_eq!(request.protocol, platform_llm::LlmTextProtocol::Responses);
|
||
assert_eq!(request.messages.len(), 2);
|
||
}
|
||
|
||
#[test]
|
||
fn detects_upstream_web_search_tool_unavailable_error() {
|
||
let error = platform_llm::LlmError::Upstream {
|
||
status_code: 502,
|
||
message: "Your account has not activated web search. code=ToolNotOpen".to_string(),
|
||
};
|
||
|
||
assert!(is_web_search_tool_unavailable(&error));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn stream_turn_retries_without_web_search_when_tool_is_unavailable() {
|
||
let log_dir = std::env::temp_dir().join(format!(
|
||
"api-server-creation-agent-raw-log-test-{}-{}",
|
||
std::process::id(),
|
||
SystemTime::now()
|
||
.duration_since(UNIX_EPOCH)
|
||
.expect("system time should be after epoch")
|
||
.as_nanos()
|
||
));
|
||
unsafe {
|
||
std::env::set_var("LLM_RAW_LOG_DIR", &log_dir);
|
||
}
|
||
let success_json = serde_json::json!({
|
||
"replyText": "好,我们先把玩具王国定住。",
|
||
"progressPercent": 12,
|
||
"nextAnchorContent": {
|
||
"worldPromise": "玩具王国初步方向",
|
||
"playerFantasy": null,
|
||
"themeBoundary": null,
|
||
"playerEntryPoint": null,
|
||
"coreConflict": null,
|
||
"keyRelationships": null,
|
||
"hiddenLines": null,
|
||
"iconicElements": null
|
||
}
|
||
})
|
||
.to_string();
|
||
let server = spawn_capturing_mock_server(vec![
|
||
MockResponse {
|
||
body: concat!(
|
||
"data: {\"type\":\"error\",\"code\":\"ToolNotOpen\",\"message\":\"Your account has not activated web search.\"}\n\n",
|
||
"data: [DONE]\n\n"
|
||
)
|
||
.to_string(),
|
||
},
|
||
MockResponse {
|
||
body: format!(
|
||
"data: {}\n\n",
|
||
serde_json::json!({
|
||
"type": "response.output_text.delta",
|
||
"delta": success_json
|
||
})
|
||
) + "data: {\"type\":\"response.completed\"}\n\n",
|
||
},
|
||
]);
|
||
let config = LlmConfig::new(
|
||
LlmProvider::Ark,
|
||
server.base_url,
|
||
"test-key".to_string(),
|
||
"test-model".to_string(),
|
||
30_000,
|
||
0,
|
||
1,
|
||
)
|
||
.expect("LLM config should build");
|
||
let llm_client = platform_llm::LlmClient::new(config).expect("LLM client should build");
|
||
let mut visible_replies = Vec::new();
|
||
|
||
let output = stream_creation_agent_json_turn(
|
||
Some(&llm_client),
|
||
"系统提示".to_string(),
|
||
"用户提示",
|
||
true,
|
||
CreationAgentLlmTurnErrorMessages {
|
||
model_unavailable: "模型不可用",
|
||
generation_failed: "生成失败",
|
||
parse_failed: "解析失败",
|
||
},
|
||
|text| visible_replies.push(text.to_string()),
|
||
|message| message,
|
||
)
|
||
.await
|
||
.expect("web search fallback should succeed");
|
||
|
||
assert_eq!(
|
||
output.parsed["replyText"].as_str(),
|
||
Some("好,我们先把玩具王国定住。")
|
||
);
|
||
assert_eq!(visible_replies, vec!["好,我们先把玩具王国定住。"]);
|
||
|
||
let requests = server.requests.lock().expect("requests lock").clone();
|
||
assert_eq!(requests.len(), 2);
|
||
assert!(requests[0].contains("\"tools\""));
|
||
assert!(requests[0].contains("\"web_search\""));
|
||
assert!(!requests[1].contains("\"tools\""));
|
||
|
||
unsafe {
|
||
std::env::remove_var("LLM_RAW_LOG_DIR");
|
||
}
|
||
if log_dir.exists() {
|
||
fs::remove_dir_all(log_dir).expect("temporary LLM raw log dir should be removed");
|
||
}
|
||
}
|
||
|
||
struct MockResponse {
|
||
body: String,
|
||
}
|
||
|
||
struct CapturingMockServer {
|
||
base_url: String,
|
||
requests: Arc<Mutex<Vec<String>>>,
|
||
}
|
||
|
||
fn spawn_capturing_mock_server(responses: Vec<MockResponse>) -> CapturingMockServer {
|
||
let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind");
|
||
let address = listener.local_addr().expect("listener should have addr");
|
||
let requests = Arc::new(Mutex::new(Vec::new()));
|
||
let requests_for_thread = Arc::clone(&requests);
|
||
|
||
thread::spawn(move || {
|
||
for response in responses {
|
||
let (mut stream, _) = listener.accept().expect("request should connect");
|
||
let request_text = read_request(&mut stream);
|
||
requests_for_thread
|
||
.lock()
|
||
.expect("requests lock")
|
||
.push(request_text);
|
||
write_sse_response(&mut stream, response);
|
||
}
|
||
});
|
||
|
||
CapturingMockServer {
|
||
base_url: format!("http://{address}"),
|
||
requests,
|
||
}
|
||
}
|
||
|
||
fn read_request(stream: &mut std::net::TcpStream) -> String {
|
||
stream
|
||
.set_read_timeout(Some(StdDuration::from_secs(1)))
|
||
.expect("read timeout should be set");
|
||
let mut buffer = Vec::new();
|
||
let mut chunk = [0_u8; 1024];
|
||
let mut expected_total = None;
|
||
|
||
loop {
|
||
match stream.read(&mut chunk) {
|
||
Ok(0) => break,
|
||
Ok(bytes_read) => {
|
||
buffer.extend_from_slice(&chunk[..bytes_read]);
|
||
|
||
if expected_total.is_none()
|
||
&& let Some(header_end) = find_header_end(&buffer)
|
||
{
|
||
let content_length =
|
||
read_content_length(&buffer[..header_end]).unwrap_or(0);
|
||
expected_total = Some(header_end + content_length);
|
||
}
|
||
|
||
if let Some(total_bytes) = expected_total
|
||
&& buffer.len() >= total_bytes
|
||
{
|
||
break;
|
||
}
|
||
}
|
||
Err(error)
|
||
if error.kind() == std::io::ErrorKind::WouldBlock
|
||
|| error.kind() == std::io::ErrorKind::TimedOut =>
|
||
{
|
||
break;
|
||
}
|
||
Err(error) => panic!("mock server failed to read request: {error}"),
|
||
}
|
||
}
|
||
|
||
String::from_utf8_lossy(buffer.as_slice()).to_string()
|
||
}
|
||
|
||
fn write_sse_response(stream: &mut std::net::TcpStream, response: MockResponse) {
|
||
let raw_response = format!(
|
||
"HTTP/1.1 200 OK\r\nContent-Type: text/event-stream; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
|
||
response.body.len(),
|
||
response.body
|
||
);
|
||
|
||
stream
|
||
.write_all(raw_response.as_bytes())
|
||
.expect("mock response should be written");
|
||
stream.flush().expect("mock response should flush");
|
||
}
|
||
|
||
fn find_header_end(buffer: &[u8]) -> Option<usize> {
|
||
buffer
|
||
.windows(4)
|
||
.position(|window| window == b"\r\n\r\n")
|
||
.map(|index| index + 4)
|
||
}
|
||
|
||
fn read_content_length(headers: &[u8]) -> Option<usize> {
|
||
let text = String::from_utf8_lossy(headers);
|
||
text.lines().find_map(|line| {
|
||
let (name, value) = line.split_once(':')?;
|
||
if name.eq_ignore_ascii_case("content-length") {
|
||
return value.trim().parse::<usize>().ok();
|
||
}
|
||
None
|
||
})
|
||
}
|
||
}
|