Files
Genarrative/server-rs/crates/api-server/src/puzzle_agent_turn.rs
2026-04-27 14:23:19 +08:00

511 lines
18 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use module_puzzle::{PuzzleAgentStage, PuzzleAnchorPack, PuzzleAnchorStatus, empty_anchor_pack};
use platform_llm::LlmClient;
use serde::{Deserialize, Serialize};
use serde_json::{Value as JsonValue, json};
use spacetime_client::{
PuzzleAgentMessageFinalizeRecordInput, PuzzleAgentMessageRecord, PuzzleAgentSessionRecord,
};
use crate::creation_agent_anchor_templates::{
get_creation_agent_anchor_template, render_anchor_question_block,
};
use crate::creation_agent_chat::render_quick_fill_extra_rules;
use crate::creation_agent_llm_turn::{
CreationAgentLlmTurnErrorMessages, stream_creation_agent_json_turn,
};
#[derive(Clone, Debug)]
pub(crate) struct PuzzleAgentTurnRequest<'a> {
pub llm_client: Option<&'a LlmClient>,
pub session: &'a PuzzleAgentSessionRecord,
pub quick_fill_requested: bool,
pub enable_web_search: bool,
}
#[derive(Clone, Debug)]
pub(crate) struct PuzzleAgentTurnResult {
pub assistant_reply_text: String,
pub stage: String,
pub progress_percent: u32,
pub anchor_pack_json: String,
pub error_message: Option<String>,
}
#[derive(Clone, Debug)]
pub(crate) struct PuzzleAgentTurnError {
message: String,
}
impl PuzzleAgentTurnError {
fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
}
impl std::fmt::Display for PuzzleAgentTurnError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.message)
}
}
impl std::error::Error for PuzzleAgentTurnError {}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct PuzzleAgentModelOutput {
reply_text: String,
progress_percent: u32,
next_anchor_pack: PuzzleAnchorPack,
}
const PUZZLE_AGENT_SYSTEM_PROMPT: &str = r#"你是一个负责和创作者共创拼图画面的中文创意策划。
你要帮助用户把一句灵感逐步收束成可以发布成拼图关卡的视觉方案。
你必须同时输出:
1. 一段直接发给用户的中文回复 replyText
2. 当前进度 progressPercent
3. 下一轮完整可用的 nextAnchorPack
硬约束:
1. 只能输出 JSON不能输出代码块或解释
2. nextAnchorPack 必须是完整对象,不能只输出 patch
3. replyText 必须是自然中文不能提“字段”“锚点”“结构”“JSON”等内部词
4. replyText 一次最多推进一个最关键问题
5. 如果用户已经给出明确方向,就优先吸收和收束,不要机械反问
6. progressPercent 范围只能是 0 到 100
7. status 只能使用 missing / inferred / confirmed / locked
"#;
const PUZZLE_AGENT_OUTPUT_CONTRACT: &str = r#"请严格按以下 JSON 输出,不要输出其他文字:
{
"replyText": "",
"progressPercent": 0,
"nextAnchorPack": {
"themePromise": {
"key": "themePromise",
"label": "题材承诺",
"value": "",
"status": "missing"
},
"visualSubject": {
"key": "visualSubject",
"label": "画面主体",
"value": "",
"status": "missing"
},
"visualMood": {
"key": "visualMood",
"label": "视觉气质",
"value": "",
"status": "missing"
},
"compositionHooks": {
"key": "compositionHooks",
"label": "拼图记忆点",
"value": "",
"status": "missing"
},
"tagsAndForbidden": {
"key": "tagsAndForbidden",
"label": "标签与禁忌",
"value": "",
"status": "missing"
}
}
}"#;
pub(crate) async fn run_puzzle_agent_turn<F>(
request: PuzzleAgentTurnRequest<'_>,
on_reply_update: F,
) -> Result<PuzzleAgentTurnResult, PuzzleAgentTurnError>
where
F: FnMut(&str),
{
let prompt = build_puzzle_agent_prompt(request.session, request.quick_fill_requested);
let turn_output = stream_creation_agent_json_turn(
request.llm_client,
format!("{PUZZLE_AGENT_SYSTEM_PROMPT}\n\n{prompt}"),
"请按约定输出这一轮的 JSON。",
request.enable_web_search,
CreationAgentLlmTurnErrorMessages {
model_unavailable: "当前模型不可用,请稍后重试。",
generation_failed: "拼图聊天生成失败,请稍后重试。",
parse_failed: "拼图聊天结果解析失败,请稍后重试。",
},
on_reply_update,
PuzzleAgentTurnError::new,
)
.await?;
let output = parse_model_output(&turn_output.parsed)?;
Ok(PuzzleAgentTurnResult {
assistant_reply_text: output.reply_text,
stage: resolve_puzzle_agent_stage(output.progress_percent)
.as_str()
.to_string(),
progress_percent: if request.quick_fill_requested {
100
} else {
output.progress_percent
},
anchor_pack_json: serde_json::to_string(&output.next_anchor_pack).unwrap_or_else(|_| {
serde_json::to_string(&empty_anchor_pack()).unwrap_or_else(|_| "{}".to_string())
}),
error_message: None,
})
}
pub(crate) fn build_finalize_record_input(
session_id: String,
owner_user_id: String,
assistant_message_id: String,
result: PuzzleAgentTurnResult,
updated_at_micros: i64,
) -> PuzzleAgentMessageFinalizeRecordInput {
PuzzleAgentMessageFinalizeRecordInput {
session_id,
owner_user_id,
assistant_message_id: Some(assistant_message_id),
assistant_reply_text: Some(result.assistant_reply_text),
stage: result.stage,
progress_percent: result.progress_percent,
anchor_pack_json: result.anchor_pack_json,
error_message: result.error_message,
updated_at_micros,
}
}
pub(crate) fn build_failed_finalize_record_input(
session_id: String,
owner_user_id: String,
session: &PuzzleAgentSessionRecord,
error_message: String,
updated_at_micros: i64,
) -> PuzzleAgentMessageFinalizeRecordInput {
let anchor_pack_json = serde_json::to_string(&map_record_anchor_pack(&session.anchor_pack))
.unwrap_or_else(|_| {
serde_json::to_string(&empty_anchor_pack()).unwrap_or_else(|_| "{}".to_string())
});
PuzzleAgentMessageFinalizeRecordInput {
session_id,
owner_user_id,
assistant_message_id: None,
assistant_reply_text: None,
stage: session.stage.clone(),
progress_percent: session.progress_percent,
anchor_pack_json,
error_message: Some(error_message),
updated_at_micros,
}
}
fn build_puzzle_agent_prompt(
session: &PuzzleAgentSessionRecord,
quick_fill_requested: bool,
) -> String {
let anchor_question_block = get_creation_agent_anchor_template("puzzle")
.map(render_anchor_question_block)
.unwrap_or_else(|| "模板目标:收束成可以发布为拼图关卡的视觉方案。".to_string());
let quick_fill_rules = if quick_fill_requested {
format!(
"\n\n{}",
render_quick_fill_extra_rules(
"当前题材方向里的拼图关键词",
"不要要求用户再提供素材、风格或禁忌",
"输出完整 nextAnchorPack直接补齐 value 为空或 status 为 missing 的项",
"生成结果页",
)
)
} else {
String::new()
};
format!(
"{anchor_question_block}{quick_fill_rules}\n\n当前是第 {turn} 轮,当前进度 {progress}% 。\n\n是否要求自动补充剩余关键字:{quick_fill_requested_text}\n\n当前 anchor pack\n{anchor_pack}\n\n最近聊天记录:\n{chat_history}\n\n{contract}",
anchor_question_block = anchor_question_block,
quick_fill_rules = quick_fill_rules,
turn = session.current_turn.saturating_add(1),
progress = session.progress_percent,
quick_fill_requested_text = if quick_fill_requested { "" } else { "" },
anchor_pack = serde_json::to_string_pretty(&map_record_anchor_pack(&session.anchor_pack))
.unwrap_or_else(|_| "{}".to_string()),
chat_history =
serde_json::to_string_pretty(&build_chat_history(session.messages.as_slice()))
.unwrap_or_else(|_| "[]".to_string()),
contract = PUZZLE_AGENT_OUTPUT_CONTRACT,
)
}
fn build_chat_history(messages: &[PuzzleAgentMessageRecord]) -> Vec<JsonValue> {
messages
.iter()
.map(|message| {
json!({
"role": message.role,
"kind": message.kind,
"content": message.text,
})
})
.collect()
}
fn parse_model_output(parsed: &JsonValue) -> Result<PuzzleAgentModelOutput, PuzzleAgentTurnError> {
let reply_text = parsed
.get("replyText")
.and_then(JsonValue::as_str)
.map(str::trim)
.filter(|value| !value.is_empty())
.ok_or_else(|| PuzzleAgentTurnError::new("拼图聊天结果缺少有效回复,请稍后重试。"))?
.to_string();
let progress_percent = parsed
.get("progressPercent")
.and_then(JsonValue::as_u64)
.map(|value| value.min(100) as u32)
.unwrap_or(0);
let next_anchor_pack_value = parsed
.get("nextAnchorPack")
.cloned()
.ok_or_else(|| PuzzleAgentTurnError::new("拼图聊天结果缺少 nextAnchorPack。"))?;
let next_anchor_pack = parse_model_anchor_pack(&next_anchor_pack_value)?;
Ok(PuzzleAgentModelOutput {
reply_text,
progress_percent,
next_anchor_pack,
})
}
fn parse_model_anchor_pack(value: &JsonValue) -> Result<PuzzleAnchorPack, PuzzleAgentTurnError> {
Ok(PuzzleAnchorPack {
// LLM 输出契约面向前端与 prompt使用 camelCaseRust 领域模型仍保持 snake_case
// 因此这里显式做边界翻译,避免把 JSON 命名差异扩散到领域 crate。
theme_promise: parse_model_anchor_item(value, "themePromise")?,
visual_subject: parse_model_anchor_item(value, "visualSubject")?,
visual_mood: parse_model_anchor_item(value, "visualMood")?,
composition_hooks: parse_model_anchor_item(value, "compositionHooks")?,
tags_and_forbidden: parse_model_anchor_item(value, "tagsAndForbidden")?,
})
}
fn parse_model_anchor_item(
pack: &JsonValue,
field_name: &str,
) -> Result<module_puzzle::PuzzleAnchorItem, PuzzleAgentTurnError> {
let value = pack.get(field_name).ok_or_else(|| {
PuzzleAgentTurnError::new(format!("拼图 anchor pack 缺少 {field_name}"))
})?;
let key = value
.get("key")
.and_then(JsonValue::as_str)
.map(str::trim)
.filter(|text| !text.is_empty())
.unwrap_or(field_name)
.to_string();
let label = value
.get("label")
.and_then(JsonValue::as_str)
.map(str::trim)
.filter(|text| !text.is_empty())
.unwrap_or_else(|| default_puzzle_anchor_label(field_name))
.to_string();
let item_value = value
.get("value")
.and_then(JsonValue::as_str)
.map(str::trim)
.unwrap_or_default()
.to_string();
let status = value
.get("status")
.and_then(JsonValue::as_str)
.map(parse_anchor_status)
.unwrap_or(PuzzleAnchorStatus::Missing);
Ok(module_puzzle::PuzzleAnchorItem {
key,
label,
value: item_value,
status,
})
}
fn default_puzzle_anchor_label(field_name: &str) -> &'static str {
match field_name {
"themePromise" => "题材承诺",
"visualSubject" => "画面主体",
"visualMood" => "视觉气质",
"compositionHooks" => "拼图记忆点",
"tagsAndForbidden" => "标签与禁忌",
_ => "拼图锚点",
}
}
fn resolve_puzzle_agent_stage(progress_percent: u32) -> PuzzleAgentStage {
if progress_percent >= 85 {
PuzzleAgentStage::DraftReady
} else {
PuzzleAgentStage::CollectingAnchors
}
}
fn map_record_anchor_pack(record: &spacetime_client::PuzzleAnchorPackRecord) -> PuzzleAnchorPack {
PuzzleAnchorPack {
theme_promise: map_record_anchor_item(&record.theme_promise),
visual_subject: map_record_anchor_item(&record.visual_subject),
visual_mood: map_record_anchor_item(&record.visual_mood),
composition_hooks: map_record_anchor_item(&record.composition_hooks),
tags_and_forbidden: map_record_anchor_item(&record.tags_and_forbidden),
}
}
fn map_record_anchor_item(
record: &spacetime_client::PuzzleAnchorItemRecord,
) -> module_puzzle::PuzzleAnchorItem {
module_puzzle::PuzzleAnchorItem {
key: record.key.clone(),
label: record.label.clone(),
value: record.value.clone(),
status: parse_anchor_status(record.status.as_str()),
}
}
fn parse_anchor_status(value: &str) -> PuzzleAnchorStatus {
match value {
"confirmed" => PuzzleAnchorStatus::Confirmed,
"locked" => PuzzleAnchorStatus::Locked,
"inferred" => PuzzleAnchorStatus::Inferred,
_ => PuzzleAnchorStatus::Missing,
}
}
#[cfg(test)]
mod tests {
use module_puzzle::PuzzleAnchorStatus;
use serde_json::json;
use super::{build_puzzle_agent_prompt, parse_model_output};
use crate::creation_agent_llm_turn::extract_reply_text_from_partial_json;
fn empty_session_record() -> spacetime_client::PuzzleAgentSessionRecord {
spacetime_client::PuzzleAgentSessionRecord {
session_id: "puzzle-session-test".to_string(),
current_turn: 2,
progress_percent: 60,
stage: "collecting_anchors".to_string(),
anchor_pack: spacetime_client::PuzzleAnchorPackRecord {
theme_promise: spacetime_client::PuzzleAnchorItemRecord {
key: "themePromise".to_string(),
label: "题材承诺".to_string(),
value: "雨夜猫咪遗迹".to_string(),
status: "confirmed".to_string(),
},
visual_subject: spacetime_client::PuzzleAnchorItemRecord {
key: "visualSubject".to_string(),
label: "画面主体".to_string(),
value: String::new(),
status: "missing".to_string(),
},
visual_mood: spacetime_client::PuzzleAnchorItemRecord {
key: "visualMood".to_string(),
label: "视觉气质".to_string(),
value: String::new(),
status: "missing".to_string(),
},
composition_hooks: spacetime_client::PuzzleAnchorItemRecord {
key: "compositionHooks".to_string(),
label: "拼图记忆点".to_string(),
value: String::new(),
status: "missing".to_string(),
},
tags_and_forbidden: spacetime_client::PuzzleAnchorItemRecord {
key: "tagsAndForbidden".to_string(),
label: "标签与禁忌".to_string(),
value: String::new(),
status: "missing".to_string(),
},
},
draft: None,
messages: Vec::new(),
last_assistant_reply: None,
published_profile_id: None,
suggested_actions: Vec::new(),
result_preview: None,
updated_at: "2026-04-24T10:00:00.000Z".to_string(),
}
}
#[test]
fn extract_reply_text_from_partial_json_preserves_chinese_characters() {
let partial_json = r#"{"replyText":"夜雨猫咪遗迹","progressPercent":42"#;
let extracted = extract_reply_text_from_partial_json(partial_json);
assert_eq!(extracted.as_deref(), Some("夜雨猫咪遗迹"));
}
#[test]
fn parse_model_output_accepts_camel_case_anchor_pack_contract() {
let model_output = json!({
"replyText": "我先把雨夜猫咪的方向收住。",
"progressPercent": 46,
"nextAnchorPack": {
"themePromise": {
"key": "themePromise",
"label": "题材承诺",
"value": "雨夜中的奇幻探索",
"status": "confirmed"
},
"visualSubject": {
"key": "visualSubject",
"label": "画面主体",
"value": "发光猫咪站在遗迹台阶上",
"status": "confirmed"
},
"visualMood": {
"key": "visualMood",
"label": "视觉气质",
"value": "潮湿、梦幻、带轻微悬疑",
"status": "inferred"
},
"compositionHooks": {
"key": "compositionHooks",
"label": "拼图记忆点",
"value": "台阶透视、倒影、远处遗迹门洞",
"status": "inferred"
},
"tagsAndForbidden": {
"key": "tagsAndForbidden",
"label": "标签与禁忌",
"value": "雨夜、猫咪、神庙遗迹;禁止文字水印",
"status": "inferred"
}
}
});
let parsed = parse_model_output(&model_output).expect("camelCase 契约应能解析");
assert_eq!(parsed.progress_percent, 46);
assert_eq!(
parsed.next_anchor_pack.theme_promise.value,
"雨夜中的奇幻探索"
);
assert_eq!(
parsed.next_anchor_pack.theme_promise.status,
PuzzleAnchorStatus::Confirmed
);
assert_eq!(
parsed.next_anchor_pack.tags_and_forbidden.value,
"雨夜、猫咪、神庙遗迹;禁止文字水印"
);
}
#[test]
fn quick_fill_prompt_forbids_follow_up_questions() {
let prompt = build_puzzle_agent_prompt(&empty_session_record(), true);
assert!(prompt.contains("用户刚刚主动要求你自动补充剩余关键字"));
assert!(prompt.contains("不要再继续提问"));
assert!(prompt.contains("progressPercent 直接输出为 100"));
}
}