570 lines
20 KiB
Rust
570 lines
20 KiB
Rust
use module_square_hole::{
|
|
SQUARE_HOLE_MAX_DIFFICULTY, SQUARE_HOLE_MAX_SHAPE_COUNT, SQUARE_HOLE_MIN_DIFFICULTY,
|
|
SQUARE_HOLE_MIN_SHAPE_COUNT, SquareHoleHoleOption, SquareHoleShapeOption,
|
|
default_background_prompt, normalize_hole_options, normalize_shape_options,
|
|
};
|
|
use platform_llm::LlmClient;
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::Value as JsonValue;
|
|
use spacetime_client::{SquareHoleAgentMessageFinalizeRecordInput, SquareHoleAgentSessionRecord};
|
|
|
|
use crate::creation_agent_llm_turn::{
|
|
CreationAgentLlmTurnErrorMessages, stream_creation_agent_json_turn,
|
|
};
|
|
use crate::prompt::square_hole::{
|
|
SQUARE_HOLE_AGENT_JSON_TURN_USER_PROMPT, SQUARE_HOLE_AGENT_SYSTEM_PROMPT,
|
|
build_square_hole_agent_prompt,
|
|
};
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub(crate) struct SquareHoleAgentTurnRequest<'a> {
|
|
pub llm_client: Option<&'a LlmClient>,
|
|
pub session: &'a SquareHoleAgentSessionRecord,
|
|
pub quick_fill_requested: bool,
|
|
pub enable_web_search: bool,
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub(crate) struct SquareHoleAgentTurnResult {
|
|
pub assistant_reply_text: String,
|
|
pub stage: String,
|
|
pub progress_percent: u32,
|
|
pub config_json: String,
|
|
pub error_message: Option<String>,
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub(crate) struct SquareHoleAgentTurnError {
|
|
message: String,
|
|
}
|
|
|
|
impl SquareHoleAgentTurnError {
|
|
fn new(message: impl Into<String>) -> Self {
|
|
Self {
|
|
message: message.into(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl std::fmt::Display for SquareHoleAgentTurnError {
|
|
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
formatter.write_str(&self.message)
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for SquareHoleAgentTurnError {}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct SquareHoleAgentModelOutput {
|
|
reply_text: String,
|
|
progress_percent: u32,
|
|
next_config: SquareHoleAgentConfigOutput,
|
|
}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct SquareHoleAgentConfigOutput {
|
|
theme_text: String,
|
|
twist_rule: String,
|
|
shape_count: u32,
|
|
difficulty: u32,
|
|
shape_options: Vec<SquareHoleAgentShapeOptionOutput>,
|
|
hole_options: Vec<SquareHoleAgentHoleOptionOutput>,
|
|
background_prompt: String,
|
|
#[serde(default)]
|
|
cover_image_src: String,
|
|
#[serde(default)]
|
|
background_image_src: String,
|
|
}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct SquareHoleAgentShapeOptionOutput {
|
|
option_id: String,
|
|
shape_kind: String,
|
|
label: String,
|
|
target_hole_id: String,
|
|
image_prompt: String,
|
|
#[serde(default)]
|
|
image_src: String,
|
|
}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct SquareHoleAgentHoleOptionOutput {
|
|
hole_id: String,
|
|
hole_kind: String,
|
|
label: String,
|
|
image_prompt: String,
|
|
#[serde(default)]
|
|
image_src: String,
|
|
}
|
|
|
|
pub(crate) async fn run_square_hole_agent_turn<F>(
|
|
request: SquareHoleAgentTurnRequest<'_>,
|
|
on_reply_update: F,
|
|
) -> Result<SquareHoleAgentTurnResult, SquareHoleAgentTurnError>
|
|
where
|
|
F: FnMut(&str),
|
|
{
|
|
let prompt = build_square_hole_agent_prompt(request.session, request.quick_fill_requested);
|
|
let turn_output = stream_creation_agent_json_turn(
|
|
request.llm_client,
|
|
format!("{SQUARE_HOLE_AGENT_SYSTEM_PROMPT}\n\n{prompt}"),
|
|
SQUARE_HOLE_AGENT_JSON_TURN_USER_PROMPT,
|
|
request.enable_web_search,
|
|
CreationAgentLlmTurnErrorMessages {
|
|
model_unavailable: "当前模型不可用,请稍后重试。",
|
|
generation_failed: "方洞挑战聊天生成失败,请稍后重试。",
|
|
parse_failed: "方洞挑战聊天结果解析失败,请稍后重试。",
|
|
},
|
|
on_reply_update,
|
|
SquareHoleAgentTurnError::new,
|
|
)
|
|
.await?;
|
|
let output = parse_model_output(&turn_output.parsed, request.session)?;
|
|
let progress_percent = if request.quick_fill_requested {
|
|
100
|
|
} else {
|
|
output.progress_percent.min(100)
|
|
};
|
|
|
|
Ok(SquareHoleAgentTurnResult {
|
|
assistant_reply_text: output.reply_text,
|
|
stage: resolve_stage(progress_percent),
|
|
progress_percent,
|
|
config_json: serde_json::to_string(&output.next_config)
|
|
.map_err(|_| SquareHoleAgentTurnError::new("方洞挑战配置序列化失败。"))?,
|
|
error_message: None,
|
|
})
|
|
}
|
|
|
|
pub(crate) fn build_finalize_record_input(
|
|
session_id: String,
|
|
owner_user_id: String,
|
|
assistant_message_id: String,
|
|
result: SquareHoleAgentTurnResult,
|
|
updated_at_micros: i64,
|
|
) -> SquareHoleAgentMessageFinalizeRecordInput {
|
|
SquareHoleAgentMessageFinalizeRecordInput {
|
|
session_id,
|
|
owner_user_id,
|
|
assistant_message_id: Some(assistant_message_id),
|
|
assistant_reply_text: Some(result.assistant_reply_text),
|
|
config_json: Some(result.config_json),
|
|
progress_percent: result.progress_percent,
|
|
stage: result.stage,
|
|
updated_at_micros,
|
|
error_message: result.error_message,
|
|
}
|
|
}
|
|
|
|
fn parse_model_output(
|
|
parsed: &JsonValue,
|
|
session: &SquareHoleAgentSessionRecord,
|
|
) -> Result<SquareHoleAgentModelOutput, SquareHoleAgentTurnError> {
|
|
let reply_text = parsed
|
|
.get("replyText")
|
|
.and_then(JsonValue::as_str)
|
|
.map(str::trim)
|
|
.filter(|value| !value.is_empty())
|
|
.ok_or_else(|| SquareHoleAgentTurnError::new("方洞挑战聊天结果缺少有效回复,请稍后重试。"))?
|
|
.to_string();
|
|
let progress_percent = parsed
|
|
.get("progressPercent")
|
|
.and_then(JsonValue::as_u64)
|
|
.map(|value| value.min(100) as u32)
|
|
.unwrap_or(session.progress_percent);
|
|
let next_config_value = parsed
|
|
.get("nextConfig")
|
|
.ok_or_else(|| SquareHoleAgentTurnError::new("方洞挑战聊天结果缺少 nextConfig。"))?;
|
|
let next_config = parse_model_config(next_config_value, session)?;
|
|
Ok(SquareHoleAgentModelOutput {
|
|
reply_text,
|
|
progress_percent,
|
|
next_config,
|
|
})
|
|
}
|
|
|
|
fn parse_model_config(
|
|
value: &JsonValue,
|
|
session: &SquareHoleAgentSessionRecord,
|
|
) -> Result<SquareHoleAgentConfigOutput, SquareHoleAgentTurnError> {
|
|
if !value.is_object() {
|
|
return Err(SquareHoleAgentTurnError::new(
|
|
"方洞挑战聊天结果中的 nextConfig 必须是对象。",
|
|
));
|
|
}
|
|
|
|
let theme_text =
|
|
read_text_field(value, "themeText").unwrap_or_else(|| session.config.theme_text.clone());
|
|
let twist_rule =
|
|
read_text_field(value, "twistRule").unwrap_or_else(|| session.config.twist_rule.clone());
|
|
let hole_options = parse_hole_options(value, session, &theme_text);
|
|
let shape_options = parse_shape_options(value, session, &theme_text, hole_options.as_slice());
|
|
let background_prompt = read_text_field(value, "backgroundPrompt")
|
|
.or_else(|| {
|
|
session
|
|
.config
|
|
.background_prompt
|
|
.trim()
|
|
.is_empty()
|
|
.then(|| default_background_prompt(&theme_text))
|
|
})
|
|
.unwrap_or_else(|| session.config.background_prompt.clone());
|
|
|
|
Ok(SquareHoleAgentConfigOutput {
|
|
theme_text,
|
|
twist_rule,
|
|
shape_count: read_u32_field(value, "shapeCount")
|
|
.unwrap_or(session.config.shape_count)
|
|
.clamp(SQUARE_HOLE_MIN_SHAPE_COUNT, SQUARE_HOLE_MAX_SHAPE_COUNT),
|
|
difficulty: read_u32_field(value, "difficulty")
|
|
.unwrap_or(session.config.difficulty)
|
|
.clamp(SQUARE_HOLE_MIN_DIFFICULTY, SQUARE_HOLE_MAX_DIFFICULTY),
|
|
shape_options: shape_options
|
|
.into_iter()
|
|
.map(SquareHoleAgentShapeOptionOutput::from)
|
|
.collect(),
|
|
hole_options: hole_options
|
|
.into_iter()
|
|
.map(SquareHoleAgentHoleOptionOutput::from)
|
|
.collect(),
|
|
background_prompt,
|
|
cover_image_src: session.config.cover_image_src.clone().unwrap_or_default(),
|
|
background_image_src: session
|
|
.config
|
|
.background_image_src
|
|
.clone()
|
|
.unwrap_or_default(),
|
|
})
|
|
}
|
|
|
|
fn parse_shape_options(
|
|
value: &JsonValue,
|
|
session: &SquareHoleAgentSessionRecord,
|
|
theme_text: &str,
|
|
hole_options: &[SquareHoleHoleOption],
|
|
) -> Vec<SquareHoleShapeOption> {
|
|
let parsed = value
|
|
.get("shapeOptions")
|
|
.and_then(JsonValue::as_array)
|
|
.map(|items| {
|
|
items
|
|
.iter()
|
|
.enumerate()
|
|
.map(|(index, item)| SquareHoleShapeOption {
|
|
option_id: read_text_field(item, "optionId")
|
|
.unwrap_or_else(|| format!("shape-option-{index}")),
|
|
shape_kind: read_text_field(item, "shapeKind")
|
|
.unwrap_or_else(|| fallback_shape_kind(index).to_string()),
|
|
label: read_text_field(item, "label")
|
|
.unwrap_or_else(|| fallback_shape_label(index).to_string()),
|
|
target_hole_id: read_text_field(item, "targetHoleId")
|
|
.filter(|value| hole_options.iter().any(|option| option.hole_id == *value))
|
|
.unwrap_or_else(|| {
|
|
hole_options
|
|
.get(index % hole_options.len().max(1))
|
|
.map(|option| option.hole_id.clone())
|
|
.unwrap_or_else(|| fallback_target_hole_id(index).to_string())
|
|
}),
|
|
image_prompt: read_text_field(item, "imagePrompt").unwrap_or_else(|| {
|
|
format!(
|
|
"{theme_text}主题的{}贴纸图,透明背景,明亮游戏资产",
|
|
fallback_shape_label(index)
|
|
)
|
|
}),
|
|
image_src: read_text_field(item, "imageSrc"),
|
|
})
|
|
.collect::<Vec<_>>()
|
|
})
|
|
.unwrap_or_else(|| {
|
|
session
|
|
.config
|
|
.shape_options
|
|
.iter()
|
|
.map(|option| SquareHoleShapeOption {
|
|
option_id: option.option_id.clone(),
|
|
shape_kind: option.shape_kind.clone(),
|
|
label: option.label.clone(),
|
|
target_hole_id: option.target_hole_id.clone(),
|
|
image_prompt: option.image_prompt.clone(),
|
|
image_src: option.image_src.clone(),
|
|
})
|
|
.collect()
|
|
});
|
|
|
|
normalize_shape_options(parsed, theme_text, hole_options)
|
|
}
|
|
|
|
fn parse_hole_options(
|
|
value: &JsonValue,
|
|
session: &SquareHoleAgentSessionRecord,
|
|
theme_text: &str,
|
|
) -> Vec<SquareHoleHoleOption> {
|
|
let parsed = value
|
|
.get("holeOptions")
|
|
.and_then(JsonValue::as_array)
|
|
.map(|items| {
|
|
items
|
|
.iter()
|
|
.enumerate()
|
|
.map(|(index, item)| SquareHoleHoleOption {
|
|
hole_id: read_text_field(item, "holeId")
|
|
.unwrap_or_else(|| format!("hole-option-{index}")),
|
|
hole_kind: read_text_field(item, "holeKind")
|
|
.unwrap_or_else(|| format!("hole-{}", index + 1)),
|
|
label: read_text_field(item, "label")
|
|
.unwrap_or_else(|| fallback_hole_label(index).to_string()),
|
|
image_prompt: read_text_field(item, "imagePrompt").unwrap_or_else(|| {
|
|
format!(
|
|
"{theme_text}主题的{}贴纸图,透明背景,明亮游戏资产",
|
|
fallback_hole_label(index)
|
|
)
|
|
}),
|
|
image_src: read_text_field(item, "imageSrc"),
|
|
})
|
|
.collect::<Vec<_>>()
|
|
})
|
|
.unwrap_or_else(|| {
|
|
session
|
|
.config
|
|
.hole_options
|
|
.iter()
|
|
.map(|option| SquareHoleHoleOption {
|
|
hole_id: option.hole_id.clone(),
|
|
hole_kind: option.hole_kind.clone(),
|
|
label: option.label.clone(),
|
|
image_prompt: option.image_prompt.clone(),
|
|
image_src: option.image_src.clone(),
|
|
})
|
|
.collect()
|
|
});
|
|
|
|
normalize_hole_options(parsed, theme_text)
|
|
}
|
|
|
|
fn read_text_field(value: &JsonValue, field_name: &str) -> Option<String> {
|
|
value
|
|
.get(field_name)
|
|
.and_then(JsonValue::as_str)
|
|
.map(str::trim)
|
|
.filter(|text| !text.is_empty())
|
|
.map(str::to_string)
|
|
}
|
|
|
|
fn read_u32_field(value: &JsonValue, field_name: &str) -> Option<u32> {
|
|
value
|
|
.get(field_name)
|
|
.and_then(JsonValue::as_u64)
|
|
.and_then(|number| u32::try_from(number).ok())
|
|
}
|
|
|
|
fn fallback_shape_kind(index: usize) -> &'static str {
|
|
match index % 6 {
|
|
0 => "square",
|
|
1 => "circle",
|
|
2 => "triangle",
|
|
3 => "diamond",
|
|
4 => "star",
|
|
_ => "arch",
|
|
}
|
|
}
|
|
|
|
fn fallback_shape_label(index: usize) -> &'static str {
|
|
match fallback_shape_kind(index) {
|
|
"square" => "方块",
|
|
"circle" => "圆块",
|
|
"triangle" => "三角块",
|
|
"diamond" => "菱形块",
|
|
"star" => "星形块",
|
|
_ => "拱形块",
|
|
}
|
|
}
|
|
|
|
fn fallback_hole_label(index: usize) -> String {
|
|
format!("洞口 {}", index + 1)
|
|
}
|
|
|
|
fn fallback_target_hole_id(index: usize) -> &'static str {
|
|
match index % 3 {
|
|
0 => "hole-1",
|
|
1 => "hole-2",
|
|
_ => "hole-3",
|
|
}
|
|
}
|
|
|
|
impl From<SquareHoleShapeOption> for SquareHoleAgentShapeOptionOutput {
|
|
fn from(option: SquareHoleShapeOption) -> Self {
|
|
Self {
|
|
option_id: option.option_id,
|
|
shape_kind: option.shape_kind,
|
|
label: option.label,
|
|
target_hole_id: option.target_hole_id,
|
|
image_prompt: option.image_prompt,
|
|
image_src: option.image_src.unwrap_or_default(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl From<SquareHoleHoleOption> for SquareHoleAgentHoleOptionOutput {
|
|
fn from(option: SquareHoleHoleOption) -> Self {
|
|
Self {
|
|
hole_id: option.hole_id,
|
|
hole_kind: option.hole_kind,
|
|
label: option.label,
|
|
image_prompt: option.image_prompt,
|
|
image_src: option.image_src.unwrap_or_default(),
|
|
}
|
|
}
|
|
}
|
|
|
|
fn resolve_stage(progress_percent: u32) -> String {
|
|
if progress_percent >= 100 {
|
|
"ReadyToCompile"
|
|
} else {
|
|
"Collecting"
|
|
}
|
|
.to_string()
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use serde_json::json;
|
|
|
|
use super::{parse_model_output, resolve_stage};
|
|
|
|
fn session_record() -> spacetime_client::SquareHoleAgentSessionRecord {
|
|
spacetime_client::SquareHoleAgentSessionRecord {
|
|
session_id: "square-hole-session-test".to_string(),
|
|
current_turn: 1,
|
|
progress_percent: 25,
|
|
stage: "collecting_config".to_string(),
|
|
anchor_pack: spacetime_client::SquareHoleAnchorPackRecord {
|
|
theme: anchor("theme", "题材主题", "纸箱"),
|
|
twist_rule: anchor("twistRule", "反差规则", "方洞万能"),
|
|
shape_count: anchor("shapeCount", "形状数量", "12"),
|
|
difficulty: anchor("difficulty", "难度", "4"),
|
|
},
|
|
config: spacetime_client::SquareHoleCreatorConfigRecord {
|
|
theme_text: "纸箱".to_string(),
|
|
twist_rule: "方洞万能".to_string(),
|
|
shape_count: 12,
|
|
difficulty: 4,
|
|
shape_options: Vec::new(),
|
|
hole_options: Vec::new(),
|
|
background_prompt: "纸箱玩具桌面背景".to_string(),
|
|
cover_image_src: None,
|
|
background_image_src: None,
|
|
},
|
|
draft: None,
|
|
messages: Vec::new(),
|
|
last_assistant_reply: None,
|
|
published_profile_id: None,
|
|
updated_at: "2026-05-04T10:00:00.000Z".to_string(),
|
|
}
|
|
}
|
|
|
|
fn anchor(key: &str, label: &str, value: &str) -> spacetime_client::SquareHoleAnchorItemRecord {
|
|
spacetime_client::SquareHoleAnchorItemRecord {
|
|
key: key.to_string(),
|
|
label: label.to_string(),
|
|
value: value.to_string(),
|
|
status: if value.is_empty() {
|
|
"missing"
|
|
} else {
|
|
"confirmed"
|
|
}
|
|
.to_string(),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn parse_model_output_accepts_camel_case_config_contract() {
|
|
let model_output = json!({
|
|
"replyText": "可以,把办公室文具都做成会被方洞吞进去的挑战。",
|
|
"progressPercent": 86,
|
|
"nextConfig": {
|
|
"themeText": "办公室文具",
|
|
"twistRule": "所有文具最终都优先进入方洞",
|
|
"shapeCount": 14,
|
|
"difficulty": 6,
|
|
"shapeOptions": [
|
|
{
|
|
"optionId": "stamp",
|
|
"shapeKind": "circle",
|
|
"label": "圆形印章",
|
|
"targetHoleId": "folder",
|
|
"imagePrompt": "办公室圆形印章贴纸"
|
|
}
|
|
],
|
|
"holeOptions": [
|
|
{
|
|
"holeId": "folder",
|
|
"holeKind": "folder",
|
|
"label": "档案盒方洞",
|
|
"imagePrompt": "办公室档案盒洞口贴纸"
|
|
}
|
|
],
|
|
"backgroundPrompt": "办公室桌面纸箱玩具背景"
|
|
}
|
|
});
|
|
|
|
let output =
|
|
parse_model_output(&model_output, &session_record()).expect("模型输出应能解析");
|
|
|
|
assert_eq!(
|
|
output.reply_text,
|
|
"可以,把办公室文具都做成会被方洞吞进去的挑战。"
|
|
);
|
|
assert_eq!(output.progress_percent, 86);
|
|
assert_eq!(output.next_config.theme_text, "办公室文具");
|
|
assert_eq!(output.next_config.twist_rule, "所有文具最终都优先进入方洞");
|
|
assert_eq!(output.next_config.shape_count, 14);
|
|
assert_eq!(output.next_config.difficulty, 6);
|
|
assert!(output.next_config.shape_options.len() >= 6);
|
|
assert_eq!(output.next_config.shape_options[0].label, "圆形印章");
|
|
assert_eq!(output.next_config.shape_options[0].target_hole_id, "folder");
|
|
assert_eq!(output.next_config.hole_options[0].label, "档案盒方洞");
|
|
assert_eq!(
|
|
output.next_config.hole_options[0].image_prompt,
|
|
"办公室档案盒洞口贴纸"
|
|
);
|
|
assert_eq!(
|
|
output.next_config.background_prompt,
|
|
"办公室桌面纸箱玩具背景"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn parse_model_output_clamps_numeric_config() {
|
|
let model_output = json!({
|
|
"replyText": "我先把数字压到可试玩范围里。",
|
|
"progressPercent": 120,
|
|
"nextConfig": {
|
|
"themeText": "霓虹积木",
|
|
"twistRule": "方洞优先",
|
|
"shapeCount": 99,
|
|
"difficulty": 0,
|
|
"shapeOptions": [],
|
|
"holeOptions": [],
|
|
"backgroundPrompt": ""
|
|
}
|
|
});
|
|
|
|
let output =
|
|
parse_model_output(&model_output, &session_record()).expect("模型输出应能解析");
|
|
|
|
assert_eq!(output.progress_percent, 100);
|
|
assert_eq!(output.next_config.shape_count, 24);
|
|
assert_eq!(output.next_config.difficulty, 1);
|
|
}
|
|
|
|
#[test]
|
|
fn resolve_stage_switches_to_compile_only_at_complete_progress() {
|
|
assert_eq!(resolve_stage(99), "Collecting");
|
|
assert_eq!(resolve_stage(100), "ReadyToCompile");
|
|
}
|
|
}
|