321 lines
11 KiB
Rust
321 lines
11 KiB
Rust
use module_puzzle::{PuzzleAgentStage, PuzzleAnchorPack, PuzzleAnchorStatus, empty_anchor_pack};
|
||
use platform_llm::LlmClient;
|
||
use serde::{Deserialize, Serialize};
|
||
use serde_json::Value as JsonValue;
|
||
use spacetime_client::{PuzzleAgentMessageFinalizeRecordInput, PuzzleAgentSessionRecord};
|
||
|
||
use crate::creation_agent_llm_turn::{
|
||
CreationAgentLlmTurnErrorMessages, stream_creation_agent_json_turn,
|
||
};
|
||
use crate::prompt::puzzle::agent_chat::{
|
||
PUZZLE_AGENT_JSON_TURN_USER_PROMPT, PUZZLE_AGENT_SYSTEM_PROMPT, build_puzzle_agent_prompt,
|
||
serialize_puzzle_record_anchor_pack,
|
||
};
|
||
|
||
#[derive(Clone, Debug)]
|
||
pub(crate) struct PuzzleAgentTurnRequest<'a> {
|
||
pub llm_client: Option<&'a LlmClient>,
|
||
pub session: &'a PuzzleAgentSessionRecord,
|
||
pub quick_fill_requested: bool,
|
||
pub enable_web_search: bool,
|
||
}
|
||
|
||
#[derive(Clone, Debug)]
|
||
pub(crate) struct PuzzleAgentTurnResult {
|
||
pub assistant_reply_text: String,
|
||
pub stage: String,
|
||
pub progress_percent: u32,
|
||
pub anchor_pack_json: String,
|
||
pub error_message: Option<String>,
|
||
}
|
||
|
||
#[derive(Clone, Debug)]
|
||
pub(crate) struct PuzzleAgentTurnError {
|
||
message: String,
|
||
}
|
||
|
||
impl PuzzleAgentTurnError {
|
||
fn new(message: impl Into<String>) -> Self {
|
||
Self {
|
||
message: message.into(),
|
||
}
|
||
}
|
||
}
|
||
|
||
impl std::fmt::Display for PuzzleAgentTurnError {
|
||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||
f.write_str(&self.message)
|
||
}
|
||
}
|
||
|
||
impl std::error::Error for PuzzleAgentTurnError {}
|
||
|
||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||
#[serde(rename_all = "camelCase")]
|
||
struct PuzzleAgentModelOutput {
|
||
reply_text: String,
|
||
progress_percent: u32,
|
||
next_anchor_pack: PuzzleAnchorPack,
|
||
}
|
||
|
||
pub(crate) async fn run_puzzle_agent_turn<F>(
|
||
request: PuzzleAgentTurnRequest<'_>,
|
||
on_reply_update: F,
|
||
) -> Result<PuzzleAgentTurnResult, PuzzleAgentTurnError>
|
||
where
|
||
F: FnMut(&str),
|
||
{
|
||
let prompt = build_puzzle_agent_prompt(request.session, request.quick_fill_requested);
|
||
let turn_output = stream_creation_agent_json_turn(
|
||
request.llm_client,
|
||
format!("{PUZZLE_AGENT_SYSTEM_PROMPT}\n\n{prompt}"),
|
||
PUZZLE_AGENT_JSON_TURN_USER_PROMPT,
|
||
request.enable_web_search,
|
||
CreationAgentLlmTurnErrorMessages {
|
||
model_unavailable: "当前模型不可用,请稍后重试。",
|
||
generation_failed: "拼图聊天生成失败,请稍后重试。",
|
||
parse_failed: "拼图聊天结果解析失败,请稍后重试。",
|
||
},
|
||
on_reply_update,
|
||
PuzzleAgentTurnError::new,
|
||
)
|
||
.await?;
|
||
let output = parse_model_output(&turn_output.parsed)?;
|
||
|
||
Ok(PuzzleAgentTurnResult {
|
||
assistant_reply_text: output.reply_text,
|
||
stage: resolve_puzzle_agent_stage(output.progress_percent)
|
||
.as_str()
|
||
.to_string(),
|
||
progress_percent: if request.quick_fill_requested {
|
||
100
|
||
} else {
|
||
output.progress_percent
|
||
},
|
||
anchor_pack_json: serde_json::to_string(&output.next_anchor_pack).unwrap_or_else(|_| {
|
||
serde_json::to_string(&empty_anchor_pack()).unwrap_or_else(|_| "{}".to_string())
|
||
}),
|
||
error_message: None,
|
||
})
|
||
}
|
||
|
||
pub(crate) fn build_finalize_record_input(
|
||
session_id: String,
|
||
owner_user_id: String,
|
||
assistant_message_id: String,
|
||
result: PuzzleAgentTurnResult,
|
||
updated_at_micros: i64,
|
||
) -> PuzzleAgentMessageFinalizeRecordInput {
|
||
PuzzleAgentMessageFinalizeRecordInput {
|
||
session_id,
|
||
owner_user_id,
|
||
assistant_message_id: Some(assistant_message_id),
|
||
assistant_reply_text: Some(result.assistant_reply_text),
|
||
stage: result.stage,
|
||
progress_percent: result.progress_percent,
|
||
anchor_pack_json: result.anchor_pack_json,
|
||
error_message: result.error_message,
|
||
updated_at_micros,
|
||
}
|
||
}
|
||
|
||
pub(crate) fn build_failed_finalize_record_input(
|
||
session_id: String,
|
||
owner_user_id: String,
|
||
session: &PuzzleAgentSessionRecord,
|
||
error_message: String,
|
||
updated_at_micros: i64,
|
||
) -> PuzzleAgentMessageFinalizeRecordInput {
|
||
PuzzleAgentMessageFinalizeRecordInput {
|
||
session_id,
|
||
owner_user_id,
|
||
assistant_message_id: None,
|
||
assistant_reply_text: None,
|
||
stage: session.stage.clone(),
|
||
progress_percent: session.progress_percent,
|
||
anchor_pack_json: serialize_puzzle_record_anchor_pack(&session.anchor_pack),
|
||
error_message: Some(error_message),
|
||
updated_at_micros,
|
||
}
|
||
}
|
||
|
||
fn parse_model_output(parsed: &JsonValue) -> Result<PuzzleAgentModelOutput, PuzzleAgentTurnError> {
|
||
let reply_text = parsed
|
||
.get("replyText")
|
||
.and_then(JsonValue::as_str)
|
||
.map(str::trim)
|
||
.filter(|value| !value.is_empty())
|
||
.ok_or_else(|| PuzzleAgentTurnError::new("拼图聊天结果缺少有效回复,请稍后重试。"))?
|
||
.to_string();
|
||
let progress_percent = parsed
|
||
.get("progressPercent")
|
||
.and_then(JsonValue::as_u64)
|
||
.map(|value| value.min(100) as u32)
|
||
.unwrap_or(0);
|
||
let next_anchor_pack_value = parsed
|
||
.get("nextAnchorPack")
|
||
.cloned()
|
||
.ok_or_else(|| PuzzleAgentTurnError::new("拼图聊天结果缺少 nextAnchorPack。"))?;
|
||
let next_anchor_pack = parse_model_anchor_pack(&next_anchor_pack_value)?;
|
||
Ok(PuzzleAgentModelOutput {
|
||
reply_text,
|
||
progress_percent,
|
||
next_anchor_pack,
|
||
})
|
||
}
|
||
|
||
fn parse_model_anchor_pack(value: &JsonValue) -> Result<PuzzleAnchorPack, PuzzleAgentTurnError> {
|
||
Ok(PuzzleAnchorPack {
|
||
// LLM 输出契约面向前端与 prompt,使用 camelCase;Rust 领域模型仍保持 snake_case,
|
||
// 因此这里显式做边界翻译,避免把 JSON 命名差异扩散到领域 crate。
|
||
theme_promise: parse_model_anchor_item(value, "themePromise")?,
|
||
visual_subject: parse_model_anchor_item(value, "visualSubject")?,
|
||
visual_mood: parse_model_anchor_item(value, "visualMood")?,
|
||
composition_hooks: parse_model_anchor_item(value, "compositionHooks")?,
|
||
tags_and_forbidden: parse_model_anchor_item(value, "tagsAndForbidden")?,
|
||
})
|
||
}
|
||
|
||
fn parse_model_anchor_item(
|
||
pack: &JsonValue,
|
||
field_name: &str,
|
||
) -> Result<module_puzzle::PuzzleAnchorItem, PuzzleAgentTurnError> {
|
||
let value = pack.get(field_name).ok_or_else(|| {
|
||
PuzzleAgentTurnError::new(format!("拼图 anchor pack 缺少 {field_name}。"))
|
||
})?;
|
||
let key = value
|
||
.get("key")
|
||
.and_then(JsonValue::as_str)
|
||
.map(str::trim)
|
||
.filter(|text| !text.is_empty())
|
||
.unwrap_or(field_name)
|
||
.to_string();
|
||
let label = value
|
||
.get("label")
|
||
.and_then(JsonValue::as_str)
|
||
.map(str::trim)
|
||
.filter(|text| !text.is_empty())
|
||
.unwrap_or_else(|| default_puzzle_anchor_label(field_name))
|
||
.to_string();
|
||
let item_value = value
|
||
.get("value")
|
||
.and_then(JsonValue::as_str)
|
||
.map(str::trim)
|
||
.unwrap_or_default()
|
||
.to_string();
|
||
let status = value
|
||
.get("status")
|
||
.and_then(JsonValue::as_str)
|
||
.map(parse_anchor_status)
|
||
.unwrap_or(PuzzleAnchorStatus::Missing);
|
||
|
||
Ok(module_puzzle::PuzzleAnchorItem {
|
||
key,
|
||
label,
|
||
value: item_value,
|
||
status,
|
||
})
|
||
}
|
||
|
||
fn default_puzzle_anchor_label(field_name: &str) -> &'static str {
|
||
match field_name {
|
||
"themePromise" => "题材承诺",
|
||
"visualSubject" => "画面主体",
|
||
"visualMood" => "视觉气质",
|
||
"compositionHooks" => "拼图记忆点",
|
||
"tagsAndForbidden" => "标签与禁忌",
|
||
_ => "拼图锚点",
|
||
}
|
||
}
|
||
|
||
fn resolve_puzzle_agent_stage(progress_percent: u32) -> PuzzleAgentStage {
|
||
if progress_percent >= 85 {
|
||
PuzzleAgentStage::DraftReady
|
||
} else {
|
||
PuzzleAgentStage::CollectingAnchors
|
||
}
|
||
}
|
||
|
||
fn parse_anchor_status(value: &str) -> PuzzleAnchorStatus {
|
||
match value {
|
||
"confirmed" => PuzzleAnchorStatus::Confirmed,
|
||
"locked" => PuzzleAnchorStatus::Locked,
|
||
"inferred" => PuzzleAnchorStatus::Inferred,
|
||
_ => PuzzleAnchorStatus::Missing,
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use module_puzzle::PuzzleAnchorStatus;
|
||
use serde_json::json;
|
||
|
||
use super::parse_model_output;
|
||
use crate::creation_agent_llm_turn::extract_reply_text_from_partial_json;
|
||
|
||
#[test]
|
||
fn extract_reply_text_from_partial_json_preserves_chinese_characters() {
|
||
let partial_json = r#"{"replyText":"夜雨猫咪遗迹","progressPercent":42"#;
|
||
|
||
let extracted = extract_reply_text_from_partial_json(partial_json);
|
||
|
||
assert_eq!(extracted.as_deref(), Some("夜雨猫咪遗迹"));
|
||
}
|
||
|
||
#[test]
|
||
fn parse_model_output_accepts_camel_case_anchor_pack_contract() {
|
||
let model_output = json!({
|
||
"replyText": "我先把雨夜猫咪的方向收住。",
|
||
"progressPercent": 46,
|
||
"nextAnchorPack": {
|
||
"themePromise": {
|
||
"key": "themePromise",
|
||
"label": "题材承诺",
|
||
"value": "雨夜中的奇幻探索",
|
||
"status": "confirmed"
|
||
},
|
||
"visualSubject": {
|
||
"key": "visualSubject",
|
||
"label": "画面主体",
|
||
"value": "发光猫咪站在遗迹台阶上",
|
||
"status": "confirmed"
|
||
},
|
||
"visualMood": {
|
||
"key": "visualMood",
|
||
"label": "视觉气质",
|
||
"value": "潮湿、梦幻、带轻微悬疑",
|
||
"status": "inferred"
|
||
},
|
||
"compositionHooks": {
|
||
"key": "compositionHooks",
|
||
"label": "拼图记忆点",
|
||
"value": "台阶透视、倒影、远处遗迹门洞",
|
||
"status": "inferred"
|
||
},
|
||
"tagsAndForbidden": {
|
||
"key": "tagsAndForbidden",
|
||
"label": "标签与禁忌",
|
||
"value": "雨夜、猫咪、神庙遗迹;禁止文字水印",
|
||
"status": "inferred"
|
||
}
|
||
}
|
||
});
|
||
|
||
let parsed = parse_model_output(&model_output).expect("camelCase 契约应能解析");
|
||
|
||
assert_eq!(parsed.progress_percent, 46);
|
||
assert_eq!(
|
||
parsed.next_anchor_pack.theme_promise.value,
|
||
"雨夜中的奇幻探索"
|
||
);
|
||
assert_eq!(
|
||
parsed.next_anchor_pack.theme_promise.status,
|
||
PuzzleAnchorStatus::Confirmed
|
||
);
|
||
assert_eq!(
|
||
parsed.next_anchor_pack.tags_and_forbidden.value,
|
||
"雨夜、猫咪、神庙遗迹;禁止文字水印"
|
||
);
|
||
}
|
||
}
|