Files
Genarrative/server-rs/crates/api-server/src/puzzle_agent_turn.rs
2026-04-30 17:49:07 +08:00

321 lines
11 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use module_puzzle::{PuzzleAgentStage, PuzzleAnchorPack, PuzzleAnchorStatus, empty_anchor_pack};
use platform_llm::LlmClient;
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use spacetime_client::{PuzzleAgentMessageFinalizeRecordInput, PuzzleAgentSessionRecord};
use crate::creation_agent_llm_turn::{
CreationAgentLlmTurnErrorMessages, stream_creation_agent_json_turn,
};
use crate::prompt::puzzle::agent_chat::{
PUZZLE_AGENT_JSON_TURN_USER_PROMPT, PUZZLE_AGENT_SYSTEM_PROMPT, build_puzzle_agent_prompt,
serialize_puzzle_record_anchor_pack,
};
#[derive(Clone, Debug)]
pub(crate) struct PuzzleAgentTurnRequest<'a> {
pub llm_client: Option<&'a LlmClient>,
pub session: &'a PuzzleAgentSessionRecord,
pub quick_fill_requested: bool,
pub enable_web_search: bool,
}
#[derive(Clone, Debug)]
pub(crate) struct PuzzleAgentTurnResult {
pub assistant_reply_text: String,
pub stage: String,
pub progress_percent: u32,
pub anchor_pack_json: String,
pub error_message: Option<String>,
}
#[derive(Clone, Debug)]
pub(crate) struct PuzzleAgentTurnError {
message: String,
}
impl PuzzleAgentTurnError {
fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
}
impl std::fmt::Display for PuzzleAgentTurnError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.message)
}
}
impl std::error::Error for PuzzleAgentTurnError {}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct PuzzleAgentModelOutput {
reply_text: String,
progress_percent: u32,
next_anchor_pack: PuzzleAnchorPack,
}
pub(crate) async fn run_puzzle_agent_turn<F>(
request: PuzzleAgentTurnRequest<'_>,
on_reply_update: F,
) -> Result<PuzzleAgentTurnResult, PuzzleAgentTurnError>
where
F: FnMut(&str),
{
let prompt = build_puzzle_agent_prompt(request.session, request.quick_fill_requested);
let turn_output = stream_creation_agent_json_turn(
request.llm_client,
format!("{PUZZLE_AGENT_SYSTEM_PROMPT}\n\n{prompt}"),
PUZZLE_AGENT_JSON_TURN_USER_PROMPT,
request.enable_web_search,
CreationAgentLlmTurnErrorMessages {
model_unavailable: "当前模型不可用,请稍后重试。",
generation_failed: "拼图聊天生成失败,请稍后重试。",
parse_failed: "拼图聊天结果解析失败,请稍后重试。",
},
on_reply_update,
PuzzleAgentTurnError::new,
)
.await?;
let output = parse_model_output(&turn_output.parsed)?;
Ok(PuzzleAgentTurnResult {
assistant_reply_text: output.reply_text,
stage: resolve_puzzle_agent_stage(output.progress_percent)
.as_str()
.to_string(),
progress_percent: if request.quick_fill_requested {
100
} else {
output.progress_percent
},
anchor_pack_json: serde_json::to_string(&output.next_anchor_pack).unwrap_or_else(|_| {
serde_json::to_string(&empty_anchor_pack()).unwrap_or_else(|_| "{}".to_string())
}),
error_message: None,
})
}
pub(crate) fn build_finalize_record_input(
session_id: String,
owner_user_id: String,
assistant_message_id: String,
result: PuzzleAgentTurnResult,
updated_at_micros: i64,
) -> PuzzleAgentMessageFinalizeRecordInput {
PuzzleAgentMessageFinalizeRecordInput {
session_id,
owner_user_id,
assistant_message_id: Some(assistant_message_id),
assistant_reply_text: Some(result.assistant_reply_text),
stage: result.stage,
progress_percent: result.progress_percent,
anchor_pack_json: result.anchor_pack_json,
error_message: result.error_message,
updated_at_micros,
}
}
pub(crate) fn build_failed_finalize_record_input(
session_id: String,
owner_user_id: String,
session: &PuzzleAgentSessionRecord,
error_message: String,
updated_at_micros: i64,
) -> PuzzleAgentMessageFinalizeRecordInput {
PuzzleAgentMessageFinalizeRecordInput {
session_id,
owner_user_id,
assistant_message_id: None,
assistant_reply_text: None,
stage: session.stage.clone(),
progress_percent: session.progress_percent,
anchor_pack_json: serialize_puzzle_record_anchor_pack(&session.anchor_pack),
error_message: Some(error_message),
updated_at_micros,
}
}
fn parse_model_output(parsed: &JsonValue) -> Result<PuzzleAgentModelOutput, PuzzleAgentTurnError> {
let reply_text = parsed
.get("replyText")
.and_then(JsonValue::as_str)
.map(str::trim)
.filter(|value| !value.is_empty())
.ok_or_else(|| PuzzleAgentTurnError::new("拼图聊天结果缺少有效回复,请稍后重试。"))?
.to_string();
let progress_percent = parsed
.get("progressPercent")
.and_then(JsonValue::as_u64)
.map(|value| value.min(100) as u32)
.unwrap_or(0);
let next_anchor_pack_value = parsed
.get("nextAnchorPack")
.cloned()
.ok_or_else(|| PuzzleAgentTurnError::new("拼图聊天结果缺少 nextAnchorPack。"))?;
let next_anchor_pack = parse_model_anchor_pack(&next_anchor_pack_value)?;
Ok(PuzzleAgentModelOutput {
reply_text,
progress_percent,
next_anchor_pack,
})
}
fn parse_model_anchor_pack(value: &JsonValue) -> Result<PuzzleAnchorPack, PuzzleAgentTurnError> {
Ok(PuzzleAnchorPack {
// LLM 输出契约面向前端与 prompt使用 camelCaseRust 领域模型仍保持 snake_case
// 因此这里显式做边界翻译,避免把 JSON 命名差异扩散到领域 crate。
theme_promise: parse_model_anchor_item(value, "themePromise")?,
visual_subject: parse_model_anchor_item(value, "visualSubject")?,
visual_mood: parse_model_anchor_item(value, "visualMood")?,
composition_hooks: parse_model_anchor_item(value, "compositionHooks")?,
tags_and_forbidden: parse_model_anchor_item(value, "tagsAndForbidden")?,
})
}
fn parse_model_anchor_item(
pack: &JsonValue,
field_name: &str,
) -> Result<module_puzzle::PuzzleAnchorItem, PuzzleAgentTurnError> {
let value = pack.get(field_name).ok_or_else(|| {
PuzzleAgentTurnError::new(format!("拼图 anchor pack 缺少 {field_name}"))
})?;
let key = value
.get("key")
.and_then(JsonValue::as_str)
.map(str::trim)
.filter(|text| !text.is_empty())
.unwrap_or(field_name)
.to_string();
let label = value
.get("label")
.and_then(JsonValue::as_str)
.map(str::trim)
.filter(|text| !text.is_empty())
.unwrap_or_else(|| default_puzzle_anchor_label(field_name))
.to_string();
let item_value = value
.get("value")
.and_then(JsonValue::as_str)
.map(str::trim)
.unwrap_or_default()
.to_string();
let status = value
.get("status")
.and_then(JsonValue::as_str)
.map(parse_anchor_status)
.unwrap_or(PuzzleAnchorStatus::Missing);
Ok(module_puzzle::PuzzleAnchorItem {
key,
label,
value: item_value,
status,
})
}
fn default_puzzle_anchor_label(field_name: &str) -> &'static str {
match field_name {
"themePromise" => "题材承诺",
"visualSubject" => "画面主体",
"visualMood" => "视觉气质",
"compositionHooks" => "拼图记忆点",
"tagsAndForbidden" => "标签与禁忌",
_ => "拼图锚点",
}
}
fn resolve_puzzle_agent_stage(progress_percent: u32) -> PuzzleAgentStage {
if progress_percent >= 85 {
PuzzleAgentStage::DraftReady
} else {
PuzzleAgentStage::CollectingAnchors
}
}
fn parse_anchor_status(value: &str) -> PuzzleAnchorStatus {
match value {
"confirmed" => PuzzleAnchorStatus::Confirmed,
"locked" => PuzzleAnchorStatus::Locked,
"inferred" => PuzzleAnchorStatus::Inferred,
_ => PuzzleAnchorStatus::Missing,
}
}
#[cfg(test)]
mod tests {
use module_puzzle::PuzzleAnchorStatus;
use serde_json::json;
use super::parse_model_output;
use crate::creation_agent_llm_turn::extract_reply_text_from_partial_json;
#[test]
fn extract_reply_text_from_partial_json_preserves_chinese_characters() {
let partial_json = r#"{"replyText":"夜雨猫咪遗迹","progressPercent":42"#;
let extracted = extract_reply_text_from_partial_json(partial_json);
assert_eq!(extracted.as_deref(), Some("夜雨猫咪遗迹"));
}
#[test]
fn parse_model_output_accepts_camel_case_anchor_pack_contract() {
let model_output = json!({
"replyText": "我先把雨夜猫咪的方向收住。",
"progressPercent": 46,
"nextAnchorPack": {
"themePromise": {
"key": "themePromise",
"label": "题材承诺",
"value": "雨夜中的奇幻探索",
"status": "confirmed"
},
"visualSubject": {
"key": "visualSubject",
"label": "画面主体",
"value": "发光猫咪站在遗迹台阶上",
"status": "confirmed"
},
"visualMood": {
"key": "visualMood",
"label": "视觉气质",
"value": "潮湿、梦幻、带轻微悬疑",
"status": "inferred"
},
"compositionHooks": {
"key": "compositionHooks",
"label": "拼图记忆点",
"value": "台阶透视、倒影、远处遗迹门洞",
"status": "inferred"
},
"tagsAndForbidden": {
"key": "tagsAndForbidden",
"label": "标签与禁忌",
"value": "雨夜、猫咪、神庙遗迹;禁止文字水印",
"status": "inferred"
}
}
});
let parsed = parse_model_output(&model_output).expect("camelCase 契约应能解析");
assert_eq!(parsed.progress_percent, 46);
assert_eq!(
parsed.next_anchor_pack.theme_promise.value,
"雨夜中的奇幻探索"
);
assert_eq!(
parsed.next_anchor_pack.theme_promise.status,
PuzzleAnchorStatus::Confirmed
);
assert_eq!(
parsed.next_anchor_pack.tags_and_forbidden.value,
"雨夜、猫咪、神庙遗迹;禁止文字水印"
);
}
}