use platform_llm::{LlmClient, LlmMessage, LlmStreamDelta, LlmTextRequest}; use serde_json::Value as JsonValue; use crate::llm_model_routing::CREATION_TEMPLATE_LLM_MODEL; #[derive(Clone, Copy, Debug)] pub(crate) struct CreationAgentLlmTurnErrorMessages<'a> { pub model_unavailable: &'a str, pub generation_failed: &'a str, pub parse_failed: &'a str, } #[derive(Clone, Debug)] pub(crate) struct CreationAgentJsonTurnOutput { pub parsed: JsonValue, } /** * 创作 Agent 的通用流式 JSON turn 调用。 * 这里只处理跨玩法一致的 LLM 调用骨架,prompt 内容和领域 JSON 解析仍由调用方负责。 */ pub(crate) async fn stream_creation_agent_json_turn( llm_client: Option<&LlmClient>, system_prompt: String, user_prompt: impl Into, enable_web_search: bool, messages: CreationAgentLlmTurnErrorMessages<'_>, mut on_reply_update: F, build_error: impl Fn(String) -> E, ) -> Result where F: FnMut(&str), { let llm_client = llm_client.ok_or_else(|| build_error(messages.model_unavailable.to_string()))?; let mut latest_reply_text = String::new(); let response = llm_client .stream_text( build_creation_agent_llm_request(system_prompt, user_prompt.into(), enable_web_search), |delta: &LlmStreamDelta| { if let Some(reply_progress) = extract_reply_text_from_partial_json(delta.accumulated_text.as_str()) && reply_progress != latest_reply_text { latest_reply_text = reply_progress.clone(); on_reply_update(reply_progress.as_str()); } }, ) .await .map_err(|_| build_error(messages.generation_failed.to_string()))?; let parsed = parse_json_response_text(response.content.as_str()) .map_err(|_| build_error(messages.parse_failed.to_string()))?; let reply_text = read_reply_text(&parsed); if let Some(reply_text) = reply_text.as_deref() && reply_text != latest_reply_text { on_reply_update(reply_text); } Ok(CreationAgentJsonTurnOutput { parsed }) } fn build_creation_agent_llm_request( system_prompt: String, user_prompt: String, enable_web_search: bool, ) -> LlmTextRequest { // 创作 Agent 是否联网由 api-server 配置集中传入,避免各玩法各自散落默认值。 LlmTextRequest::new(vec![ LlmMessage::system(system_prompt), LlmMessage::user(user_prompt), ]) .with_model(CREATION_TEMPLATE_LLM_MODEL) .with_responses_api() .with_web_search(enable_web_search) } pub(crate) async fn request_creation_agent_json_turn( llm_client: &LlmClient, system_prompt: String, user_prompt: String, build_error: impl Fn(String) -> E, ) -> Result { let response = llm_client .request_text( LlmTextRequest::new(vec![ LlmMessage::system(system_prompt), LlmMessage::user(user_prompt), ]) .with_model(CREATION_TEMPLATE_LLM_MODEL) .with_responses_api(), ) .await .map_err(|error| build_error(error.to_string()))?; parse_json_response_text(response.content.as_str()) .map_err(|error| build_error(error.to_string())) } pub(crate) fn parse_json_response_text(text: &str) -> Result { let trimmed = text.trim(); if let Some(start) = trimmed.find('{') && let Some(end) = trimmed.rfind('}') && end > start { return serde_json::from_str::(&trimmed[start..=end]); } serde_json::from_str::(trimmed) } pub(crate) fn extract_reply_text_from_partial_json(text: &str) -> Option { let key_index = text.find("\"replyText\"")?; let colon_index = text[key_index..].find(':')? + key_index; let mut cursor = colon_index + 1; while cursor < text.len() && text.as_bytes()[cursor].is_ascii_whitespace() { cursor += 1; } if text.as_bytes().get(cursor).copied() != Some(b'"') { return None; } cursor += 1; let mut decoded = String::new(); let remainder = text.get(cursor..)?; let mut characters = remainder.chars().peekable(); while let Some(current) = characters.next() { if current == '"' { return Some(decoded); } if current == '\\' { let escaped = characters.next()?; match escaped { '"' => decoded.push('"'), '\\' => decoded.push('\\'), '/' => decoded.push('/'), 'b' => decoded.push('\u{0008}'), 'f' => decoded.push('\u{000C}'), 'n' => decoded.push('\n'), 'r' => decoded.push('\r'), 't' => decoded.push('\t'), 'u' => { let mut hex = String::new(); for _ in 0..4 { hex.push(characters.next()?); } if let Ok(code) = u16::from_str_radix(hex.as_str(), 16) && let Some(character) = char::from_u32(code as u32) { decoded.push(character); } } other => decoded.push(other), } continue; } decoded.push(current); } Some(decoded) } fn read_reply_text(parsed: &JsonValue) -> Option { parsed .get("replyText") .and_then(JsonValue::as_str) .map(str::trim) .filter(|value| !value.is_empty()) .map(str::to_string) } #[cfg(test)] mod tests { use crate::llm_model_routing::CREATION_TEMPLATE_LLM_MODEL; use super::{ build_creation_agent_llm_request, extract_reply_text_from_partial_json, parse_json_response_text, }; #[test] fn extracts_reply_text_from_partial_json_with_chinese_text() { let partial_json = r#"{"replyText":"你好,潮雾列岛","progressPercent":32"#; let extracted = extract_reply_text_from_partial_json(partial_json); assert_eq!(extracted.as_deref(), Some("你好,潮雾列岛")); } #[test] fn parses_json_inside_model_markdown_noise() { let parsed = parse_json_response_text("```json\n{\"replyText\":\"好\"}\n```") .expect("应能截取模型返回中的 JSON 对象"); assert_eq!(parsed["replyText"].as_str(), Some("好")); } #[test] fn builds_stream_request_with_web_search_when_enabled() { let request = build_creation_agent_llm_request("系统提示".to_string(), "用户提示".to_string(), true); assert!(request.enable_web_search); assert_eq!(request.model.as_deref(), Some(CREATION_TEMPLATE_LLM_MODEL)); assert_eq!(request.protocol, platform_llm::LlmTextProtocol::Responses); assert_eq!(request.messages.len(), 2); } }