use platform_llm::{LlmClient, LlmError, LlmMessage, LlmStreamDelta, LlmTextRequest}; use serde_json::Value as JsonValue; use crate::llm_model_routing::CREATION_TEMPLATE_LLM_MODEL; #[derive(Clone, Copy, Debug)] pub(crate) struct CreationAgentLlmTurnErrorMessages<'a> { pub model_unavailable: &'a str, pub generation_failed: &'a str, pub parse_failed: &'a str, } #[derive(Clone, Debug)] pub(crate) struct CreationAgentJsonTurnOutput { pub parsed: JsonValue, } /** * 创作 Agent 的通用流式 JSON turn 调用。 * 这里只处理跨玩法一致的 LLM 调用骨架,prompt 内容和领域 JSON 解析仍由调用方负责。 */ pub(crate) async fn stream_creation_agent_json_turn( llm_client: Option<&LlmClient>, system_prompt: String, user_prompt: impl Into, enable_web_search: bool, messages: CreationAgentLlmTurnErrorMessages<'_>, mut on_reply_update: F, build_error: impl Fn(String) -> E, ) -> Result where F: FnMut(&str), { let llm_client = llm_client.ok_or_else(|| build_error(messages.model_unavailable.to_string()))?; let user_prompt = user_prompt.into(); let turn_output = match request_stream_creation_agent_json_turn( llm_client, system_prompt.clone(), user_prompt.clone(), enable_web_search, &mut on_reply_update, ) .await { Ok(turn_output) => Ok(turn_output), Err(CreationAgentJsonTurnFailure::Stream(error)) if enable_web_search && is_web_search_tool_unavailable(&error) => { tracing::warn!( error = %error, "创作 Agent 联网搜索插件不可用,自动降级为无联网搜索重试" ); request_stream_creation_agent_json_turn( llm_client, system_prompt, user_prompt, false, &mut on_reply_update, ) .await } Err(error) => Err(error), }; turn_output.map_err(|error| match error { CreationAgentJsonTurnFailure::Stream(error) => { tracing::warn!( error = %error, "创作 Agent 流式 LLM 请求失败" ); build_error(format!("{}:{error}", messages.generation_failed)) } CreationAgentJsonTurnFailure::Parse => build_error(messages.parse_failed.to_string()), }) } enum CreationAgentJsonTurnFailure { Stream(LlmError), Parse, } async fn request_stream_creation_agent_json_turn( llm_client: &LlmClient, system_prompt: String, user_prompt: String, enable_web_search: bool, on_reply_update: &mut F, ) -> Result where F: FnMut(&str), { let mut latest_reply_text = String::new(); let response = llm_client .stream_text( build_creation_agent_llm_request(system_prompt, user_prompt, enable_web_search), |delta: &LlmStreamDelta| { if let Some(reply_progress) = extract_reply_text_from_partial_json(delta.accumulated_text.as_str()) && reply_progress != latest_reply_text { latest_reply_text = reply_progress.clone(); on_reply_update(reply_progress.as_str()); } }, ) .await .map_err(CreationAgentJsonTurnFailure::Stream)?; let parsed = parse_json_response_text(response.content.as_str()) .map_err(|_| CreationAgentJsonTurnFailure::Parse)?; let reply_text = read_reply_text(&parsed); if let Some(reply_text) = reply_text.as_deref() && reply_text != latest_reply_text { on_reply_update(reply_text); } Ok(CreationAgentJsonTurnOutput { parsed }) } fn is_web_search_tool_unavailable(error: &LlmError) -> bool { let message = error.to_string(); message.contains("ToolNotOpen") || message.contains("has not activated web search") || message.contains("未开通") } fn build_creation_agent_llm_request( system_prompt: String, user_prompt: String, enable_web_search: bool, ) -> LlmTextRequest { // 创作 Agent 是否联网由 api-server 配置集中传入,避免各玩法各自散落默认值。 LlmTextRequest::new(vec![ LlmMessage::system(system_prompt), LlmMessage::user(user_prompt), ]) .with_model(CREATION_TEMPLATE_LLM_MODEL) .with_responses_api() .with_web_search(enable_web_search) } pub(crate) async fn request_creation_agent_json_turn( llm_client: &LlmClient, system_prompt: String, user_prompt: String, build_error: impl Fn(String) -> E, ) -> Result { let response = llm_client .request_text( LlmTextRequest::new(vec![ LlmMessage::system(system_prompt), LlmMessage::user(user_prompt), ]) .with_model(CREATION_TEMPLATE_LLM_MODEL) .with_responses_api(), ) .await .map_err(|error| build_error(error.to_string()))?; parse_json_response_text(response.content.as_str()) .map_err(|error| build_error(error.to_string())) } pub(crate) fn parse_json_response_text(text: &str) -> Result { let trimmed = text.trim(); if let Some(start) = trimmed.find('{') && let Some(end) = trimmed.rfind('}') && end > start { return serde_json::from_str::(&trimmed[start..=end]); } serde_json::from_str::(trimmed) } pub(crate) fn extract_reply_text_from_partial_json(text: &str) -> Option { let key_index = text.find("\"replyText\"")?; let colon_index = text[key_index..].find(':')? + key_index; let mut cursor = colon_index + 1; while cursor < text.len() && text.as_bytes()[cursor].is_ascii_whitespace() { cursor += 1; } if text.as_bytes().get(cursor).copied() != Some(b'"') { return None; } cursor += 1; let mut decoded = String::new(); let remainder = text.get(cursor..)?; let mut characters = remainder.chars().peekable(); while let Some(current) = characters.next() { if current == '"' { return Some(decoded); } if current == '\\' { let escaped = characters.next()?; match escaped { '"' => decoded.push('"'), '\\' => decoded.push('\\'), '/' => decoded.push('/'), 'b' => decoded.push('\u{0008}'), 'f' => decoded.push('\u{000C}'), 'n' => decoded.push('\n'), 'r' => decoded.push('\r'), 't' => decoded.push('\t'), 'u' => { let mut hex = String::new(); for _ in 0..4 { hex.push(characters.next()?); } if let Ok(code) = u16::from_str_radix(hex.as_str(), 16) && let Some(character) = char::from_u32(code as u32) { decoded.push(character); } } other => decoded.push(other), } continue; } decoded.push(current); } Some(decoded) } fn read_reply_text(parsed: &JsonValue) -> Option { parsed .get("replyText") .and_then(JsonValue::as_str) .map(str::trim) .filter(|value| !value.is_empty()) .map(str::to_string) } #[cfg(test)] mod tests { use std::{ fs, io::{Read, Write}, net::TcpListener, sync::{Arc, Mutex}, thread, time::{Duration as StdDuration, SystemTime, UNIX_EPOCH}, }; use platform_llm::{LlmConfig, LlmProvider}; use crate::llm_model_routing::CREATION_TEMPLATE_LLM_MODEL; use super::{ CreationAgentLlmTurnErrorMessages, build_creation_agent_llm_request, extract_reply_text_from_partial_json, is_web_search_tool_unavailable, parse_json_response_text, stream_creation_agent_json_turn, }; #[test] fn extracts_reply_text_from_partial_json_with_chinese_text() { let partial_json = r#"{"replyText":"你好,潮雾列岛","progressPercent":32"#; let extracted = extract_reply_text_from_partial_json(partial_json); assert_eq!(extracted.as_deref(), Some("你好,潮雾列岛")); } #[test] fn parses_json_inside_model_markdown_noise() { let parsed = parse_json_response_text("```json\n{\"replyText\":\"好\"}\n```") .expect("应能截取模型返回中的 JSON 对象"); assert_eq!(parsed["replyText"].as_str(), Some("好")); } #[test] fn builds_stream_request_with_web_search_when_enabled() { let request = build_creation_agent_llm_request("系统提示".to_string(), "用户提示".to_string(), true); assert!(request.enable_web_search); assert_eq!(request.model.as_deref(), Some(CREATION_TEMPLATE_LLM_MODEL)); assert_eq!(request.protocol, platform_llm::LlmTextProtocol::Responses); assert_eq!(request.messages.len(), 2); } #[test] fn detects_upstream_web_search_tool_unavailable_error() { let error = platform_llm::LlmError::Upstream { status_code: 502, message: "Your account has not activated web search. code=ToolNotOpen".to_string(), }; assert!(is_web_search_tool_unavailable(&error)); } #[tokio::test] async fn stream_turn_retries_without_web_search_when_tool_is_unavailable() { let log_dir = std::env::temp_dir().join(format!( "api-server-creation-agent-raw-log-test-{}-{}", std::process::id(), SystemTime::now() .duration_since(UNIX_EPOCH) .expect("system time should be after epoch") .as_nanos() )); unsafe { std::env::set_var("LLM_RAW_LOG_DIR", &log_dir); } let success_json = serde_json::json!({ "replyText": "好,我们先把玩具王国定住。", "progressPercent": 12, "nextAnchorContent": { "worldPromise": "玩具王国初步方向", "playerFantasy": null, "themeBoundary": null, "playerEntryPoint": null, "coreConflict": null, "keyRelationships": null, "hiddenLines": null, "iconicElements": null } }) .to_string(); let server = spawn_capturing_mock_server(vec![ MockResponse { body: concat!( "data: {\"type\":\"error\",\"code\":\"ToolNotOpen\",\"message\":\"Your account has not activated web search.\"}\n\n", "data: [DONE]\n\n" ) .to_string(), }, MockResponse { body: format!( "data: {}\n\n", serde_json::json!({ "type": "response.output_text.delta", "delta": success_json }) ) + "data: {\"type\":\"response.completed\"}\n\n", }, ]); let config = LlmConfig::new( LlmProvider::Ark, server.base_url, "test-key".to_string(), "test-model".to_string(), 30_000, 0, 1, ) .expect("LLM config should build"); let llm_client = platform_llm::LlmClient::new(config).expect("LLM client should build"); let mut visible_replies = Vec::new(); let output = stream_creation_agent_json_turn( Some(&llm_client), "系统提示".to_string(), "用户提示", true, CreationAgentLlmTurnErrorMessages { model_unavailable: "模型不可用", generation_failed: "生成失败", parse_failed: "解析失败", }, |text| visible_replies.push(text.to_string()), |message| message, ) .await .expect("web search fallback should succeed"); assert_eq!( output.parsed["replyText"].as_str(), Some("好,我们先把玩具王国定住。") ); assert_eq!(visible_replies, vec!["好,我们先把玩具王国定住。"]); let requests = server.requests.lock().expect("requests lock").clone(); assert_eq!(requests.len(), 2); assert!(requests[0].contains("\"tools\"")); assert!(requests[0].contains("\"web_search\"")); assert!(!requests[1].contains("\"tools\"")); unsafe { std::env::remove_var("LLM_RAW_LOG_DIR"); } if log_dir.exists() { fs::remove_dir_all(log_dir).expect("temporary LLM raw log dir should be removed"); } } struct MockResponse { body: String, } struct CapturingMockServer { base_url: String, requests: Arc>>, } fn spawn_capturing_mock_server(responses: Vec) -> CapturingMockServer { let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind"); let address = listener.local_addr().expect("listener should have addr"); let requests = Arc::new(Mutex::new(Vec::new())); let requests_for_thread = Arc::clone(&requests); thread::spawn(move || { for response in responses { let (mut stream, _) = listener.accept().expect("request should connect"); let request_text = read_request(&mut stream); requests_for_thread .lock() .expect("requests lock") .push(request_text); write_sse_response(&mut stream, response); } }); CapturingMockServer { base_url: format!("http://{address}"), requests, } } fn read_request(stream: &mut std::net::TcpStream) -> String { stream .set_read_timeout(Some(StdDuration::from_secs(1))) .expect("read timeout should be set"); let mut buffer = Vec::new(); let mut chunk = [0_u8; 1024]; let mut expected_total = None; loop { match stream.read(&mut chunk) { Ok(0) => break, Ok(bytes_read) => { buffer.extend_from_slice(&chunk[..bytes_read]); if expected_total.is_none() && let Some(header_end) = find_header_end(&buffer) { let content_length = read_content_length(&buffer[..header_end]).unwrap_or(0); expected_total = Some(header_end + content_length); } if let Some(total_bytes) = expected_total && buffer.len() >= total_bytes { break; } } Err(error) if error.kind() == std::io::ErrorKind::WouldBlock || error.kind() == std::io::ErrorKind::TimedOut => { break; } Err(error) => panic!("mock server failed to read request: {error}"), } } String::from_utf8_lossy(buffer.as_slice()).to_string() } fn write_sse_response(stream: &mut std::net::TcpStream, response: MockResponse) { let raw_response = format!( "HTTP/1.1 200 OK\r\nContent-Type: text/event-stream; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", response.body.len(), response.body ); stream .write_all(raw_response.as_bytes()) .expect("mock response should be written"); stream.flush().expect("mock response should flush"); } fn find_header_end(buffer: &[u8]) -> Option { buffer .windows(4) .position(|window| window == b"\r\n\r\n") .map(|index| index + 4) } fn read_content_length(headers: &[u8]) -> Option { let text = String::from_utf8_lossy(headers); text.lines().find_map(|line| { let (name, value) = line.split_once(':')?; if name.eq_ignore_ascii_case("content-length") { return value.trim().parse::().ok(); } None }) } }