Files
Genarrative/server-rs/crates/module-ai/src/tests.rs

221 lines
7.3 KiB
Rust

use super::*;
fn build_service() -> AiTaskService {
AiTaskService::new(InMemoryAiTaskStore::default())
}
fn build_create_input(task_kind: AiTaskKind) -> AiTaskCreateInput {
AiTaskCreateInput {
task_id: generate_ai_task_id(1_713_680_000_000_000),
task_kind,
owner_user_id: "user_001".to_string(),
request_label: "首轮故事生成".to_string(),
source_module: "story".to_string(),
source_entity_id: Some("storysess_001".to_string()),
request_payload_json: Some("{\"scene\":\"camp\"}".to_string()),
stages: task_kind.default_stage_blueprints(),
created_at_micros: 1_713_680_000_000_000,
}
}
#[test]
fn default_stage_blueprints_match_story_baseline() {
let stages = AiTaskKind::StoryGeneration.default_stage_blueprints();
assert_eq!(stages.len(), 4);
assert_eq!(stages[0].stage_kind, AiTaskStageKind::PreparePrompt);
assert_eq!(stages[1].stage_kind, AiTaskStageKind::RequestModel);
assert_eq!(stages[2].stage_kind, AiTaskStageKind::RepairResponse);
assert_eq!(stages[3].stage_kind, AiTaskStageKind::NormalizeResult);
}
#[test]
fn create_task_rejects_duplicate_stage_blueprints() {
let mut input = build_create_input(AiTaskKind::StoryGeneration);
input.stages.push(AiTaskStageBlueprint {
stage_kind: AiTaskStageKind::PreparePrompt,
label: "重复阶段".to_string(),
detail: "重复阶段".to_string(),
order: 99,
});
let error = validate_task_create_input(&input).expect_err("duplicate stages should fail");
assert_eq!(error, AiTaskFieldError::DuplicateStageBlueprint);
}
#[test]
fn generate_ai_task_stage_id_contains_task_and_stage_slug() {
let stage_id = generate_ai_task_stage_id("aitask_demo", AiTaskStageKind::NormalizeResult);
assert_eq!(stage_id, "aistage_aitask_demo_normalize_result");
}
#[test]
fn create_and_start_task_updates_status() {
let service = build_service();
let created = service
.create_task(build_create_input(AiTaskKind::QuestIntent))
.expect("task should create");
let started = service
.start_task(&created.task_id, created.created_at_micros + 1)
.expect("task should start");
assert_eq!(created.status, AiTaskStatus::Pending);
assert_eq!(started.status, AiTaskStatus::Running);
assert_eq!(
started.started_at_micros,
Some(created.created_at_micros + 1)
);
assert_eq!(started.version, INITIAL_AI_TASK_VERSION + 1);
}
#[test]
fn append_text_chunk_aggregates_stream_output_by_stage() {
let service = build_service();
let task = service
.create_task(build_create_input(AiTaskKind::CharacterChat))
.expect("task should create");
service
.start_stage(
&task.task_id,
AiTaskStageKind::RequestModel,
task.created_at_micros + 10,
)
.expect("stage should start");
let (after_first, _) = service
.append_text_chunk(
&task.task_id,
AiTaskStageKind::RequestModel,
1,
"".to_string(),
task.created_at_micros + 20,
)
.expect("first chunk should append");
let (after_second, second_chunk) = service
.append_text_chunk(
&task.task_id,
AiTaskStageKind::RequestModel,
2,
"好。".to_string(),
task.created_at_micros + 30,
)
.expect("second chunk should append");
assert_eq!(after_first.latest_text_output.as_deref(), Some(""));
assert_eq!(after_second.latest_text_output.as_deref(), Some("你好。"));
assert_eq!(second_chunk.sequence, 2);
}
#[test]
fn complete_stage_updates_latest_outputs() {
let service = build_service();
let task = service
.create_task(build_create_input(AiTaskKind::StoryGeneration))
.expect("task should create");
let completed = service
.complete_stage(AiStageCompletionInput {
task_id: task.task_id.clone(),
stage_kind: AiTaskStageKind::NormalizeResult,
text_output: Some("营地前的篝火重新亮了起来。".to_string()),
structured_payload_json: Some("{\"choices\":3}".to_string()),
warning_messages: vec!["使用了 fallback 选项池".to_string()],
completed_at_micros: task.created_at_micros + 50,
})
.expect("stage should complete");
let stage = completed
.stages
.iter()
.find(|stage| stage.stage_kind == AiTaskStageKind::NormalizeResult)
.expect("normalize stage should exist");
assert_eq!(stage.status, AiTaskStageStatus::Completed);
assert_eq!(
completed.latest_text_output.as_deref(),
Some("营地前的篝火重新亮了起来。")
);
assert_eq!(
completed.latest_structured_payload_json.as_deref(),
Some("{\"choices\":3}")
);
assert_eq!(stage.warning_messages, vec!["使用了 fallback 选项池"]);
}
#[test]
fn attach_result_reference_appends_binding() {
let service = build_service();
let task = service
.create_task(build_create_input(AiTaskKind::CustomWorldGeneration))
.expect("task should create");
let updated = service
.attach_result_reference(
&task.task_id,
AiResultReferenceKind::CustomWorldProfile,
"profile_001".to_string(),
Some("主世界档案".to_string()),
task.created_at_micros + 10,
)
.expect("reference should attach");
assert_eq!(updated.result_references.len(), 1);
assert_eq!(
updated.result_references[0].reference_kind,
AiResultReferenceKind::CustomWorldProfile
);
assert_eq!(updated.result_references[0].reference_id, "profile_001");
}
#[test]
fn fail_and_cancel_task_move_into_terminal_states() {
let service = build_service();
let first = service
.create_task(build_create_input(AiTaskKind::NpcChat))
.expect("task should create");
let failed = service
.fail_task(
&first.task_id,
"上游模型超时".to_string(),
first.created_at_micros + 10,
)
.expect("task should fail");
assert_eq!(failed.status, AiTaskStatus::Failed);
assert_eq!(failed.failure_message.as_deref(), Some("上游模型超时"));
let second = service
.create_task(AiTaskCreateInput {
task_id: generate_ai_task_id(1_713_680_000_000_999),
..build_create_input(AiTaskKind::RuntimeItemIntent)
})
.expect("second task should create");
let cancelled = service
.cancel_task(&second.task_id, second.created_at_micros + 20)
.expect("task should cancel");
assert_eq!(cancelled.status, AiTaskStatus::Cancelled);
assert_eq!(
cancelled.completed_at_micros,
Some(second.created_at_micros + 20)
);
}
#[test]
fn complete_task_marks_terminal_success() {
let service = build_service();
let task = service
.create_task(build_create_input(AiTaskKind::QuestIntent))
.expect("task should create");
let completed = service
.complete_task(&task.task_id, task.created_at_micros + 100)
.expect("task should complete");
assert_eq!(completed.status, AiTaskStatus::Completed);
assert_eq!(
completed.completed_at_micros,
Some(task.created_at_micros + 100)
);
}