This commit is contained in:
170
server-rs/crates/api-server/src/creation_agent_llm_turn.rs
Normal file
170
server-rs/crates/api-server/src/creation_agent_llm_turn.rs
Normal file
@@ -0,0 +1,170 @@
|
||||
use platform_llm::{LlmClient, LlmMessage, LlmStreamDelta, LlmTextRequest};
|
||||
use serde_json::Value as JsonValue;
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub(crate) struct CreationAgentLlmTurnErrorMessages<'a> {
|
||||
pub model_unavailable: &'a str,
|
||||
pub generation_failed: &'a str,
|
||||
pub parse_failed: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct CreationAgentJsonTurnOutput {
|
||||
pub parsed: JsonValue,
|
||||
}
|
||||
|
||||
/**
|
||||
* 创作 Agent 的通用流式 JSON turn 调用。
|
||||
* 这里只处理跨玩法一致的 LLM 调用骨架,prompt 内容和领域 JSON 解析仍由调用方负责。
|
||||
*/
|
||||
pub(crate) async fn stream_creation_agent_json_turn<F, E>(
|
||||
llm_client: Option<&LlmClient>,
|
||||
system_prompt: String,
|
||||
user_prompt: impl Into<String>,
|
||||
messages: CreationAgentLlmTurnErrorMessages<'_>,
|
||||
mut on_reply_update: F,
|
||||
build_error: impl Fn(String) -> E,
|
||||
) -> Result<CreationAgentJsonTurnOutput, E>
|
||||
where
|
||||
F: FnMut(&str),
|
||||
{
|
||||
let llm_client =
|
||||
llm_client.ok_or_else(|| build_error(messages.model_unavailable.to_string()))?;
|
||||
let mut latest_reply_text = String::new();
|
||||
let response = llm_client
|
||||
.stream_text(
|
||||
LlmTextRequest::new(vec![
|
||||
LlmMessage::system(system_prompt),
|
||||
LlmMessage::user(user_prompt.into()),
|
||||
]),
|
||||
|delta: &LlmStreamDelta| {
|
||||
if let Some(reply_progress) =
|
||||
extract_reply_text_from_partial_json(delta.accumulated_text.as_str())
|
||||
&& reply_progress != latest_reply_text
|
||||
{
|
||||
latest_reply_text = reply_progress.clone();
|
||||
on_reply_update(reply_progress.as_str());
|
||||
}
|
||||
},
|
||||
)
|
||||
.await
|
||||
.map_err(|_| build_error(messages.generation_failed.to_string()))?;
|
||||
let parsed = parse_json_response_text(response.content.as_str())
|
||||
.map_err(|_| build_error(messages.parse_failed.to_string()))?;
|
||||
let reply_text = read_reply_text(&parsed);
|
||||
if let Some(reply_text) = reply_text.as_deref()
|
||||
&& reply_text != latest_reply_text
|
||||
{
|
||||
on_reply_update(reply_text);
|
||||
}
|
||||
|
||||
Ok(CreationAgentJsonTurnOutput { parsed })
|
||||
}
|
||||
|
||||
pub(crate) async fn request_creation_agent_json_turn<E>(
|
||||
llm_client: &LlmClient,
|
||||
system_prompt: String,
|
||||
user_prompt: String,
|
||||
build_error: impl Fn(String) -> E,
|
||||
) -> Result<JsonValue, E> {
|
||||
let response = llm_client
|
||||
.request_text(LlmTextRequest::new(vec![
|
||||
LlmMessage::system(system_prompt),
|
||||
LlmMessage::user(user_prompt),
|
||||
]))
|
||||
.await
|
||||
.map_err(|error| build_error(error.to_string()))?;
|
||||
parse_json_response_text(response.content.as_str())
|
||||
.map_err(|error| build_error(error.to_string()))
|
||||
}
|
||||
|
||||
pub(crate) fn parse_json_response_text(text: &str) -> Result<JsonValue, serde_json::Error> {
|
||||
let trimmed = text.trim();
|
||||
if let Some(start) = trimmed.find('{')
|
||||
&& let Some(end) = trimmed.rfind('}')
|
||||
&& end > start
|
||||
{
|
||||
return serde_json::from_str::<JsonValue>(&trimmed[start..=end]);
|
||||
}
|
||||
serde_json::from_str::<JsonValue>(trimmed)
|
||||
}
|
||||
|
||||
pub(crate) fn extract_reply_text_from_partial_json(text: &str) -> Option<String> {
|
||||
let key_index = text.find("\"replyText\"")?;
|
||||
let colon_index = text[key_index..].find(':')? + key_index;
|
||||
let mut cursor = colon_index + 1;
|
||||
while cursor < text.len() && text.as_bytes()[cursor].is_ascii_whitespace() {
|
||||
cursor += 1;
|
||||
}
|
||||
if text.as_bytes().get(cursor).copied() != Some(b'"') {
|
||||
return None;
|
||||
}
|
||||
cursor += 1;
|
||||
let mut decoded = String::new();
|
||||
let remainder = text.get(cursor..)?;
|
||||
let mut characters = remainder.chars().peekable();
|
||||
while let Some(current) = characters.next() {
|
||||
if current == '"' {
|
||||
return Some(decoded);
|
||||
}
|
||||
if current == '\\' {
|
||||
let escaped = characters.next()?;
|
||||
match escaped {
|
||||
'"' => decoded.push('"'),
|
||||
'\\' => decoded.push('\\'),
|
||||
'/' => decoded.push('/'),
|
||||
'b' => decoded.push('\u{0008}'),
|
||||
'f' => decoded.push('\u{000C}'),
|
||||
'n' => decoded.push('\n'),
|
||||
'r' => decoded.push('\r'),
|
||||
't' => decoded.push('\t'),
|
||||
'u' => {
|
||||
let mut hex = String::new();
|
||||
for _ in 0..4 {
|
||||
hex.push(characters.next()?);
|
||||
}
|
||||
if let Ok(code) = u16::from_str_radix(hex.as_str(), 16)
|
||||
&& let Some(character) = char::from_u32(code as u32)
|
||||
{
|
||||
decoded.push(character);
|
||||
}
|
||||
}
|
||||
other => decoded.push(other),
|
||||
}
|
||||
continue;
|
||||
}
|
||||
decoded.push(current);
|
||||
}
|
||||
Some(decoded)
|
||||
}
|
||||
|
||||
fn read_reply_text(parsed: &JsonValue) -> Option<String> {
|
||||
parsed
|
||||
.get("replyText")
|
||||
.and_then(JsonValue::as_str)
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.map(str::to_string)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{extract_reply_text_from_partial_json, parse_json_response_text};
|
||||
|
||||
#[test]
|
||||
fn extracts_reply_text_from_partial_json_with_chinese_text() {
|
||||
let partial_json = r#"{"replyText":"你好,潮雾列岛","progressPercent":32"#;
|
||||
|
||||
let extracted = extract_reply_text_from_partial_json(partial_json);
|
||||
|
||||
assert_eq!(extracted.as_deref(), Some("你好,潮雾列岛"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_json_inside_model_markdown_noise() {
|
||||
let parsed = parse_json_response_text("```json\n{\"replyText\":\"好\"}\n```")
|
||||
.expect("应能截取模型返回中的 JSON 对象");
|
||||
|
||||
assert_eq!(parsed["replyText"].as_str(), Some("好"));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user