Preserve partial creation replies on stream failure
Some checks failed
CI / verify (push) Has been cancelled

This commit is contained in:
kdletters
2026-05-05 11:31:50 +08:00
parent 100fee7e7a
commit 995661e7cc
299 changed files with 13805 additions and 1429 deletions

View File

@@ -0,0 +1,307 @@
use module_square_hole::{
SQUARE_HOLE_MAX_DIFFICULTY, SQUARE_HOLE_MAX_SHAPE_COUNT, SQUARE_HOLE_MIN_DIFFICULTY,
SQUARE_HOLE_MIN_SHAPE_COUNT,
};
use platform_llm::LlmClient;
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use spacetime_client::{SquareHoleAgentMessageFinalizeRecordInput, SquareHoleAgentSessionRecord};
use crate::creation_agent_llm_turn::{
CreationAgentLlmTurnErrorMessages, stream_creation_agent_json_turn,
};
use crate::prompt::square_hole::{
SQUARE_HOLE_AGENT_JSON_TURN_USER_PROMPT, SQUARE_HOLE_AGENT_SYSTEM_PROMPT,
build_square_hole_agent_prompt,
};
#[derive(Clone, Debug)]
pub(crate) struct SquareHoleAgentTurnRequest<'a> {
pub llm_client: Option<&'a LlmClient>,
pub session: &'a SquareHoleAgentSessionRecord,
pub quick_fill_requested: bool,
pub enable_web_search: bool,
}
#[derive(Clone, Debug)]
pub(crate) struct SquareHoleAgentTurnResult {
pub assistant_reply_text: String,
pub stage: String,
pub progress_percent: u32,
pub config_json: String,
pub error_message: Option<String>,
}
#[derive(Clone, Debug)]
pub(crate) struct SquareHoleAgentTurnError {
message: String,
}
impl SquareHoleAgentTurnError {
fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
}
impl std::fmt::Display for SquareHoleAgentTurnError {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str(&self.message)
}
}
impl std::error::Error for SquareHoleAgentTurnError {}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct SquareHoleAgentModelOutput {
reply_text: String,
progress_percent: u32,
next_config: SquareHoleAgentConfigOutput,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct SquareHoleAgentConfigOutput {
theme_text: String,
twist_rule: String,
shape_count: u32,
difficulty: u32,
}
pub(crate) async fn run_square_hole_agent_turn<F>(
request: SquareHoleAgentTurnRequest<'_>,
on_reply_update: F,
) -> Result<SquareHoleAgentTurnResult, SquareHoleAgentTurnError>
where
F: FnMut(&str),
{
let prompt = build_square_hole_agent_prompt(request.session, request.quick_fill_requested);
let turn_output = stream_creation_agent_json_turn(
request.llm_client,
format!("{SQUARE_HOLE_AGENT_SYSTEM_PROMPT}\n\n{prompt}"),
SQUARE_HOLE_AGENT_JSON_TURN_USER_PROMPT,
request.enable_web_search,
CreationAgentLlmTurnErrorMessages {
model_unavailable: "当前模型不可用,请稍后重试。",
generation_failed: "方洞挑战聊天生成失败,请稍后重试。",
parse_failed: "方洞挑战聊天结果解析失败,请稍后重试。",
},
on_reply_update,
SquareHoleAgentTurnError::new,
)
.await?;
let output = parse_model_output(&turn_output.parsed, request.session)?;
let progress_percent = if request.quick_fill_requested {
100
} else {
output.progress_percent.min(100)
};
Ok(SquareHoleAgentTurnResult {
assistant_reply_text: output.reply_text,
stage: resolve_stage(progress_percent),
progress_percent,
config_json: serde_json::to_string(&output.next_config)
.map_err(|_| SquareHoleAgentTurnError::new("方洞挑战配置序列化失败。"))?,
error_message: None,
})
}
pub(crate) fn build_finalize_record_input(
session_id: String,
owner_user_id: String,
assistant_message_id: String,
result: SquareHoleAgentTurnResult,
updated_at_micros: i64,
) -> SquareHoleAgentMessageFinalizeRecordInput {
SquareHoleAgentMessageFinalizeRecordInput {
session_id,
owner_user_id,
assistant_message_id: Some(assistant_message_id),
assistant_reply_text: Some(result.assistant_reply_text),
config_json: Some(result.config_json),
progress_percent: result.progress_percent,
stage: result.stage,
updated_at_micros,
error_message: result.error_message,
}
}
fn parse_model_output(
parsed: &JsonValue,
session: &SquareHoleAgentSessionRecord,
) -> Result<SquareHoleAgentModelOutput, SquareHoleAgentTurnError> {
let reply_text = parsed
.get("replyText")
.and_then(JsonValue::as_str)
.map(str::trim)
.filter(|value| !value.is_empty())
.ok_or_else(|| SquareHoleAgentTurnError::new("方洞挑战聊天结果缺少有效回复,请稍后重试。"))?
.to_string();
let progress_percent = parsed
.get("progressPercent")
.and_then(JsonValue::as_u64)
.map(|value| value.min(100) as u32)
.unwrap_or(session.progress_percent);
let next_config_value = parsed
.get("nextConfig")
.ok_or_else(|| SquareHoleAgentTurnError::new("方洞挑战聊天结果缺少 nextConfig。"))?;
let next_config = parse_model_config(next_config_value, session)?;
Ok(SquareHoleAgentModelOutput {
reply_text,
progress_percent,
next_config,
})
}
fn parse_model_config(
value: &JsonValue,
session: &SquareHoleAgentSessionRecord,
) -> Result<SquareHoleAgentConfigOutput, SquareHoleAgentTurnError> {
if !value.is_object() {
return Err(SquareHoleAgentTurnError::new(
"方洞挑战聊天结果中的 nextConfig 必须是对象。",
));
}
Ok(SquareHoleAgentConfigOutput {
theme_text: read_text_field(value, "themeText")
.unwrap_or_else(|| session.config.theme_text.clone()),
twist_rule: read_text_field(value, "twistRule")
.unwrap_or_else(|| session.config.twist_rule.clone()),
shape_count: read_u32_field(value, "shapeCount")
.unwrap_or(session.config.shape_count)
.clamp(SQUARE_HOLE_MIN_SHAPE_COUNT, SQUARE_HOLE_MAX_SHAPE_COUNT),
difficulty: read_u32_field(value, "difficulty")
.unwrap_or(session.config.difficulty)
.clamp(SQUARE_HOLE_MIN_DIFFICULTY, SQUARE_HOLE_MAX_DIFFICULTY),
})
}
fn read_text_field(value: &JsonValue, field_name: &str) -> Option<String> {
value
.get(field_name)
.and_then(JsonValue::as_str)
.map(str::trim)
.filter(|text| !text.is_empty())
.map(str::to_string)
}
fn read_u32_field(value: &JsonValue, field_name: &str) -> Option<u32> {
value
.get(field_name)
.and_then(JsonValue::as_u64)
.and_then(|number| u32::try_from(number).ok())
}
fn resolve_stage(progress_percent: u32) -> String {
if progress_percent >= 100 {
"ReadyToCompile"
} else {
"Collecting"
}
.to_string()
}
#[cfg(test)]
mod tests {
use serde_json::json;
use super::{parse_model_output, resolve_stage};
fn session_record() -> spacetime_client::SquareHoleAgentSessionRecord {
spacetime_client::SquareHoleAgentSessionRecord {
session_id: "square-hole-session-test".to_string(),
current_turn: 1,
progress_percent: 25,
stage: "collecting_config".to_string(),
anchor_pack: spacetime_client::SquareHoleAnchorPackRecord {
theme: anchor("theme", "题材主题", "纸箱"),
twist_rule: anchor("twistRule", "反差规则", "方洞万能"),
shape_count: anchor("shapeCount", "形状数量", "12"),
difficulty: anchor("difficulty", "难度", "4"),
},
config: spacetime_client::SquareHoleCreatorConfigRecord {
theme_text: "纸箱".to_string(),
twist_rule: "方洞万能".to_string(),
shape_count: 12,
difficulty: 4,
},
draft: None,
messages: Vec::new(),
last_assistant_reply: None,
published_profile_id: None,
updated_at: "2026-05-04T10:00:00.000Z".to_string(),
}
}
fn anchor(key: &str, label: &str, value: &str) -> spacetime_client::SquareHoleAnchorItemRecord {
spacetime_client::SquareHoleAnchorItemRecord {
key: key.to_string(),
label: label.to_string(),
value: value.to_string(),
status: if value.is_empty() {
"missing"
} else {
"confirmed"
}
.to_string(),
}
}
#[test]
fn parse_model_output_accepts_camel_case_config_contract() {
let model_output = json!({
"replyText": "可以,把办公室文具都做成会被方洞吞进去的挑战。",
"progressPercent": 86,
"nextConfig": {
"themeText": "办公室文具",
"twistRule": "所有文具最终都优先进入方洞",
"shapeCount": 14,
"difficulty": 6
}
});
let output =
parse_model_output(&model_output, &session_record()).expect("模型输出应能解析");
assert_eq!(
output.reply_text,
"可以,把办公室文具都做成会被方洞吞进去的挑战。"
);
assert_eq!(output.progress_percent, 86);
assert_eq!(output.next_config.theme_text, "办公室文具");
assert_eq!(output.next_config.twist_rule, "所有文具最终都优先进入方洞");
assert_eq!(output.next_config.shape_count, 14);
assert_eq!(output.next_config.difficulty, 6);
}
#[test]
fn parse_model_output_clamps_numeric_config() {
let model_output = json!({
"replyText": "我先把数字压到可试玩范围里。",
"progressPercent": 120,
"nextConfig": {
"themeText": "霓虹积木",
"twistRule": "方洞优先",
"shapeCount": 99,
"difficulty": 0
}
});
let output =
parse_model_output(&model_output, &session_record()).expect("模型输出应能解析");
assert_eq!(output.progress_percent, 100);
assert_eq!(output.next_config.shape_count, 24);
assert_eq!(output.next_config.difficulty, 1);
}
#[test]
fn resolve_stage_switches_to_compile_only_at_complete_progress() {
assert_eq!(resolve_stage(99), "Collecting");
assert_eq!(resolve_stage(100), "ReadyToCompile");
}
}