1051 lines
36 KiB
Rust
1051 lines
36 KiB
Rust
use std::{
|
|
collections::HashMap,
|
|
error::Error,
|
|
fmt,
|
|
sync::{Arc, Mutex},
|
|
};
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
use shared_kernel::{
|
|
build_prefixed_seed_id, normalize_optional_string as normalize_shared_optional_string,
|
|
normalize_required_string, normalize_string_list as normalize_shared_string_list,
|
|
};
|
|
#[cfg(feature = "spacetime-types")]
|
|
use spacetimedb::SpacetimeType;
|
|
|
|
pub const AI_TASK_ID_PREFIX: &str = "aitask_";
|
|
pub const AI_TASK_STAGE_ID_PREFIX: &str = "aistage_";
|
|
pub const AI_RESULT_REF_ID_PREFIX: &str = "aires_";
|
|
pub const AI_TEXT_CHUNK_ID_PREFIX: &str = "aichunk_";
|
|
pub const INITIAL_AI_TASK_VERSION: u32 = 1;
|
|
|
|
// AI 编排类型与当前 Node 正式运行时主链保持一致,避免后续接线时重新发明命名。
|
|
#[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))]
|
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
|
pub enum AiTaskKind {
|
|
StoryGeneration,
|
|
CharacterChat,
|
|
NpcChat,
|
|
CustomWorldGeneration,
|
|
QuestIntent,
|
|
RuntimeItemIntent,
|
|
}
|
|
|
|
#[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))]
|
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
|
pub enum AiTaskStatus {
|
|
Pending,
|
|
Running,
|
|
Completed,
|
|
Failed,
|
|
Cancelled,
|
|
}
|
|
|
|
#[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))]
|
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
|
pub enum AiTaskStageKind {
|
|
PreparePrompt,
|
|
RequestModel,
|
|
RepairResponse,
|
|
NormalizeResult,
|
|
PersistResult,
|
|
}
|
|
|
|
#[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))]
|
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
|
pub enum AiTaskStageStatus {
|
|
Pending,
|
|
Running,
|
|
Completed,
|
|
Skipped,
|
|
}
|
|
|
|
#[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))]
|
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
|
pub enum AiResultReferenceKind {
|
|
StorySession,
|
|
StoryEvent,
|
|
CustomWorldProfile,
|
|
QuestRecord,
|
|
RuntimeItemRecord,
|
|
AssetObject,
|
|
}
|
|
|
|
#[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))]
|
|
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
|
pub struct AiTaskStageBlueprint {
|
|
pub stage_kind: AiTaskStageKind,
|
|
pub label: String,
|
|
pub detail: String,
|
|
pub order: u32,
|
|
}
|
|
|
|
#[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))]
|
|
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
|
pub struct AiTaskStageSnapshot {
|
|
pub stage_kind: AiTaskStageKind,
|
|
pub label: String,
|
|
pub detail: String,
|
|
pub order: u32,
|
|
pub status: AiTaskStageStatus,
|
|
pub text_output: Option<String>,
|
|
pub structured_payload_json: Option<String>,
|
|
pub warning_messages: Vec<String>,
|
|
pub started_at_micros: Option<i64>,
|
|
pub completed_at_micros: Option<i64>,
|
|
}
|
|
|
|
#[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))]
|
|
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
|
pub struct AiTaskCreateInput {
|
|
pub task_id: String,
|
|
pub task_kind: AiTaskKind,
|
|
pub owner_user_id: String,
|
|
pub request_label: String,
|
|
pub source_module: String,
|
|
pub source_entity_id: Option<String>,
|
|
pub request_payload_json: Option<String>,
|
|
pub stages: Vec<AiTaskStageBlueprint>,
|
|
pub created_at_micros: i64,
|
|
}
|
|
|
|
#[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))]
|
|
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
|
pub struct AiTaskStartInput {
|
|
pub task_id: String,
|
|
pub started_at_micros: i64,
|
|
}
|
|
|
|
#[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))]
|
|
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
|
pub struct AiTaskStageStartInput {
|
|
pub task_id: String,
|
|
pub stage_kind: AiTaskStageKind,
|
|
pub started_at_micros: i64,
|
|
}
|
|
|
|
#[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))]
|
|
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
|
pub struct AiTaskSnapshot {
|
|
pub task_id: String,
|
|
pub task_kind: AiTaskKind,
|
|
pub owner_user_id: String,
|
|
pub request_label: String,
|
|
pub source_module: String,
|
|
pub source_entity_id: Option<String>,
|
|
pub request_payload_json: Option<String>,
|
|
pub status: AiTaskStatus,
|
|
pub failure_message: Option<String>,
|
|
pub stages: Vec<AiTaskStageSnapshot>,
|
|
pub result_references: Vec<AiResultReferenceSnapshot>,
|
|
pub latest_text_output: Option<String>,
|
|
pub latest_structured_payload_json: Option<String>,
|
|
pub version: u32,
|
|
pub created_at_micros: i64,
|
|
pub started_at_micros: Option<i64>,
|
|
pub completed_at_micros: Option<i64>,
|
|
pub updated_at_micros: i64,
|
|
}
|
|
|
|
#[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))]
|
|
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
|
pub struct AiTextChunkSnapshot {
|
|
pub chunk_id: String,
|
|
pub task_id: String,
|
|
pub stage_kind: AiTaskStageKind,
|
|
pub sequence: u32,
|
|
pub delta_text: String,
|
|
pub created_at_micros: i64,
|
|
}
|
|
|
|
#[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))]
|
|
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
|
pub struct AiTextChunkAppendInput {
|
|
pub task_id: String,
|
|
pub stage_kind: AiTaskStageKind,
|
|
pub sequence: u32,
|
|
pub delta_text: String,
|
|
pub created_at_micros: i64,
|
|
}
|
|
|
|
#[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))]
|
|
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
|
pub struct AiStageCompletionInput {
|
|
pub task_id: String,
|
|
pub stage_kind: AiTaskStageKind,
|
|
pub text_output: Option<String>,
|
|
pub structured_payload_json: Option<String>,
|
|
pub warning_messages: Vec<String>,
|
|
pub completed_at_micros: i64,
|
|
}
|
|
|
|
#[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))]
|
|
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
|
pub struct AiResultReferenceInput {
|
|
pub task_id: String,
|
|
pub reference_kind: AiResultReferenceKind,
|
|
pub reference_id: String,
|
|
pub label: Option<String>,
|
|
pub created_at_micros: i64,
|
|
}
|
|
|
|
#[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))]
|
|
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
|
pub struct AiResultReferenceSnapshot {
|
|
pub result_ref_id: String,
|
|
pub task_id: String,
|
|
pub reference_kind: AiResultReferenceKind,
|
|
pub reference_id: String,
|
|
pub label: Option<String>,
|
|
pub created_at_micros: i64,
|
|
}
|
|
|
|
#[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))]
|
|
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
|
pub struct AiTaskFinishInput {
|
|
pub task_id: String,
|
|
pub completed_at_micros: i64,
|
|
}
|
|
|
|
#[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))]
|
|
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
|
pub struct AiTaskCancelInput {
|
|
pub task_id: String,
|
|
pub completed_at_micros: i64,
|
|
}
|
|
|
|
#[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))]
|
|
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
|
pub struct AiTaskFailureInput {
|
|
pub task_id: String,
|
|
pub failure_message: String,
|
|
pub completed_at_micros: i64,
|
|
}
|
|
|
|
#[cfg_attr(feature = "spacetime-types", derive(SpacetimeType))]
|
|
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
|
pub struct AiTaskProcedureResult {
|
|
pub ok: bool,
|
|
pub task: Option<AiTaskSnapshot>,
|
|
pub text_chunk: Option<AiTextChunkSnapshot>,
|
|
pub error_message: Option<String>,
|
|
}
|
|
|
|
#[derive(Clone, Debug, PartialEq, Eq)]
|
|
pub enum AiTaskFieldError {
|
|
MissingTaskId,
|
|
MissingOwnerUserId,
|
|
MissingRequestLabel,
|
|
MissingSourceModule,
|
|
MissingStageBlueprints,
|
|
DuplicateStageBlueprint,
|
|
MissingReferenceId,
|
|
MissingChunkText,
|
|
InvalidSequence,
|
|
MissingFailureMessage,
|
|
MissingStage,
|
|
InvalidTaskState,
|
|
}
|
|
|
|
#[derive(Clone, Debug, PartialEq, Eq)]
|
|
pub enum AiTaskServiceError {
|
|
Field(AiTaskFieldError),
|
|
TaskAlreadyExists,
|
|
TaskNotFound,
|
|
StageNotFound,
|
|
Store(String),
|
|
}
|
|
|
|
#[derive(Clone, Debug, Default)]
|
|
pub struct InMemoryAiTaskStore {
|
|
inner: Arc<Mutex<InMemoryAiTaskStoreState>>,
|
|
}
|
|
|
|
#[derive(Debug, Default)]
|
|
struct InMemoryAiTaskStoreState {
|
|
tasks: HashMap<String, AiTaskSnapshot>,
|
|
text_chunks: HashMap<String, Vec<AiTextChunkSnapshot>>,
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct AiTaskService {
|
|
store: InMemoryAiTaskStore,
|
|
}
|
|
|
|
impl AiTaskKind {
|
|
// 默认阶段蓝图只冻结通用语义,具体 prompt 内容与供应商策略仍由上层模块决定。
|
|
pub fn default_stage_blueprints(self) -> Vec<AiTaskStageBlueprint> {
|
|
let ordered_kinds = match self {
|
|
Self::StoryGeneration => vec![
|
|
AiTaskStageKind::PreparePrompt,
|
|
AiTaskStageKind::RequestModel,
|
|
AiTaskStageKind::RepairResponse,
|
|
AiTaskStageKind::NormalizeResult,
|
|
],
|
|
Self::CharacterChat | Self::NpcChat | Self::QuestIntent | Self::RuntimeItemIntent => {
|
|
vec![
|
|
AiTaskStageKind::PreparePrompt,
|
|
AiTaskStageKind::RequestModel,
|
|
AiTaskStageKind::NormalizeResult,
|
|
]
|
|
}
|
|
Self::CustomWorldGeneration => vec![
|
|
AiTaskStageKind::PreparePrompt,
|
|
AiTaskStageKind::RequestModel,
|
|
AiTaskStageKind::RepairResponse,
|
|
AiTaskStageKind::NormalizeResult,
|
|
AiTaskStageKind::PersistResult,
|
|
],
|
|
};
|
|
|
|
ordered_kinds
|
|
.into_iter()
|
|
.enumerate()
|
|
.map(|(index, stage_kind)| AiTaskStageBlueprint {
|
|
stage_kind,
|
|
label: stage_kind.default_label().to_string(),
|
|
detail: stage_kind.default_detail().to_string(),
|
|
order: index as u32,
|
|
})
|
|
.collect()
|
|
}
|
|
}
|
|
|
|
impl AiTaskStageKind {
|
|
pub fn as_str(self) -> &'static str {
|
|
match self {
|
|
Self::PreparePrompt => "prepare_prompt",
|
|
Self::RequestModel => "request_model",
|
|
Self::RepairResponse => "repair_response",
|
|
Self::NormalizeResult => "normalize_result",
|
|
Self::PersistResult => "persist_result",
|
|
}
|
|
}
|
|
|
|
pub fn default_label(self) -> &'static str {
|
|
match self {
|
|
Self::PreparePrompt => "整理提示词",
|
|
Self::RequestModel => "请求模型",
|
|
Self::RepairResponse => "修复响应",
|
|
Self::NormalizeResult => "归一结果",
|
|
Self::PersistResult => "回写结果",
|
|
}
|
|
}
|
|
|
|
pub fn default_detail(self) -> &'static str {
|
|
match self {
|
|
Self::PreparePrompt => "整理输入上下文并构建本轮提示词。",
|
|
Self::RequestModel => "向上游模型发起正式推理请求。",
|
|
Self::RepairResponse => "对非严格输出做补救修复或二次编排。",
|
|
Self::NormalizeResult => "把模型输出归一成模块可消费结构。",
|
|
Self::PersistResult => "把结果引用或聚合状态回写到下游模块。",
|
|
}
|
|
}
|
|
}
|
|
|
|
impl AiTaskStatus {
|
|
fn is_terminal(self) -> bool {
|
|
matches!(self, Self::Completed | Self::Failed | Self::Cancelled)
|
|
}
|
|
}
|
|
|
|
impl AiTaskService {
|
|
pub fn new(store: InMemoryAiTaskStore) -> Self {
|
|
Self { store }
|
|
}
|
|
|
|
pub fn create_task(
|
|
&self,
|
|
input: AiTaskCreateInput,
|
|
) -> Result<AiTaskSnapshot, AiTaskServiceError> {
|
|
validate_task_create_input(&input).map_err(AiTaskServiceError::Field)?;
|
|
|
|
let snapshot = AiTaskSnapshot {
|
|
task_id: input.task_id.clone(),
|
|
task_kind: input.task_kind,
|
|
owner_user_id: normalize_required_string(input.owner_user_id).unwrap_or_default(),
|
|
request_label: normalize_required_string(input.request_label).unwrap_or_default(),
|
|
source_module: normalize_required_string(input.source_module).unwrap_or_default(),
|
|
source_entity_id: normalize_optional_text(input.source_entity_id),
|
|
request_payload_json: normalize_optional_text(input.request_payload_json),
|
|
status: AiTaskStatus::Pending,
|
|
failure_message: None,
|
|
stages: input
|
|
.stages
|
|
.into_iter()
|
|
.map(|stage| AiTaskStageSnapshot {
|
|
stage_kind: stage.stage_kind,
|
|
label: normalize_required_string(stage.label).unwrap_or_default(),
|
|
detail: normalize_required_string(stage.detail).unwrap_or_default(),
|
|
order: stage.order,
|
|
status: AiTaskStageStatus::Pending,
|
|
text_output: None,
|
|
structured_payload_json: None,
|
|
warning_messages: Vec::new(),
|
|
started_at_micros: None,
|
|
completed_at_micros: None,
|
|
})
|
|
.collect(),
|
|
result_references: Vec::new(),
|
|
latest_text_output: None,
|
|
latest_structured_payload_json: None,
|
|
version: INITIAL_AI_TASK_VERSION,
|
|
created_at_micros: input.created_at_micros,
|
|
started_at_micros: None,
|
|
completed_at_micros: None,
|
|
updated_at_micros: input.created_at_micros,
|
|
};
|
|
|
|
self.store.insert_task(snapshot)
|
|
}
|
|
|
|
pub fn start_task(
|
|
&self,
|
|
task_id: &str,
|
|
started_at_micros: i64,
|
|
) -> Result<AiTaskSnapshot, AiTaskServiceError> {
|
|
self.store.update_task(task_id, |task| {
|
|
if task.status.is_terminal() {
|
|
return Err(AiTaskServiceError::Field(
|
|
AiTaskFieldError::InvalidTaskState,
|
|
));
|
|
}
|
|
|
|
task.status = AiTaskStatus::Running;
|
|
task.started_at_micros.get_or_insert(started_at_micros);
|
|
task.updated_at_micros = started_at_micros;
|
|
task.version += 1;
|
|
Ok(())
|
|
})
|
|
}
|
|
|
|
pub fn start_stage(
|
|
&self,
|
|
task_id: &str,
|
|
stage_kind: AiTaskStageKind,
|
|
started_at_micros: i64,
|
|
) -> Result<AiTaskSnapshot, AiTaskServiceError> {
|
|
self.store.update_task(task_id, |task| {
|
|
if task.status.is_terminal() {
|
|
return Err(AiTaskServiceError::Field(
|
|
AiTaskFieldError::InvalidTaskState,
|
|
));
|
|
}
|
|
|
|
task.status = AiTaskStatus::Running;
|
|
task.started_at_micros.get_or_insert(started_at_micros);
|
|
let stage = task
|
|
.stages
|
|
.iter_mut()
|
|
.find(|stage| stage.stage_kind == stage_kind)
|
|
.ok_or(AiTaskServiceError::StageNotFound)?;
|
|
stage.status = AiTaskStageStatus::Running;
|
|
stage.started_at_micros.get_or_insert(started_at_micros);
|
|
task.updated_at_micros = started_at_micros;
|
|
task.version += 1;
|
|
Ok(())
|
|
})
|
|
}
|
|
|
|
pub fn append_text_chunk(
|
|
&self,
|
|
task_id: &str,
|
|
stage_kind: AiTaskStageKind,
|
|
sequence: u32,
|
|
delta_text: String,
|
|
created_at_micros: i64,
|
|
) -> Result<(AiTaskSnapshot, AiTextChunkSnapshot), AiTaskServiceError> {
|
|
if delta_text.trim().is_empty() {
|
|
return Err(AiTaskServiceError::Field(
|
|
AiTaskFieldError::MissingChunkText,
|
|
));
|
|
}
|
|
if sequence == 0 {
|
|
return Err(AiTaskServiceError::Field(AiTaskFieldError::InvalidSequence));
|
|
}
|
|
|
|
let chunk = AiTextChunkSnapshot {
|
|
chunk_id: generate_ai_text_chunk_id(created_at_micros, sequence),
|
|
task_id: normalize_required_string(task_id).unwrap_or_default(),
|
|
stage_kind,
|
|
sequence,
|
|
delta_text: normalize_required_string(delta_text).unwrap_or_default(),
|
|
created_at_micros,
|
|
};
|
|
|
|
let task = self.store.append_text_chunk(chunk.clone())?;
|
|
Ok((task, chunk))
|
|
}
|
|
|
|
pub fn complete_stage(
|
|
&self,
|
|
input: AiStageCompletionInput,
|
|
) -> Result<AiTaskSnapshot, AiTaskServiceError> {
|
|
self.store.update_task(&input.task_id, |task| {
|
|
if task.status.is_terminal() {
|
|
return Err(AiTaskServiceError::Field(
|
|
AiTaskFieldError::InvalidTaskState,
|
|
));
|
|
}
|
|
|
|
let stage = task
|
|
.stages
|
|
.iter_mut()
|
|
.find(|stage| stage.stage_kind == input.stage_kind)
|
|
.ok_or(AiTaskServiceError::StageNotFound)?;
|
|
stage.status = AiTaskStageStatus::Completed;
|
|
stage.completed_at_micros = Some(input.completed_at_micros);
|
|
stage.text_output = normalize_optional_text(input.text_output.clone());
|
|
stage.structured_payload_json =
|
|
normalize_optional_text(input.structured_payload_json.clone());
|
|
stage.warning_messages = normalize_string_list(input.warning_messages.clone());
|
|
|
|
task.latest_text_output = stage.text_output.clone();
|
|
task.latest_structured_payload_json = stage.structured_payload_json.clone();
|
|
task.updated_at_micros = input.completed_at_micros;
|
|
task.version += 1;
|
|
Ok(())
|
|
})
|
|
}
|
|
|
|
pub fn attach_result_reference(
|
|
&self,
|
|
task_id: &str,
|
|
reference_kind: AiResultReferenceKind,
|
|
reference_id: String,
|
|
label: Option<String>,
|
|
created_at_micros: i64,
|
|
) -> Result<AiTaskSnapshot, AiTaskServiceError> {
|
|
let Some(reference_id) = normalize_required_string(reference_id) else {
|
|
return Err(AiTaskServiceError::Field(
|
|
AiTaskFieldError::MissingReferenceId,
|
|
));
|
|
};
|
|
|
|
self.store.update_task(task_id, |task| {
|
|
task.result_references.push(AiResultReferenceSnapshot {
|
|
result_ref_id: generate_ai_result_ref_id(created_at_micros),
|
|
task_id: task.task_id.clone(),
|
|
reference_kind,
|
|
reference_id: reference_id.clone(),
|
|
label: normalize_optional_text(label.clone()),
|
|
created_at_micros,
|
|
});
|
|
task.updated_at_micros = created_at_micros;
|
|
task.version += 1;
|
|
Ok(())
|
|
})
|
|
}
|
|
|
|
pub fn complete_task(
|
|
&self,
|
|
task_id: &str,
|
|
completed_at_micros: i64,
|
|
) -> Result<AiTaskSnapshot, AiTaskServiceError> {
|
|
self.store.update_task(task_id, |task| {
|
|
if task.status.is_terminal() {
|
|
return Err(AiTaskServiceError::Field(
|
|
AiTaskFieldError::InvalidTaskState,
|
|
));
|
|
}
|
|
|
|
task.status = AiTaskStatus::Completed;
|
|
task.completed_at_micros = Some(completed_at_micros);
|
|
task.updated_at_micros = completed_at_micros;
|
|
task.version += 1;
|
|
Ok(())
|
|
})
|
|
}
|
|
|
|
pub fn fail_task(
|
|
&self,
|
|
task_id: &str,
|
|
failure_message: String,
|
|
completed_at_micros: i64,
|
|
) -> Result<AiTaskSnapshot, AiTaskServiceError> {
|
|
let Some(failure_message) = normalize_required_string(failure_message) else {
|
|
return Err(AiTaskServiceError::Field(
|
|
AiTaskFieldError::MissingFailureMessage,
|
|
));
|
|
};
|
|
|
|
self.store.update_task(task_id, |task| {
|
|
if task.status.is_terminal() {
|
|
return Err(AiTaskServiceError::Field(
|
|
AiTaskFieldError::InvalidTaskState,
|
|
));
|
|
}
|
|
|
|
task.status = AiTaskStatus::Failed;
|
|
task.failure_message = Some(failure_message.clone());
|
|
task.completed_at_micros = Some(completed_at_micros);
|
|
task.updated_at_micros = completed_at_micros;
|
|
task.version += 1;
|
|
Ok(())
|
|
})
|
|
}
|
|
|
|
pub fn cancel_task(
|
|
&self,
|
|
task_id: &str,
|
|
completed_at_micros: i64,
|
|
) -> Result<AiTaskSnapshot, AiTaskServiceError> {
|
|
self.store.update_task(task_id, |task| {
|
|
if task.status.is_terminal() {
|
|
return Err(AiTaskServiceError::Field(
|
|
AiTaskFieldError::InvalidTaskState,
|
|
));
|
|
}
|
|
|
|
task.status = AiTaskStatus::Cancelled;
|
|
task.completed_at_micros = Some(completed_at_micros);
|
|
task.updated_at_micros = completed_at_micros;
|
|
task.version += 1;
|
|
Ok(())
|
|
})
|
|
}
|
|
|
|
pub fn get_task(&self, task_id: &str) -> Result<AiTaskSnapshot, AiTaskServiceError> {
|
|
self.store.get_task(task_id)
|
|
}
|
|
}
|
|
|
|
impl InMemoryAiTaskStore {
|
|
fn insert_task(&self, task: AiTaskSnapshot) -> Result<AiTaskSnapshot, AiTaskServiceError> {
|
|
let mut state = self
|
|
.inner
|
|
.lock()
|
|
.map_err(|_| AiTaskServiceError::Store("AI 任务仓储锁已中毒".to_string()))?;
|
|
|
|
if state.tasks.contains_key(&task.task_id) {
|
|
return Err(AiTaskServiceError::TaskAlreadyExists);
|
|
}
|
|
|
|
state.text_chunks.insert(task.task_id.clone(), Vec::new());
|
|
state.tasks.insert(task.task_id.clone(), task.clone());
|
|
Ok(task)
|
|
}
|
|
|
|
fn update_task<F>(
|
|
&self,
|
|
task_id: &str,
|
|
mut apply: F,
|
|
) -> Result<AiTaskSnapshot, AiTaskServiceError>
|
|
where
|
|
F: FnMut(&mut AiTaskSnapshot) -> Result<(), AiTaskServiceError>,
|
|
{
|
|
let mut state = self
|
|
.inner
|
|
.lock()
|
|
.map_err(|_| AiTaskServiceError::Store("AI 任务仓储锁已中毒".to_string()))?;
|
|
let task = state
|
|
.tasks
|
|
.get_mut(task_id.trim())
|
|
.ok_or(AiTaskServiceError::TaskNotFound)?;
|
|
apply(task)?;
|
|
Ok(task.clone())
|
|
}
|
|
|
|
fn append_text_chunk(
|
|
&self,
|
|
chunk: AiTextChunkSnapshot,
|
|
) -> Result<AiTaskSnapshot, AiTaskServiceError> {
|
|
let mut state = self
|
|
.inner
|
|
.lock()
|
|
.map_err(|_| AiTaskServiceError::Store("AI 任务仓储锁已中毒".to_string()))?;
|
|
{
|
|
let task = state
|
|
.tasks
|
|
.get_mut(&chunk.task_id)
|
|
.ok_or(AiTaskServiceError::TaskNotFound)?;
|
|
if task.status.is_terminal() {
|
|
return Err(AiTaskServiceError::Field(
|
|
AiTaskFieldError::InvalidTaskState,
|
|
));
|
|
}
|
|
|
|
let stage = task
|
|
.stages
|
|
.iter_mut()
|
|
.find(|stage| stage.stage_kind == chunk.stage_kind)
|
|
.ok_or(AiTaskServiceError::StageNotFound)?;
|
|
if stage.status == AiTaskStageStatus::Pending {
|
|
stage.status = AiTaskStageStatus::Running;
|
|
stage.started_at_micros = Some(chunk.created_at_micros);
|
|
}
|
|
|
|
task.status = AiTaskStatus::Running;
|
|
task.started_at_micros
|
|
.get_or_insert(chunk.created_at_micros);
|
|
}
|
|
|
|
let chunks = state
|
|
.text_chunks
|
|
.get_mut(&chunk.task_id)
|
|
.ok_or(AiTaskServiceError::TaskNotFound)?;
|
|
chunks.push(chunk.clone());
|
|
chunks.sort_by_key(|value| value.sequence);
|
|
|
|
let aggregated_text = chunks
|
|
.iter()
|
|
.filter(|value| value.stage_kind == chunk.stage_kind)
|
|
.map(|value| value.delta_text.as_str())
|
|
.collect::<Vec<_>>()
|
|
.join("");
|
|
let normalized_output = if aggregated_text.trim().is_empty() {
|
|
None
|
|
} else {
|
|
Some(aggregated_text)
|
|
};
|
|
|
|
let task = state
|
|
.tasks
|
|
.get_mut(&chunk.task_id)
|
|
.ok_or(AiTaskServiceError::TaskNotFound)?;
|
|
let stage = task
|
|
.stages
|
|
.iter_mut()
|
|
.find(|stage| stage.stage_kind == chunk.stage_kind)
|
|
.ok_or(AiTaskServiceError::StageNotFound)?;
|
|
stage.text_output = normalized_output.clone();
|
|
task.latest_text_output = normalized_output;
|
|
task.updated_at_micros = chunk.created_at_micros;
|
|
task.version += 1;
|
|
Ok(task.clone())
|
|
}
|
|
|
|
fn get_task(&self, task_id: &str) -> Result<AiTaskSnapshot, AiTaskServiceError> {
|
|
let state = self
|
|
.inner
|
|
.lock()
|
|
.map_err(|_| AiTaskServiceError::Store("AI 任务仓储锁已中毒".to_string()))?;
|
|
state
|
|
.tasks
|
|
.get(task_id.trim())
|
|
.cloned()
|
|
.ok_or(AiTaskServiceError::TaskNotFound)
|
|
}
|
|
}
|
|
|
|
pub fn validate_task_create_input(input: &AiTaskCreateInput) -> Result<(), AiTaskFieldError> {
|
|
if normalize_required_string(&input.task_id).is_none() {
|
|
return Err(AiTaskFieldError::MissingTaskId);
|
|
}
|
|
if normalize_required_string(&input.owner_user_id).is_none() {
|
|
return Err(AiTaskFieldError::MissingOwnerUserId);
|
|
}
|
|
if normalize_required_string(&input.request_label).is_none() {
|
|
return Err(AiTaskFieldError::MissingRequestLabel);
|
|
}
|
|
if normalize_required_string(&input.source_module).is_none() {
|
|
return Err(AiTaskFieldError::MissingSourceModule);
|
|
}
|
|
if input.stages.is_empty() {
|
|
return Err(AiTaskFieldError::MissingStageBlueprints);
|
|
}
|
|
|
|
let mut seen = HashMap::new();
|
|
for stage in &input.stages {
|
|
if normalize_required_string(&stage.label).is_none()
|
|
|| normalize_required_string(&stage.detail).is_none()
|
|
{
|
|
return Err(AiTaskFieldError::MissingStageBlueprints);
|
|
}
|
|
|
|
if seen.insert(stage.stage_kind, true).is_some() {
|
|
return Err(AiTaskFieldError::DuplicateStageBlueprint);
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub fn generate_ai_task_id(seed_micros: i64) -> String {
|
|
build_prefixed_seed_id(AI_TASK_ID_PREFIX, seed_micros)
|
|
}
|
|
|
|
pub fn generate_ai_task_stage_id(task_id: &str, stage_kind: AiTaskStageKind) -> String {
|
|
format!(
|
|
"{}{}_{}",
|
|
AI_TASK_STAGE_ID_PREFIX,
|
|
task_id.trim(),
|
|
stage_kind.as_str()
|
|
)
|
|
}
|
|
|
|
pub fn generate_ai_result_ref_id(seed_micros: i64) -> String {
|
|
build_prefixed_seed_id(AI_RESULT_REF_ID_PREFIX, seed_micros)
|
|
}
|
|
|
|
pub fn generate_ai_text_chunk_id(seed_micros: i64, sequence: u32) -> String {
|
|
format!("{}{seed_micros:x}_{sequence:x}", AI_TEXT_CHUNK_ID_PREFIX)
|
|
}
|
|
|
|
pub fn normalize_optional_text(value: Option<String>) -> Option<String> {
|
|
normalize_shared_optional_string(value)
|
|
}
|
|
|
|
pub fn normalize_string_list(values: Vec<String>) -> Vec<String> {
|
|
normalize_shared_string_list(values)
|
|
}
|
|
|
|
impl fmt::Display for AiTaskFieldError {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
match self {
|
|
Self::MissingTaskId => f.write_str("ai_task.task_id 不能为空"),
|
|
Self::MissingOwnerUserId => f.write_str("ai_task.owner_user_id 不能为空"),
|
|
Self::MissingRequestLabel => f.write_str("ai_task.request_label 不能为空"),
|
|
Self::MissingSourceModule => f.write_str("ai_task.source_module 不能为空"),
|
|
Self::MissingStageBlueprints => f.write_str("ai_task.stages 至少需要一个有效阶段"),
|
|
Self::DuplicateStageBlueprint => f.write_str("ai_task.stages 不能包含重复阶段"),
|
|
Self::MissingReferenceId => f.write_str("ai_result_reference.reference_id 不能为空"),
|
|
Self::MissingChunkText => f.write_str("ai_text_chunk.delta_text 不能为空"),
|
|
Self::InvalidSequence => f.write_str("ai_text_chunk.sequence 必须大于 0"),
|
|
Self::MissingFailureMessage => f.write_str("ai_task.failure_message 不能为空"),
|
|
Self::MissingStage => f.write_str("ai_task.stage 不存在"),
|
|
Self::InvalidTaskState => f.write_str("当前 ai_task 状态不允许执行该操作"),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Error for AiTaskFieldError {}
|
|
|
|
impl fmt::Display for AiTaskServiceError {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
match self {
|
|
Self::Field(error) => write!(f, "{error}"),
|
|
Self::TaskAlreadyExists => f.write_str("ai_task 已存在,不能重复创建"),
|
|
Self::TaskNotFound => f.write_str("ai_task 不存在"),
|
|
Self::StageNotFound => f.write_str("ai_task.stage 不存在"),
|
|
Self::Store(message) => f.write_str(message),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Error for AiTaskServiceError {}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
fn build_service() -> AiTaskService {
|
|
AiTaskService::new(InMemoryAiTaskStore::default())
|
|
}
|
|
|
|
fn build_create_input(task_kind: AiTaskKind) -> AiTaskCreateInput {
|
|
AiTaskCreateInput {
|
|
task_id: generate_ai_task_id(1_713_680_000_000_000),
|
|
task_kind,
|
|
owner_user_id: "user_001".to_string(),
|
|
request_label: "首轮故事生成".to_string(),
|
|
source_module: "story".to_string(),
|
|
source_entity_id: Some("storysess_001".to_string()),
|
|
request_payload_json: Some("{\"scene\":\"camp\"}".to_string()),
|
|
stages: task_kind.default_stage_blueprints(),
|
|
created_at_micros: 1_713_680_000_000_000,
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn default_stage_blueprints_match_story_baseline() {
|
|
let stages = AiTaskKind::StoryGeneration.default_stage_blueprints();
|
|
|
|
assert_eq!(stages.len(), 4);
|
|
assert_eq!(stages[0].stage_kind, AiTaskStageKind::PreparePrompt);
|
|
assert_eq!(stages[1].stage_kind, AiTaskStageKind::RequestModel);
|
|
assert_eq!(stages[2].stage_kind, AiTaskStageKind::RepairResponse);
|
|
assert_eq!(stages[3].stage_kind, AiTaskStageKind::NormalizeResult);
|
|
}
|
|
|
|
#[test]
|
|
fn create_task_rejects_duplicate_stage_blueprints() {
|
|
let mut input = build_create_input(AiTaskKind::StoryGeneration);
|
|
input.stages.push(AiTaskStageBlueprint {
|
|
stage_kind: AiTaskStageKind::PreparePrompt,
|
|
label: "重复阶段".to_string(),
|
|
detail: "重复阶段".to_string(),
|
|
order: 99,
|
|
});
|
|
|
|
let error = validate_task_create_input(&input).expect_err("duplicate stages should fail");
|
|
assert_eq!(error, AiTaskFieldError::DuplicateStageBlueprint);
|
|
}
|
|
|
|
#[test]
|
|
fn generate_ai_task_stage_id_contains_task_and_stage_slug() {
|
|
let stage_id = generate_ai_task_stage_id("aitask_demo", AiTaskStageKind::NormalizeResult);
|
|
|
|
assert_eq!(stage_id, "aistage_aitask_demo_normalize_result");
|
|
}
|
|
|
|
#[test]
|
|
fn create_and_start_task_updates_status() {
|
|
let service = build_service();
|
|
let created = service
|
|
.create_task(build_create_input(AiTaskKind::QuestIntent))
|
|
.expect("task should create");
|
|
let started = service
|
|
.start_task(&created.task_id, created.created_at_micros + 1)
|
|
.expect("task should start");
|
|
|
|
assert_eq!(created.status, AiTaskStatus::Pending);
|
|
assert_eq!(started.status, AiTaskStatus::Running);
|
|
assert_eq!(
|
|
started.started_at_micros,
|
|
Some(created.created_at_micros + 1)
|
|
);
|
|
assert_eq!(started.version, INITIAL_AI_TASK_VERSION + 1);
|
|
}
|
|
|
|
#[test]
|
|
fn append_text_chunk_aggregates_stream_output_by_stage() {
|
|
let service = build_service();
|
|
let task = service
|
|
.create_task(build_create_input(AiTaskKind::CharacterChat))
|
|
.expect("task should create");
|
|
service
|
|
.start_stage(
|
|
&task.task_id,
|
|
AiTaskStageKind::RequestModel,
|
|
task.created_at_micros + 10,
|
|
)
|
|
.expect("stage should start");
|
|
|
|
let (after_first, _) = service
|
|
.append_text_chunk(
|
|
&task.task_id,
|
|
AiTaskStageKind::RequestModel,
|
|
1,
|
|
"你".to_string(),
|
|
task.created_at_micros + 20,
|
|
)
|
|
.expect("first chunk should append");
|
|
let (after_second, second_chunk) = service
|
|
.append_text_chunk(
|
|
&task.task_id,
|
|
AiTaskStageKind::RequestModel,
|
|
2,
|
|
"好。".to_string(),
|
|
task.created_at_micros + 30,
|
|
)
|
|
.expect("second chunk should append");
|
|
|
|
assert_eq!(after_first.latest_text_output.as_deref(), Some("你"));
|
|
assert_eq!(after_second.latest_text_output.as_deref(), Some("你好。"));
|
|
assert_eq!(second_chunk.sequence, 2);
|
|
}
|
|
|
|
#[test]
|
|
fn complete_stage_updates_latest_outputs() {
|
|
let service = build_service();
|
|
let task = service
|
|
.create_task(build_create_input(AiTaskKind::StoryGeneration))
|
|
.expect("task should create");
|
|
|
|
let completed = service
|
|
.complete_stage(AiStageCompletionInput {
|
|
task_id: task.task_id.clone(),
|
|
stage_kind: AiTaskStageKind::NormalizeResult,
|
|
text_output: Some("营地前的篝火重新亮了起来。".to_string()),
|
|
structured_payload_json: Some("{\"choices\":3}".to_string()),
|
|
warning_messages: vec!["使用了 fallback 选项池".to_string()],
|
|
completed_at_micros: task.created_at_micros + 50,
|
|
})
|
|
.expect("stage should complete");
|
|
|
|
let stage = completed
|
|
.stages
|
|
.iter()
|
|
.find(|stage| stage.stage_kind == AiTaskStageKind::NormalizeResult)
|
|
.expect("normalize stage should exist");
|
|
assert_eq!(stage.status, AiTaskStageStatus::Completed);
|
|
assert_eq!(
|
|
completed.latest_text_output.as_deref(),
|
|
Some("营地前的篝火重新亮了起来。")
|
|
);
|
|
assert_eq!(
|
|
completed.latest_structured_payload_json.as_deref(),
|
|
Some("{\"choices\":3}")
|
|
);
|
|
assert_eq!(stage.warning_messages, vec!["使用了 fallback 选项池"]);
|
|
}
|
|
|
|
#[test]
|
|
fn attach_result_reference_appends_binding() {
|
|
let service = build_service();
|
|
let task = service
|
|
.create_task(build_create_input(AiTaskKind::CustomWorldGeneration))
|
|
.expect("task should create");
|
|
|
|
let updated = service
|
|
.attach_result_reference(
|
|
&task.task_id,
|
|
AiResultReferenceKind::CustomWorldProfile,
|
|
"profile_001".to_string(),
|
|
Some("主世界档案".to_string()),
|
|
task.created_at_micros + 10,
|
|
)
|
|
.expect("reference should attach");
|
|
|
|
assert_eq!(updated.result_references.len(), 1);
|
|
assert_eq!(
|
|
updated.result_references[0].reference_kind,
|
|
AiResultReferenceKind::CustomWorldProfile
|
|
);
|
|
assert_eq!(updated.result_references[0].reference_id, "profile_001");
|
|
}
|
|
|
|
#[test]
|
|
fn fail_and_cancel_task_move_into_terminal_states() {
|
|
let service = build_service();
|
|
let first = service
|
|
.create_task(build_create_input(AiTaskKind::NpcChat))
|
|
.expect("task should create");
|
|
let failed = service
|
|
.fail_task(
|
|
&first.task_id,
|
|
"上游模型超时".to_string(),
|
|
first.created_at_micros + 10,
|
|
)
|
|
.expect("task should fail");
|
|
|
|
assert_eq!(failed.status, AiTaskStatus::Failed);
|
|
assert_eq!(failed.failure_message.as_deref(), Some("上游模型超时"));
|
|
|
|
let second = service
|
|
.create_task(AiTaskCreateInput {
|
|
task_id: generate_ai_task_id(1_713_680_000_000_999),
|
|
..build_create_input(AiTaskKind::RuntimeItemIntent)
|
|
})
|
|
.expect("second task should create");
|
|
let cancelled = service
|
|
.cancel_task(&second.task_id, second.created_at_micros + 20)
|
|
.expect("task should cancel");
|
|
|
|
assert_eq!(cancelled.status, AiTaskStatus::Cancelled);
|
|
assert_eq!(
|
|
cancelled.completed_at_micros,
|
|
Some(second.created_at_micros + 20)
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn complete_task_marks_terminal_success() {
|
|
let service = build_service();
|
|
let task = service
|
|
.create_task(build_create_input(AiTaskKind::QuestIntent))
|
|
.expect("task should create");
|
|
|
|
let completed = service
|
|
.complete_task(&task.task_id, task.created_at_micros + 100)
|
|
.expect("task should complete");
|
|
|
|
assert_eq!(completed.status, AiTaskStatus::Completed);
|
|
assert_eq!(
|
|
completed.completed_at_micros,
|
|
Some(task.created_at_micros + 100)
|
|
);
|
|
}
|
|
}
|