use std::{ collections::HashMap, error::Error, fmt, sync::{Arc, Mutex}, }; use serde::{Deserialize, Serialize}; use shared_kernel::{ build_prefixed_seed_id, normalize_optional_string as normalize_shared_optional_string, normalize_required_string, normalize_string_list as normalize_shared_string_list, }; #[cfg(feature = "spacetime-types")] use spacetimedb::SpacetimeType; pub const AI_TASK_ID_PREFIX: &str = "aitask_"; pub const AI_TASK_STAGE_ID_PREFIX: &str = "aistage_"; pub const AI_RESULT_REF_ID_PREFIX: &str = "aires_"; pub const AI_TEXT_CHUNK_ID_PREFIX: &str = "aichunk_"; pub const INITIAL_AI_TASK_VERSION: u32 = 1; // AI 编排类型与当前 Node 正式运行时主链保持一致,避免后续接线时重新发明命名。 #[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))] #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum AiTaskKind { StoryGeneration, CharacterChat, NpcChat, CustomWorldGeneration, QuestIntent, RuntimeItemIntent, } #[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))] #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum AiTaskStatus { Pending, Running, Completed, Failed, Cancelled, } #[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))] #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum AiTaskStageKind { PreparePrompt, RequestModel, RepairResponse, NormalizeResult, PersistResult, } #[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))] #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum AiTaskStageStatus { Pending, Running, Completed, Skipped, } #[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))] #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum AiResultReferenceKind { StorySession, StoryEvent, CustomWorldProfile, QuestRecord, RuntimeItemRecord, AssetObject, } #[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))] #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct AiTaskStageBlueprint { pub stage_kind: AiTaskStageKind, pub label: String, pub detail: String, pub order: u32, } #[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))] #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct AiTaskStageSnapshot { pub stage_kind: AiTaskStageKind, pub label: String, pub detail: String, pub order: u32, pub status: AiTaskStageStatus, pub text_output: Option, pub structured_payload_json: Option, pub warning_messages: Vec, pub started_at_micros: Option, pub completed_at_micros: Option, } #[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))] #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct AiTaskCreateInput { pub task_id: String, pub task_kind: AiTaskKind, pub owner_user_id: String, pub request_label: String, pub source_module: String, pub source_entity_id: Option, pub request_payload_json: Option, pub stages: Vec, pub created_at_micros: i64, } #[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))] #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct AiTaskStartInput { pub task_id: String, pub started_at_micros: i64, } #[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))] #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct AiTaskStageStartInput { pub task_id: String, pub stage_kind: AiTaskStageKind, pub started_at_micros: i64, } #[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))] #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct AiTaskSnapshot { pub task_id: String, pub task_kind: AiTaskKind, pub owner_user_id: String, pub request_label: String, pub source_module: String, pub source_entity_id: Option, pub request_payload_json: Option, pub status: AiTaskStatus, pub failure_message: Option, pub stages: Vec, pub result_references: Vec, pub latest_text_output: Option, pub latest_structured_payload_json: Option, pub version: u32, pub created_at_micros: i64, pub started_at_micros: Option, pub completed_at_micros: Option, pub updated_at_micros: i64, } #[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))] #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct AiTextChunkSnapshot { pub chunk_id: String, pub task_id: String, pub stage_kind: AiTaskStageKind, pub sequence: u32, pub delta_text: String, pub created_at_micros: i64, } #[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))] #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct AiTextChunkAppendInput { pub task_id: String, pub stage_kind: AiTaskStageKind, pub sequence: u32, pub delta_text: String, pub created_at_micros: i64, } #[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))] #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct AiStageCompletionInput { pub task_id: String, pub stage_kind: AiTaskStageKind, pub text_output: Option, pub structured_payload_json: Option, pub warning_messages: Vec, pub completed_at_micros: i64, } #[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))] #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct AiResultReferenceInput { pub task_id: String, pub reference_kind: AiResultReferenceKind, pub reference_id: String, pub label: Option, pub created_at_micros: i64, } #[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))] #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct AiResultReferenceSnapshot { pub result_ref_id: String, pub task_id: String, pub reference_kind: AiResultReferenceKind, pub reference_id: String, pub label: Option, pub created_at_micros: i64, } #[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))] #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct AiTaskFinishInput { pub task_id: String, pub completed_at_micros: i64, } #[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))] #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct AiTaskCancelInput { pub task_id: String, pub completed_at_micros: i64, } #[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))] #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct AiTaskFailureInput { pub task_id: String, pub failure_message: String, pub completed_at_micros: i64, } #[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))] #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct AiTaskProcedureResult { pub ok: bool, pub task: Option, pub text_chunk: Option, pub error_message: Option, } #[derive(Clone, Debug, PartialEq, Eq)] pub enum AiTaskFieldError { MissingTaskId, MissingOwnerUserId, MissingRequestLabel, MissingSourceModule, MissingStageBlueprints, DuplicateStageBlueprint, MissingReferenceId, MissingChunkText, InvalidSequence, MissingFailureMessage, MissingStage, InvalidTaskState, } #[derive(Clone, Debug, PartialEq, Eq)] pub enum AiTaskServiceError { Field(AiTaskFieldError), TaskAlreadyExists, TaskNotFound, StageNotFound, Store(String), } #[derive(Clone, Debug, Default)] pub struct InMemoryAiTaskStore { inner: Arc>, } #[derive(Debug, Default)] struct InMemoryAiTaskStoreState { tasks: HashMap, text_chunks: HashMap>, } #[derive(Clone, Debug)] pub struct AiTaskService { store: InMemoryAiTaskStore, } impl AiTaskKind { // 默认阶段蓝图只冻结通用语义,具体 prompt 内容与供应商策略仍由上层模块决定。 pub fn default_stage_blueprints(self) -> Vec { let ordered_kinds = match self { Self::StoryGeneration => vec![ AiTaskStageKind::PreparePrompt, AiTaskStageKind::RequestModel, AiTaskStageKind::RepairResponse, AiTaskStageKind::NormalizeResult, ], Self::CharacterChat | Self::NpcChat | Self::QuestIntent | Self::RuntimeItemIntent => { vec![ AiTaskStageKind::PreparePrompt, AiTaskStageKind::RequestModel, AiTaskStageKind::NormalizeResult, ] } Self::CustomWorldGeneration => vec![ AiTaskStageKind::PreparePrompt, AiTaskStageKind::RequestModel, AiTaskStageKind::RepairResponse, AiTaskStageKind::NormalizeResult, AiTaskStageKind::PersistResult, ], }; ordered_kinds .into_iter() .enumerate() .map(|(index, stage_kind)| AiTaskStageBlueprint { stage_kind, label: stage_kind.default_label().to_string(), detail: stage_kind.default_detail().to_string(), order: index as u32, }) .collect() } } impl AiTaskStageKind { pub fn as_str(self) -> &'static str { match self { Self::PreparePrompt => "prepare_prompt", Self::RequestModel => "request_model", Self::RepairResponse => "repair_response", Self::NormalizeResult => "normalize_result", Self::PersistResult => "persist_result", } } pub fn default_label(self) -> &'static str { match self { Self::PreparePrompt => "整理提示词", Self::RequestModel => "请求模型", Self::RepairResponse => "修复响应", Self::NormalizeResult => "归一结果", Self::PersistResult => "回写结果", } } pub fn default_detail(self) -> &'static str { match self { Self::PreparePrompt => "整理输入上下文并构建本轮提示词。", Self::RequestModel => "向上游模型发起正式推理请求。", Self::RepairResponse => "对非严格输出做补救修复或二次编排。", Self::NormalizeResult => "把模型输出归一成模块可消费结构。", Self::PersistResult => "把结果引用或聚合状态回写到下游模块。", } } } impl AiTaskStatus { fn is_terminal(self) -> bool { matches!(self, Self::Completed | Self::Failed | Self::Cancelled) } } impl AiTaskService { pub fn new(store: InMemoryAiTaskStore) -> Self { Self { store } } pub fn create_task( &self, input: AiTaskCreateInput, ) -> Result { validate_task_create_input(&input).map_err(AiTaskServiceError::Field)?; let snapshot = AiTaskSnapshot { task_id: input.task_id.clone(), task_kind: input.task_kind, owner_user_id: normalize_required_string(input.owner_user_id).unwrap_or_default(), request_label: normalize_required_string(input.request_label).unwrap_or_default(), source_module: normalize_required_string(input.source_module).unwrap_or_default(), source_entity_id: normalize_optional_text(input.source_entity_id), request_payload_json: normalize_optional_text(input.request_payload_json), status: AiTaskStatus::Pending, failure_message: None, stages: input .stages .into_iter() .map(|stage| AiTaskStageSnapshot { stage_kind: stage.stage_kind, label: normalize_required_string(stage.label).unwrap_or_default(), detail: normalize_required_string(stage.detail).unwrap_or_default(), order: stage.order, status: AiTaskStageStatus::Pending, text_output: None, structured_payload_json: None, warning_messages: Vec::new(), started_at_micros: None, completed_at_micros: None, }) .collect(), result_references: Vec::new(), latest_text_output: None, latest_structured_payload_json: None, version: INITIAL_AI_TASK_VERSION, created_at_micros: input.created_at_micros, started_at_micros: None, completed_at_micros: None, updated_at_micros: input.created_at_micros, }; self.store.insert_task(snapshot) } pub fn start_task( &self, task_id: &str, started_at_micros: i64, ) -> Result { self.store.update_task(task_id, |task| { if task.status.is_terminal() { return Err(AiTaskServiceError::Field( AiTaskFieldError::InvalidTaskState, )); } task.status = AiTaskStatus::Running; task.started_at_micros.get_or_insert(started_at_micros); task.updated_at_micros = started_at_micros; task.version += 1; Ok(()) }) } pub fn start_stage( &self, task_id: &str, stage_kind: AiTaskStageKind, started_at_micros: i64, ) -> Result { self.store.update_task(task_id, |task| { if task.status.is_terminal() { return Err(AiTaskServiceError::Field( AiTaskFieldError::InvalidTaskState, )); } task.status = AiTaskStatus::Running; task.started_at_micros.get_or_insert(started_at_micros); let stage = task .stages .iter_mut() .find(|stage| stage.stage_kind == stage_kind) .ok_or(AiTaskServiceError::StageNotFound)?; stage.status = AiTaskStageStatus::Running; stage.started_at_micros.get_or_insert(started_at_micros); task.updated_at_micros = started_at_micros; task.version += 1; Ok(()) }) } pub fn append_text_chunk( &self, task_id: &str, stage_kind: AiTaskStageKind, sequence: u32, delta_text: String, created_at_micros: i64, ) -> Result<(AiTaskSnapshot, AiTextChunkSnapshot), AiTaskServiceError> { if delta_text.trim().is_empty() { return Err(AiTaskServiceError::Field( AiTaskFieldError::MissingChunkText, )); } if sequence == 0 { return Err(AiTaskServiceError::Field(AiTaskFieldError::InvalidSequence)); } let chunk = AiTextChunkSnapshot { chunk_id: generate_ai_text_chunk_id(created_at_micros, sequence), task_id: normalize_required_string(task_id).unwrap_or_default(), stage_kind, sequence, delta_text: normalize_required_string(delta_text).unwrap_or_default(), created_at_micros, }; let task = self.store.append_text_chunk(chunk.clone())?; Ok((task, chunk)) } pub fn complete_stage( &self, input: AiStageCompletionInput, ) -> Result { self.store.update_task(&input.task_id, |task| { if task.status.is_terminal() { return Err(AiTaskServiceError::Field( AiTaskFieldError::InvalidTaskState, )); } let stage = task .stages .iter_mut() .find(|stage| stage.stage_kind == input.stage_kind) .ok_or(AiTaskServiceError::StageNotFound)?; stage.status = AiTaskStageStatus::Completed; stage.completed_at_micros = Some(input.completed_at_micros); stage.text_output = normalize_optional_text(input.text_output.clone()); stage.structured_payload_json = normalize_optional_text(input.structured_payload_json.clone()); stage.warning_messages = normalize_string_list(input.warning_messages.clone()); task.latest_text_output = stage.text_output.clone(); task.latest_structured_payload_json = stage.structured_payload_json.clone(); task.updated_at_micros = input.completed_at_micros; task.version += 1; Ok(()) }) } pub fn attach_result_reference( &self, task_id: &str, reference_kind: AiResultReferenceKind, reference_id: String, label: Option, created_at_micros: i64, ) -> Result { let Some(reference_id) = normalize_required_string(reference_id) else { return Err(AiTaskServiceError::Field( AiTaskFieldError::MissingReferenceId, )); }; self.store.update_task(task_id, |task| { task.result_references.push(AiResultReferenceSnapshot { result_ref_id: generate_ai_result_ref_id(created_at_micros), task_id: task.task_id.clone(), reference_kind, reference_id: reference_id.clone(), label: normalize_optional_text(label.clone()), created_at_micros, }); task.updated_at_micros = created_at_micros; task.version += 1; Ok(()) }) } pub fn complete_task( &self, task_id: &str, completed_at_micros: i64, ) -> Result { self.store.update_task(task_id, |task| { if task.status.is_terminal() { return Err(AiTaskServiceError::Field( AiTaskFieldError::InvalidTaskState, )); } task.status = AiTaskStatus::Completed; task.completed_at_micros = Some(completed_at_micros); task.updated_at_micros = completed_at_micros; task.version += 1; Ok(()) }) } pub fn fail_task( &self, task_id: &str, failure_message: String, completed_at_micros: i64, ) -> Result { let Some(failure_message) = normalize_required_string(failure_message) else { return Err(AiTaskServiceError::Field( AiTaskFieldError::MissingFailureMessage, )); }; self.store.update_task(task_id, |task| { if task.status.is_terminal() { return Err(AiTaskServiceError::Field( AiTaskFieldError::InvalidTaskState, )); } task.status = AiTaskStatus::Failed; task.failure_message = Some(failure_message.clone()); task.completed_at_micros = Some(completed_at_micros); task.updated_at_micros = completed_at_micros; task.version += 1; Ok(()) }) } pub fn cancel_task( &self, task_id: &str, completed_at_micros: i64, ) -> Result { self.store.update_task(task_id, |task| { if task.status.is_terminal() { return Err(AiTaskServiceError::Field( AiTaskFieldError::InvalidTaskState, )); } task.status = AiTaskStatus::Cancelled; task.completed_at_micros = Some(completed_at_micros); task.updated_at_micros = completed_at_micros; task.version += 1; Ok(()) }) } pub fn get_task(&self, task_id: &str) -> Result { self.store.get_task(task_id) } } impl InMemoryAiTaskStore { fn insert_task(&self, task: AiTaskSnapshot) -> Result { let mut state = self .inner .lock() .map_err(|_| AiTaskServiceError::Store("AI 任务仓储锁已中毒".to_string()))?; if state.tasks.contains_key(&task.task_id) { return Err(AiTaskServiceError::TaskAlreadyExists); } state.text_chunks.insert(task.task_id.clone(), Vec::new()); state.tasks.insert(task.task_id.clone(), task.clone()); Ok(task) } fn update_task( &self, task_id: &str, mut apply: F, ) -> Result where F: FnMut(&mut AiTaskSnapshot) -> Result<(), AiTaskServiceError>, { let mut state = self .inner .lock() .map_err(|_| AiTaskServiceError::Store("AI 任务仓储锁已中毒".to_string()))?; let task = state .tasks .get_mut(task_id.trim()) .ok_or(AiTaskServiceError::TaskNotFound)?; apply(task)?; Ok(task.clone()) } fn append_text_chunk( &self, chunk: AiTextChunkSnapshot, ) -> Result { let mut state = self .inner .lock() .map_err(|_| AiTaskServiceError::Store("AI 任务仓储锁已中毒".to_string()))?; { let task = state .tasks .get_mut(&chunk.task_id) .ok_or(AiTaskServiceError::TaskNotFound)?; if task.status.is_terminal() { return Err(AiTaskServiceError::Field( AiTaskFieldError::InvalidTaskState, )); } let stage = task .stages .iter_mut() .find(|stage| stage.stage_kind == chunk.stage_kind) .ok_or(AiTaskServiceError::StageNotFound)?; if stage.status == AiTaskStageStatus::Pending { stage.status = AiTaskStageStatus::Running; stage.started_at_micros = Some(chunk.created_at_micros); } task.status = AiTaskStatus::Running; task.started_at_micros .get_or_insert(chunk.created_at_micros); } let chunks = state .text_chunks .get_mut(&chunk.task_id) .ok_or(AiTaskServiceError::TaskNotFound)?; chunks.push(chunk.clone()); chunks.sort_by_key(|value| value.sequence); let aggregated_text = chunks .iter() .filter(|value| value.stage_kind == chunk.stage_kind) .map(|value| value.delta_text.as_str()) .collect::>() .join(""); let normalized_output = if aggregated_text.trim().is_empty() { None } else { Some(aggregated_text) }; let task = state .tasks .get_mut(&chunk.task_id) .ok_or(AiTaskServiceError::TaskNotFound)?; let stage = task .stages .iter_mut() .find(|stage| stage.stage_kind == chunk.stage_kind) .ok_or(AiTaskServiceError::StageNotFound)?; stage.text_output = normalized_output.clone(); task.latest_text_output = normalized_output; task.updated_at_micros = chunk.created_at_micros; task.version += 1; Ok(task.clone()) } fn get_task(&self, task_id: &str) -> Result { let state = self .inner .lock() .map_err(|_| AiTaskServiceError::Store("AI 任务仓储锁已中毒".to_string()))?; state .tasks .get(task_id.trim()) .cloned() .ok_or(AiTaskServiceError::TaskNotFound) } } pub fn validate_task_create_input(input: &AiTaskCreateInput) -> Result<(), AiTaskFieldError> { if normalize_required_string(&input.task_id).is_none() { return Err(AiTaskFieldError::MissingTaskId); } if normalize_required_string(&input.owner_user_id).is_none() { return Err(AiTaskFieldError::MissingOwnerUserId); } if normalize_required_string(&input.request_label).is_none() { return Err(AiTaskFieldError::MissingRequestLabel); } if normalize_required_string(&input.source_module).is_none() { return Err(AiTaskFieldError::MissingSourceModule); } if input.stages.is_empty() { return Err(AiTaskFieldError::MissingStageBlueprints); } let mut seen = HashMap::new(); for stage in &input.stages { if normalize_required_string(&stage.label).is_none() || normalize_required_string(&stage.detail).is_none() { return Err(AiTaskFieldError::MissingStageBlueprints); } if seen.insert(stage.stage_kind, true).is_some() { return Err(AiTaskFieldError::DuplicateStageBlueprint); } } Ok(()) } pub fn generate_ai_task_id(seed_micros: i64) -> String { build_prefixed_seed_id(AI_TASK_ID_PREFIX, seed_micros) } pub fn generate_ai_task_stage_id(task_id: &str, stage_kind: AiTaskStageKind) -> String { format!( "{}{}_{}", AI_TASK_STAGE_ID_PREFIX, task_id.trim(), stage_kind.as_str() ) } pub fn generate_ai_result_ref_id(seed_micros: i64) -> String { build_prefixed_seed_id(AI_RESULT_REF_ID_PREFIX, seed_micros) } pub fn generate_ai_text_chunk_id(seed_micros: i64, sequence: u32) -> String { format!("{}{seed_micros:x}_{sequence:x}", AI_TEXT_CHUNK_ID_PREFIX) } pub fn normalize_optional_text(value: Option) -> Option { normalize_shared_optional_string(value) } pub fn normalize_string_list(values: Vec) -> Vec { normalize_shared_string_list(values) } impl fmt::Display for AiTaskFieldError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::MissingTaskId => f.write_str("ai_task.task_id 不能为空"), Self::MissingOwnerUserId => f.write_str("ai_task.owner_user_id 不能为空"), Self::MissingRequestLabel => f.write_str("ai_task.request_label 不能为空"), Self::MissingSourceModule => f.write_str("ai_task.source_module 不能为空"), Self::MissingStageBlueprints => f.write_str("ai_task.stages 至少需要一个有效阶段"), Self::DuplicateStageBlueprint => f.write_str("ai_task.stages 不能包含重复阶段"), Self::MissingReferenceId => f.write_str("ai_result_reference.reference_id 不能为空"), Self::MissingChunkText => f.write_str("ai_text_chunk.delta_text 不能为空"), Self::InvalidSequence => f.write_str("ai_text_chunk.sequence 必须大于 0"), Self::MissingFailureMessage => f.write_str("ai_task.failure_message 不能为空"), Self::MissingStage => f.write_str("ai_task.stage 不存在"), Self::InvalidTaskState => f.write_str("当前 ai_task 状态不允许执行该操作"), } } } impl Error for AiTaskFieldError {} impl fmt::Display for AiTaskServiceError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Field(error) => write!(f, "{error}"), Self::TaskAlreadyExists => f.write_str("ai_task 已存在,不能重复创建"), Self::TaskNotFound => f.write_str("ai_task 不存在"), Self::StageNotFound => f.write_str("ai_task.stage 不存在"), Self::Store(message) => f.write_str(message), } } } impl Error for AiTaskServiceError {} #[cfg(test)] mod tests { use super::*; fn build_service() -> AiTaskService { AiTaskService::new(InMemoryAiTaskStore::default()) } fn build_create_input(task_kind: AiTaskKind) -> AiTaskCreateInput { AiTaskCreateInput { task_id: generate_ai_task_id(1_713_680_000_000_000), task_kind, owner_user_id: "user_001".to_string(), request_label: "首轮故事生成".to_string(), source_module: "story".to_string(), source_entity_id: Some("storysess_001".to_string()), request_payload_json: Some("{\"scene\":\"camp\"}".to_string()), stages: task_kind.default_stage_blueprints(), created_at_micros: 1_713_680_000_000_000, } } #[test] fn default_stage_blueprints_match_story_baseline() { let stages = AiTaskKind::StoryGeneration.default_stage_blueprints(); assert_eq!(stages.len(), 4); assert_eq!(stages[0].stage_kind, AiTaskStageKind::PreparePrompt); assert_eq!(stages[1].stage_kind, AiTaskStageKind::RequestModel); assert_eq!(stages[2].stage_kind, AiTaskStageKind::RepairResponse); assert_eq!(stages[3].stage_kind, AiTaskStageKind::NormalizeResult); } #[test] fn create_task_rejects_duplicate_stage_blueprints() { let mut input = build_create_input(AiTaskKind::StoryGeneration); input.stages.push(AiTaskStageBlueprint { stage_kind: AiTaskStageKind::PreparePrompt, label: "重复阶段".to_string(), detail: "重复阶段".to_string(), order: 99, }); let error = validate_task_create_input(&input).expect_err("duplicate stages should fail"); assert_eq!(error, AiTaskFieldError::DuplicateStageBlueprint); } #[test] fn generate_ai_task_stage_id_contains_task_and_stage_slug() { let stage_id = generate_ai_task_stage_id("aitask_demo", AiTaskStageKind::NormalizeResult); assert_eq!(stage_id, "aistage_aitask_demo_normalize_result"); } #[test] fn create_and_start_task_updates_status() { let service = build_service(); let created = service .create_task(build_create_input(AiTaskKind::QuestIntent)) .expect("task should create"); let started = service .start_task(&created.task_id, created.created_at_micros + 1) .expect("task should start"); assert_eq!(created.status, AiTaskStatus::Pending); assert_eq!(started.status, AiTaskStatus::Running); assert_eq!( started.started_at_micros, Some(created.created_at_micros + 1) ); assert_eq!(started.version, INITIAL_AI_TASK_VERSION + 1); } #[test] fn append_text_chunk_aggregates_stream_output_by_stage() { let service = build_service(); let task = service .create_task(build_create_input(AiTaskKind::CharacterChat)) .expect("task should create"); service .start_stage( &task.task_id, AiTaskStageKind::RequestModel, task.created_at_micros + 10, ) .expect("stage should start"); let (after_first, _) = service .append_text_chunk( &task.task_id, AiTaskStageKind::RequestModel, 1, "你".to_string(), task.created_at_micros + 20, ) .expect("first chunk should append"); let (after_second, second_chunk) = service .append_text_chunk( &task.task_id, AiTaskStageKind::RequestModel, 2, "好。".to_string(), task.created_at_micros + 30, ) .expect("second chunk should append"); assert_eq!(after_first.latest_text_output.as_deref(), Some("你")); assert_eq!(after_second.latest_text_output.as_deref(), Some("你好。")); assert_eq!(second_chunk.sequence, 2); } #[test] fn complete_stage_updates_latest_outputs() { let service = build_service(); let task = service .create_task(build_create_input(AiTaskKind::StoryGeneration)) .expect("task should create"); let completed = service .complete_stage(AiStageCompletionInput { task_id: task.task_id.clone(), stage_kind: AiTaskStageKind::NormalizeResult, text_output: Some("营地前的篝火重新亮了起来。".to_string()), structured_payload_json: Some("{\"choices\":3}".to_string()), warning_messages: vec!["使用了 fallback 选项池".to_string()], completed_at_micros: task.created_at_micros + 50, }) .expect("stage should complete"); let stage = completed .stages .iter() .find(|stage| stage.stage_kind == AiTaskStageKind::NormalizeResult) .expect("normalize stage should exist"); assert_eq!(stage.status, AiTaskStageStatus::Completed); assert_eq!( completed.latest_text_output.as_deref(), Some("营地前的篝火重新亮了起来。") ); assert_eq!( completed.latest_structured_payload_json.as_deref(), Some("{\"choices\":3}") ); assert_eq!(stage.warning_messages, vec!["使用了 fallback 选项池"]); } #[test] fn attach_result_reference_appends_binding() { let service = build_service(); let task = service .create_task(build_create_input(AiTaskKind::CustomWorldGeneration)) .expect("task should create"); let updated = service .attach_result_reference( &task.task_id, AiResultReferenceKind::CustomWorldProfile, "profile_001".to_string(), Some("主世界档案".to_string()), task.created_at_micros + 10, ) .expect("reference should attach"); assert_eq!(updated.result_references.len(), 1); assert_eq!( updated.result_references[0].reference_kind, AiResultReferenceKind::CustomWorldProfile ); assert_eq!(updated.result_references[0].reference_id, "profile_001"); } #[test] fn fail_and_cancel_task_move_into_terminal_states() { let service = build_service(); let first = service .create_task(build_create_input(AiTaskKind::NpcChat)) .expect("task should create"); let failed = service .fail_task( &first.task_id, "上游模型超时".to_string(), first.created_at_micros + 10, ) .expect("task should fail"); assert_eq!(failed.status, AiTaskStatus::Failed); assert_eq!(failed.failure_message.as_deref(), Some("上游模型超时")); let second = service .create_task(AiTaskCreateInput { task_id: generate_ai_task_id(1_713_680_000_000_999), ..build_create_input(AiTaskKind::RuntimeItemIntent) }) .expect("second task should create"); let cancelled = service .cancel_task(&second.task_id, second.created_at_micros + 20) .expect("task should cancel"); assert_eq!(cancelled.status, AiTaskStatus::Cancelled); assert_eq!( cancelled.completed_at_micros, Some(second.created_at_micros + 20) ); } #[test] fn complete_task_marks_terminal_success() { let service = build_service(); let task = service .create_task(build_create_input(AiTaskKind::QuestIntent)) .expect("task should create"); let completed = service .complete_task(&task.task_id, task.created_at_micros + 100) .expect("task should complete"); assert_eq!(completed.status, AiTaskStatus::Completed); assert_eq!( completed.completed_at_micros, Some(task.created_at_micros + 100) ); } }