Files
Genarrative/server-rs/crates/module-puzzle/src/creative_tools.rs
2026-05-14 14:21:17 +08:00

533 lines
20 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//! 拼图创意 Agent 草稿工具。
//!
//! 通用 Agent 只能把模型输出交给这些工具;字段归一化、模板关卡数和可编辑
//! 字段白名单都收口在拼图模块,避免 api-server 复制草稿业务规则。
use serde::{Deserialize, Serialize};
use serde_json::Value;
use shared_kernel::{normalize_required_string, normalize_string_list};
use crate::{
application::{
build_form_anchor_pack, build_result_preview, normalize_puzzle_draft,
normalize_puzzle_levels, sync_primary_level_fields,
},
creative_templates::{
PuzzleCreativeCostRange, PuzzleCreativeDraftEditableFieldPath,
PuzzleCreativeLevelGenerationMode, PuzzleCreativeSupportedLevelMode,
PuzzleCreativeTemplateProtocol, retrieve_puzzle_template_catalog,
},
domain::{
PUZZLE_MAX_TAG_COUNT, PUZZLE_MIN_TAG_COUNT, PuzzleDraftLevel, PuzzleFormDraft,
PuzzleResultDraft,
},
errors::PuzzleFieldError,
};
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CreativePuzzleLevelDraftInput {
pub level_name: String,
pub picture_description: String,
#[serde(default)]
pub picture_reference: Option<String>,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CreativePuzzleDraftToolInput {
pub template_id: String,
pub template_cost_range: PuzzleCreativeCostRange,
pub work_title: String,
pub work_description: String,
pub work_tags: Vec<String>,
pub levels: Vec<CreativePuzzleLevelDraftInput>,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PuzzleCreativeTemplateSelection {
pub template_id: String,
pub title: String,
pub reason: String,
pub cost_range: PuzzleCreativeCostRange,
pub supported_level_mode: PuzzleCreativeSupportedLevelMode,
pub selected_level_mode: PuzzleCreativeLevelGenerationMode,
pub planned_level_count: u32,
pub requires_user_confirmation: bool,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PuzzleLevelImagePlanInput {
pub template_id: String,
pub selected_level_mode: PuzzleCreativeLevelGenerationMode,
pub levels: Vec<CreativePuzzleLevelDraftInput>,
pub cost_range: PuzzleCreativeCostRange,
#[serde(default)]
pub candidate_count_per_level: Option<u32>,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PuzzleImageGenerationPlanLevel {
pub level_id: String,
pub level_name: String,
pub picture_description: String,
pub image_prompt: String,
#[serde(default)]
pub picture_reference: Option<String>,
pub candidate_count: u32,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PuzzleImageGenerationPlan {
pub mode: PuzzleCreativeLevelGenerationMode,
pub template_id: String,
pub estimated_cost_range: PuzzleCreativeCostRange,
pub levels: Vec<PuzzleImageGenerationPlanLevel>,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PuzzleDraftFieldPatchOperation {
Set,
Append,
Replace,
Remove,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PuzzleDraftFieldPatch {
pub field_path: PuzzleCreativeDraftEditableFieldPath,
pub operation: PuzzleDraftFieldPatchOperation,
#[serde(default)]
pub level_id: Option<String>,
pub value: Value,
pub rationale: String,
}
pub fn validate_puzzle_template_selection(
selection: &PuzzleCreativeTemplateSelection,
) -> Result<(), PuzzleFieldError> {
let template = resolve_phase1_template(&selection.template_id)?;
if selection.cost_range != template.cost_range
|| selection.supported_level_mode != template.supported_level_mode
|| !selection.requires_user_confirmation
{
return Err(PuzzleFieldError::InvalidOperation);
}
validate_level_count(
selection.planned_level_count,
&selection.selected_level_mode,
&template,
)
}
pub fn build_puzzle_draft_from_creative_fields(
input: CreativePuzzleDraftToolInput,
) -> Result<PuzzleResultDraft, PuzzleFieldError> {
let template = resolve_phase1_template(&input.template_id)?;
if input.template_cost_range != template.cost_range {
return Err(PuzzleFieldError::InvalidOperation);
}
validate_level_count(
input.levels.len() as u32,
&if input.levels.len() > 1 {
PuzzleCreativeLevelGenerationMode::MultiLevel
} else {
PuzzleCreativeLevelGenerationMode::SingleLevel
},
&template,
)?;
let work_title =
normalize_required_string(&input.work_title).ok_or(PuzzleFieldError::MissingText)?;
let work_description =
normalize_required_string(&input.work_description).ok_or(PuzzleFieldError::MissingText)?;
let tags = normalize_theme_tags_for_creative(input.work_tags)?;
let anchor_pack = build_form_anchor_pack(
work_title.as_str(),
input
.levels
.first()
.map(|level| level.picture_description.as_str())
.unwrap_or(work_description.as_str()),
);
let levels = input
.levels
.into_iter()
.enumerate()
.map(|(index, level)| {
let picture_description = normalize_required_string(&level.picture_description)
.ok_or(PuzzleFieldError::MissingText)?;
Ok(PuzzleDraftLevel {
level_id: format!("puzzle-level-{}", index + 1),
level_name: normalize_required_string(&level.level_name)
.unwrap_or_else(|| format!("{}", index + 1)),
picture_description,
picture_reference: level.picture_reference.and_then(normalize_required_string),
ui_background_prompt: None,
ui_background_image_src: None,
ui_background_image_object_key: None,
background_music: None,
candidates: Vec::new(),
selected_candidate_id: None,
cover_image_src: None,
cover_asset_id: None,
generation_status: "idle".to_string(),
})
})
.collect::<Result<Vec<_>, PuzzleFieldError>>()?;
let mut draft = PuzzleResultDraft {
work_title: work_title.clone(),
work_description: work_description.clone(),
level_name: levels
.first()
.map(|level| level.level_name.clone())
.unwrap_or_default(),
summary: work_description.clone(),
theme_tags: tags,
forbidden_directives: Vec::new(),
creator_intent: None,
anchor_pack,
candidates: Vec::new(),
selected_candidate_id: None,
cover_image_src: None,
cover_asset_id: None,
generation_status: "idle".to_string(),
levels,
form_draft: Some(PuzzleFormDraft {
work_title: Some(work_title),
work_description: Some(work_description),
picture_description: None,
}),
};
sync_primary_level_fields(&mut draft);
Ok(normalize_puzzle_draft(draft))
}
pub fn plan_puzzle_level_images(
input: PuzzleLevelImagePlanInput,
) -> Result<PuzzleImageGenerationPlan, PuzzleFieldError> {
let template = resolve_phase1_template(&input.template_id)?;
validate_level_count(
input.levels.len() as u32,
&input.selected_level_mode,
&template,
)?;
if input.cost_range != template.cost_range {
return Err(PuzzleFieldError::InvalidOperation);
}
let candidate_count = input.candidate_count_per_level.unwrap_or(1).clamp(1, 1);
let levels = input
.levels
.into_iter()
.enumerate()
.map(|(index, level)| {
let picture_description = normalize_required_string(&level.picture_description)
.ok_or(PuzzleFieldError::MissingText)?;
let level_name = normalize_required_string(&level.level_name)
.unwrap_or_else(|| format!("{}", index + 1));
Ok(PuzzleImageGenerationPlanLevel {
level_id: format!("puzzle-level-{}", index + 1),
image_prompt: build_level_image_prompt(&level_name, &picture_description),
level_name,
picture_description,
picture_reference: level.picture_reference.and_then(normalize_required_string),
candidate_count,
})
})
.collect::<Result<Vec<_>, PuzzleFieldError>>()?;
Ok(PuzzleImageGenerationPlan {
mode: input.selected_level_mode,
template_id: input.template_id,
estimated_cost_range: input.cost_range,
levels,
})
}
pub fn apply_puzzle_draft_field_patch(
draft: PuzzleResultDraft,
patch: PuzzleDraftFieldPatch,
) -> Result<PuzzleResultDraft, PuzzleFieldError> {
if patch.operation != PuzzleDraftFieldPatchOperation::Set
&& patch.operation != PuzzleDraftFieldPatchOperation::Replace
{
return Err(PuzzleFieldError::InvalidOperation);
}
let mut next_draft = normalize_puzzle_draft(draft);
match patch.field_path {
PuzzleCreativeDraftEditableFieldPath::WorkTitle => {
next_draft.work_title = value_as_required_string(&patch.value)?;
}
PuzzleCreativeDraftEditableFieldPath::WorkDescription => {
next_draft.work_description = value_as_required_string(&patch.value)?;
}
PuzzleCreativeDraftEditableFieldPath::WorkTags => {
next_draft.theme_tags =
normalize_theme_tags_for_creative(value_as_string_list(&patch.value)?)?;
}
PuzzleCreativeDraftEditableFieldPath::LevelName => {
let level = mutable_level_for_patch(&mut next_draft, patch.level_id.as_deref())?;
level.level_name = value_as_required_string(&patch.value)?;
}
PuzzleCreativeDraftEditableFieldPath::LevelPictureDescription => {
let level = mutable_level_for_patch(&mut next_draft, patch.level_id.as_deref())?;
level.picture_description = value_as_required_string(&patch.value)?;
}
PuzzleCreativeDraftEditableFieldPath::LevelPictureReference => {
let level = mutable_level_for_patch(&mut next_draft, patch.level_id.as_deref())?;
level.picture_reference = value_as_optional_string(&patch.value);
}
}
let levels = normalize_puzzle_levels(next_draft.levels.clone(), &next_draft.theme_tags)?;
next_draft.levels = levels;
sync_primary_level_fields(&mut next_draft);
let _ = build_result_preview(&next_draft, Some("陶泥儿主"));
Ok(next_draft)
}
fn resolve_phase1_template(
template_id: &str,
) -> Result<PuzzleCreativeTemplateProtocol, PuzzleFieldError> {
let normalized_template_id =
normalize_required_string(template_id).ok_or(PuzzleFieldError::InvalidOperation)?;
retrieve_puzzle_template_catalog()
.into_iter()
.find(|template| template.template_id == normalized_template_id)
.ok_or(PuzzleFieldError::InvalidOperation)
}
fn validate_level_count(
count: u32,
mode: &PuzzleCreativeLevelGenerationMode,
template: &PuzzleCreativeTemplateProtocol,
) -> Result<(), PuzzleFieldError> {
if count < template.min_level_count || count > template.max_level_count {
return Err(PuzzleFieldError::InvalidOperation);
}
if matches!(mode, PuzzleCreativeLevelGenerationMode::SingleLevel) && count != 1 {
return Err(PuzzleFieldError::InvalidOperation);
}
if matches!(mode, PuzzleCreativeLevelGenerationMode::MultiLevel) && count < 2 {
return Err(PuzzleFieldError::InvalidOperation);
}
Ok(())
}
fn normalize_theme_tags_for_creative(values: Vec<String>) -> Result<Vec<String>, PuzzleFieldError> {
let mut tags = Vec::new();
for tag in normalize_string_list(values) {
if !tags.contains(&tag) {
tags.push(tag);
}
if tags.len() >= PUZZLE_MAX_TAG_COUNT {
break;
}
}
if tags.len() < PUZZLE_MIN_TAG_COUNT {
return Err(PuzzleFieldError::InvalidTagCount);
}
Ok(tags)
}
fn build_level_image_prompt(level_name: &str, picture_description: &str) -> String {
format!("{level_name}{picture_description}。清晰主体,适合拼图切块。")
}
fn mutable_level_for_patch<'a>(
draft: &'a mut PuzzleResultDraft,
level_id: Option<&str>,
) -> Result<&'a mut PuzzleDraftLevel, PuzzleFieldError> {
if let Some(level_id) = level_id.and_then(normalize_required_string) {
return draft
.levels
.iter_mut()
.find(|level| level.level_id == level_id)
.ok_or(PuzzleFieldError::InvalidOperation);
}
draft
.levels
.first_mut()
.ok_or(PuzzleFieldError::InvalidOperation)
}
fn value_as_required_string(value: &Value) -> Result<String, PuzzleFieldError> {
value
.as_str()
.and_then(normalize_required_string)
.ok_or(PuzzleFieldError::MissingText)
}
fn value_as_optional_string(value: &Value) -> Option<String> {
value.as_str().and_then(normalize_required_string)
}
fn value_as_string_list(value: &Value) -> Result<Vec<String>, PuzzleFieldError> {
value
.as_array()
.map(|values| {
values
.iter()
.filter_map(|value| value.as_str().map(ToString::to_string))
.collect::<Vec<_>>()
})
.ok_or(PuzzleFieldError::InvalidOperation)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::creative_templates::{PUZZLE_PHASE1_TEMPLATE_ID, PuzzleCreativePricingUnit};
fn cost_range() -> PuzzleCreativeCostRange {
PuzzleCreativeCostRange {
min_points: 2,
max_points: 12,
pricing_unit: PuzzleCreativePricingUnit::Point,
reason: "按关卡数和每关图片生成次数估算,实际扣费以后端任务结算为准".to_string(),
}
}
#[test]
fn creative_draft_builds_single_level_with_summary_and_plan() {
let input_level = CreativePuzzleLevelDraftInput {
level_name: "第一关".to_string(),
picture_description: "生日蛋糕和朋友合影。".to_string(),
picture_reference: Some("https://assets.example.test/birthday.png".to_string()),
};
let draft = build_puzzle_draft_from_creative_fields(CreativePuzzleDraftToolInput {
template_id: PUZZLE_PHASE1_TEMPLATE_ID.to_string(),
template_cost_range: cost_range(),
work_title: "生日拼图".to_string(),
work_description: "把生日照片做成一关拼图。".to_string(),
work_tags: vec!["生日".to_string(), "朋友".to_string(), "纪念".to_string()],
levels: vec![input_level.clone()],
})
.expect("single level draft should build");
let plan = plan_puzzle_level_images(PuzzleLevelImagePlanInput {
template_id: PUZZLE_PHASE1_TEMPLATE_ID.to_string(),
selected_level_mode: PuzzleCreativeLevelGenerationMode::SingleLevel,
levels: vec![input_level],
cost_range: cost_range(),
candidate_count_per_level: Some(3),
})
.expect("single level image plan should build");
assert_eq!(draft.work_title, "生日拼图");
assert_eq!(draft.work_description, "把生日照片做成一关拼图。");
assert_eq!(draft.summary, "把生日照片做成一关拼图。");
assert_eq!(draft.level_name, "第一关");
assert_eq!(draft.levels.len(), 1);
assert_eq!(
draft.levels[0].picture_reference.as_deref(),
Some("https://assets.example.test/birthday.png")
);
assert_eq!(plan.mode, PuzzleCreativeLevelGenerationMode::SingleLevel);
assert_eq!(plan.levels.len(), 1);
assert_eq!(plan.levels[0].candidate_count, 1);
assert!(plan.levels[0].image_prompt.contains("生日蛋糕和朋友合影"));
}
#[test]
fn creative_draft_builds_multi_level_picture_references() {
let draft = build_puzzle_draft_from_creative_fields(CreativePuzzleDraftToolInput {
template_id: PUZZLE_PHASE1_TEMPLATE_ID.to_string(),
template_cost_range: cost_range(),
work_title: "旅行拼图".to_string(),
work_description: "把旅行照片做成系列拼图。".to_string(),
work_tags: vec!["旅行".to_string(), "照片".to_string(), "纪念".to_string()],
levels: vec![
CreativePuzzleLevelDraftInput {
level_name: "第一站".to_string(),
picture_description: "海边合影".to_string(),
picture_reference: Some("asset-1".to_string()),
},
CreativePuzzleLevelDraftInput {
level_name: "第二站".to_string(),
picture_description: "山顶日落".to_string(),
picture_reference: Some("asset-2".to_string()),
},
],
})
.expect("draft should build");
assert_eq!(draft.work_title, "旅行拼图");
assert_eq!(draft.theme_tags, vec!["旅行", "照片", "纪念"]);
assert_eq!(draft.levels.len(), 2);
assert_eq!(
draft.levels[1].picture_reference.as_deref(),
Some("asset-2")
);
}
#[test]
fn creative_draft_accepts_catalog_subtemplate_id() {
let draft = build_puzzle_draft_from_creative_fields(CreativePuzzleDraftToolInput {
template_id: crate::creative_templates::PUZZLE_TRAVEL_MEMORY_TEMPLATE_ID.to_string(),
template_cost_range: PuzzleCreativeCostRange {
min_points: 4,
max_points: 16,
pricing_unit: PuzzleCreativePricingUnit::Point,
reason: "按旅行节点和每关图片生成次数估算,实际扣费以后端任务结算为准".to_string(),
},
work_title: "旅行记忆".to_string(),
work_description: "把旅行照片做成系列拼图。".to_string(),
work_tags: vec!["旅行".to_string(), "照片".to_string(), "纪念".to_string()],
levels: vec![
CreativePuzzleLevelDraftInput {
level_name: "第一站".to_string(),
picture_description: "海边合影".to_string(),
picture_reference: Some("asset-1".to_string()),
},
CreativePuzzleLevelDraftInput {
level_name: "第二站".to_string(),
picture_description: "山顶日落".to_string(),
picture_reference: Some("asset-2".to_string()),
},
],
})
.expect("subtemplate draft should build");
assert_eq!(draft.work_title, "旅行记忆");
assert_eq!(draft.levels.len(), 2);
}
#[test]
fn draft_patch_rejects_non_whitelisted_operation() {
let draft = build_puzzle_draft_from_creative_fields(CreativePuzzleDraftToolInput {
template_id: PUZZLE_PHASE1_TEMPLATE_ID.to_string(),
template_cost_range: cost_range(),
work_title: "旅行拼图".to_string(),
work_description: "把旅行照片做成系列拼图。".to_string(),
work_tags: vec!["旅行".to_string(), "照片".to_string(), "纪念".to_string()],
levels: vec![CreativePuzzleLevelDraftInput {
level_name: "第一站".to_string(),
picture_description: "海边合影".to_string(),
picture_reference: None,
}],
})
.expect("draft should build");
let error = apply_puzzle_draft_field_patch(
draft,
PuzzleDraftFieldPatch {
field_path: PuzzleCreativeDraftEditableFieldPath::WorkTitle,
operation: PuzzleDraftFieldPatchOperation::Remove,
level_id: None,
value: Value::Null,
rationale: "测试".to_string(),
},
)
.expect_err("remove should be rejected");
assert_eq!(error, PuzzleFieldError::InvalidOperation);
}
}