Files
Genarrative/server-rs/crates/api-server/src/creation_agent_llm_turn.rs
2026-04-30 17:49:07 +08:00

206 lines
6.9 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use platform_llm::{LlmClient, LlmMessage, LlmStreamDelta, LlmTextRequest};
use serde_json::Value as JsonValue;
use crate::llm_model_routing::CREATION_TEMPLATE_LLM_MODEL;
#[derive(Clone, Copy, Debug)]
pub(crate) struct CreationAgentLlmTurnErrorMessages<'a> {
pub model_unavailable: &'a str,
pub generation_failed: &'a str,
pub parse_failed: &'a str,
}
#[derive(Clone, Debug)]
pub(crate) struct CreationAgentJsonTurnOutput {
pub parsed: JsonValue,
}
/**
* 创作 Agent 的通用流式 JSON turn 调用。
* 这里只处理跨玩法一致的 LLM 调用骨架prompt 内容和领域 JSON 解析仍由调用方负责。
*/
pub(crate) async fn stream_creation_agent_json_turn<F, E>(
llm_client: Option<&LlmClient>,
system_prompt: String,
user_prompt: impl Into<String>,
enable_web_search: bool,
messages: CreationAgentLlmTurnErrorMessages<'_>,
mut on_reply_update: F,
build_error: impl Fn(String) -> E,
) -> Result<CreationAgentJsonTurnOutput, E>
where
F: FnMut(&str),
{
let llm_client =
llm_client.ok_or_else(|| build_error(messages.model_unavailable.to_string()))?;
let mut latest_reply_text = String::new();
let response = llm_client
.stream_text(
build_creation_agent_llm_request(system_prompt, user_prompt.into(), enable_web_search),
|delta: &LlmStreamDelta| {
if let Some(reply_progress) =
extract_reply_text_from_partial_json(delta.accumulated_text.as_str())
&& reply_progress != latest_reply_text
{
latest_reply_text = reply_progress.clone();
on_reply_update(reply_progress.as_str());
}
},
)
.await
.map_err(|_| build_error(messages.generation_failed.to_string()))?;
let parsed = parse_json_response_text(response.content.as_str())
.map_err(|_| build_error(messages.parse_failed.to_string()))?;
let reply_text = read_reply_text(&parsed);
if let Some(reply_text) = reply_text.as_deref()
&& reply_text != latest_reply_text
{
on_reply_update(reply_text);
}
Ok(CreationAgentJsonTurnOutput { parsed })
}
fn build_creation_agent_llm_request(
system_prompt: String,
user_prompt: String,
enable_web_search: bool,
) -> LlmTextRequest {
// 创作 Agent 是否联网由 api-server 配置集中传入,避免各玩法各自散落默认值。
LlmTextRequest::new(vec![
LlmMessage::system(system_prompt),
LlmMessage::user(user_prompt),
])
.with_model(CREATION_TEMPLATE_LLM_MODEL)
.with_responses_api()
.with_web_search(enable_web_search)
}
pub(crate) async fn request_creation_agent_json_turn<E>(
llm_client: &LlmClient,
system_prompt: String,
user_prompt: String,
build_error: impl Fn(String) -> E,
) -> Result<JsonValue, E> {
let response = llm_client
.request_text(
LlmTextRequest::new(vec![
LlmMessage::system(system_prompt),
LlmMessage::user(user_prompt),
])
.with_model(CREATION_TEMPLATE_LLM_MODEL)
.with_responses_api(),
)
.await
.map_err(|error| build_error(error.to_string()))?;
parse_json_response_text(response.content.as_str())
.map_err(|error| build_error(error.to_string()))
}
pub(crate) fn parse_json_response_text(text: &str) -> Result<JsonValue, serde_json::Error> {
let trimmed = text.trim();
if let Some(start) = trimmed.find('{')
&& let Some(end) = trimmed.rfind('}')
&& end > start
{
return serde_json::from_str::<JsonValue>(&trimmed[start..=end]);
}
serde_json::from_str::<JsonValue>(trimmed)
}
pub(crate) fn extract_reply_text_from_partial_json(text: &str) -> Option<String> {
let key_index = text.find("\"replyText\"")?;
let colon_index = text[key_index..].find(':')? + key_index;
let mut cursor = colon_index + 1;
while cursor < text.len() && text.as_bytes()[cursor].is_ascii_whitespace() {
cursor += 1;
}
if text.as_bytes().get(cursor).copied() != Some(b'"') {
return None;
}
cursor += 1;
let mut decoded = String::new();
let remainder = text.get(cursor..)?;
let mut characters = remainder.chars().peekable();
while let Some(current) = characters.next() {
if current == '"' {
return Some(decoded);
}
if current == '\\' {
let escaped = characters.next()?;
match escaped {
'"' => decoded.push('"'),
'\\' => decoded.push('\\'),
'/' => decoded.push('/'),
'b' => decoded.push('\u{0008}'),
'f' => decoded.push('\u{000C}'),
'n' => decoded.push('\n'),
'r' => decoded.push('\r'),
't' => decoded.push('\t'),
'u' => {
let mut hex = String::new();
for _ in 0..4 {
hex.push(characters.next()?);
}
if let Ok(code) = u16::from_str_radix(hex.as_str(), 16)
&& let Some(character) = char::from_u32(code as u32)
{
decoded.push(character);
}
}
other => decoded.push(other),
}
continue;
}
decoded.push(current);
}
Some(decoded)
}
fn read_reply_text(parsed: &JsonValue) -> Option<String> {
parsed
.get("replyText")
.and_then(JsonValue::as_str)
.map(str::trim)
.filter(|value| !value.is_empty())
.map(str::to_string)
}
#[cfg(test)]
mod tests {
use crate::llm_model_routing::CREATION_TEMPLATE_LLM_MODEL;
use super::{
build_creation_agent_llm_request, extract_reply_text_from_partial_json,
parse_json_response_text,
};
#[test]
fn extracts_reply_text_from_partial_json_with_chinese_text() {
let partial_json = r#"{"replyText":"你好,潮雾列岛","progressPercent":32"#;
let extracted = extract_reply_text_from_partial_json(partial_json);
assert_eq!(extracted.as_deref(), Some("你好,潮雾列岛"));
}
#[test]
fn parses_json_inside_model_markdown_noise() {
let parsed = parse_json_response_text("```json\n{\"replyText\":\"\"}\n```")
.expect("应能截取模型返回中的 JSON 对象");
assert_eq!(parsed["replyText"].as_str(), Some(""));
}
#[test]
fn builds_stream_request_with_web_search_when_enabled() {
let request =
build_creation_agent_llm_request("系统提示".to_string(), "用户提示".to_string(), true);
assert!(request.enable_web_search);
assert_eq!(request.model.as_deref(), Some(CREATION_TEMPLATE_LLM_MODEL));
assert_eq!(request.protocol, platform_llm::LlmTextProtocol::Responses);
assert_eq!(request.messages.len(), 2);
}
}