Files
Genarrative/server-rs/crates/spacetime-module/src/ai/tasks.rs

331 lines
10 KiB
Rust

use crate::*;
use module_ai::{INITIAL_AI_TASK_VERSION, normalize_optional_text, validate_task_create_input};
#[spacetimedb::table(
accessor = ai_task,
index(accessor = by_ai_task_owner_user_id, btree(columns = [owner_user_id])),
index(accessor = by_ai_task_status, btree(columns = [status])),
index(accessor = by_ai_task_kind, btree(columns = [task_kind]))
)]
pub struct AiTask {
#[primary_key]
pub(crate) task_id: String,
pub(crate) task_kind: AiTaskKind,
pub(crate) owner_user_id: String,
pub(crate) request_label: String,
pub(crate) source_module: String,
pub(crate) source_entity_id: Option<String>,
pub(crate) request_payload_json: Option<String>,
pub(crate) status: AiTaskStatus,
pub(crate) failure_message: Option<String>,
pub(crate) latest_text_output: Option<String>,
pub(crate) latest_structured_payload_json: Option<String>,
pub(crate) version: u32,
pub(crate) created_at: Timestamp,
pub(crate) started_at: Option<Timestamp>,
pub(crate) completed_at: Option<Timestamp>,
pub(crate) updated_at: Timestamp,
}
// AI 任务当前先固定成 private 真相表,后续由 Axum / platform-llm 再往外包一层 HTTP 与 SSE 协议。
#[spacetimedb::reducer]
pub fn create_ai_task(ctx: &ReducerContext, input: AiTaskCreateInput) -> Result<(), String> {
create_ai_task_tx(ctx, input).map(|_| ())
}
#[spacetimedb::procedure]
pub fn create_ai_task_and_return(
ctx: &mut ProcedureContext,
input: AiTaskCreateInput,
) -> AiTaskProcedureResult {
match ctx.try_with_tx(|tx| create_ai_task_tx(tx, input.clone())) {
Ok(task) => AiTaskProcedureResult {
ok: true,
task: Some(task),
text_chunk: None,
error_message: None,
},
Err(message) => AiTaskProcedureResult {
ok: false,
task: None,
text_chunk: None,
error_message: Some(message),
},
}
}
#[spacetimedb::reducer]
pub fn start_ai_task(ctx: &ReducerContext, input: AiTaskStartInput) -> Result<(), String> {
start_ai_task_tx(ctx, input).map(|_| ())
}
#[spacetimedb::procedure]
pub fn complete_ai_task_and_return(
ctx: &mut ProcedureContext,
input: AiTaskFinishInput,
) -> AiTaskProcedureResult {
match ctx.try_with_tx(|tx| complete_ai_task_tx(tx, input.clone())) {
Ok(task) => AiTaskProcedureResult {
ok: true,
task: Some(task),
text_chunk: None,
error_message: None,
},
Err(message) => AiTaskProcedureResult {
ok: false,
task: None,
text_chunk: None,
error_message: Some(message),
},
}
}
#[spacetimedb::procedure]
pub fn fail_ai_task_and_return(
ctx: &mut ProcedureContext,
input: AiTaskFailureInput,
) -> AiTaskProcedureResult {
match ctx.try_with_tx(|tx| fail_ai_task_tx(tx, input.clone())) {
Ok(task) => AiTaskProcedureResult {
ok: true,
task: Some(task),
text_chunk: None,
error_message: None,
},
Err(message) => AiTaskProcedureResult {
ok: false,
task: None,
text_chunk: None,
error_message: Some(message),
},
}
}
#[spacetimedb::procedure]
pub fn cancel_ai_task_and_return(
ctx: &mut ProcedureContext,
input: AiTaskCancelInput,
) -> AiTaskProcedureResult {
match ctx.try_with_tx(|tx| cancel_ai_task_tx(tx, input.clone())) {
Ok(task) => AiTaskProcedureResult {
ok: true,
task: Some(task),
text_chunk: None,
error_message: None,
},
Err(message) => AiTaskProcedureResult {
ok: false,
task: None,
text_chunk: None,
error_message: Some(message),
},
}
}
fn create_ai_task_tx(
ctx: &ReducerContext,
input: AiTaskCreateInput,
) -> Result<AiTaskSnapshot, String> {
validate_task_create_input(&input).map_err(|error| error.to_string())?;
if ctx.db.ai_task().task_id().find(&input.task_id).is_some() {
return Err("ai_task.task_id 已存在".to_string());
}
let task_snapshot = build_ai_task_snapshot_from_create_input(&input);
ctx.db.ai_task().insert(build_ai_task_row(&task_snapshot));
replace_ai_task_stages(ctx, &task_snapshot.task_id, &task_snapshot.stages);
emit_ai_task_event(
ctx,
&task_snapshot,
AiTaskEventKind::TaskCreated,
None,
None,
None,
task_snapshot.created_at_micros,
);
get_ai_task_snapshot_tx(ctx, &task_snapshot.task_id)
}
fn start_ai_task_tx(
ctx: &ReducerContext,
input: AiTaskStartInput,
) -> Result<AiTaskSnapshot, String> {
let mut snapshot = get_ai_task_snapshot_tx(ctx, &input.task_id)?;
ensure_ai_task_can_transition(snapshot.status)?;
snapshot.status = AiTaskStatus::Running;
if snapshot.started_at_micros.is_none() {
snapshot.started_at_micros = Some(input.started_at_micros);
}
snapshot.updated_at_micros = input.started_at_micros;
snapshot.version += 1;
persist_ai_task_snapshot(ctx, &snapshot)?;
emit_ai_task_event(
ctx,
&snapshot,
AiTaskEventKind::TaskStatusChanged,
None,
None,
None,
input.started_at_micros,
);
Ok(snapshot)
}
fn complete_ai_task_tx(
ctx: &ReducerContext,
input: AiTaskFinishInput,
) -> Result<AiTaskSnapshot, String> {
let mut snapshot = get_ai_task_snapshot_tx(ctx, &input.task_id)?;
ensure_ai_task_can_transition(snapshot.status)?;
snapshot.status = AiTaskStatus::Completed;
snapshot.completed_at_micros = Some(input.completed_at_micros);
snapshot.updated_at_micros = input.completed_at_micros;
snapshot.version += 1;
persist_ai_task_snapshot(ctx, &snapshot)?;
emit_ai_task_event(
ctx,
&snapshot,
AiTaskEventKind::TaskStatusChanged,
None,
None,
None,
input.completed_at_micros,
);
Ok(snapshot)
}
fn fail_ai_task_tx(
ctx: &ReducerContext,
input: AiTaskFailureInput,
) -> Result<AiTaskSnapshot, String> {
let failure_message = input.failure_message.trim().to_string();
if failure_message.is_empty() {
return Err("ai_task.failure_message 不能为空".to_string());
}
let mut snapshot = get_ai_task_snapshot_tx(ctx, &input.task_id)?;
ensure_ai_task_can_transition(snapshot.status)?;
snapshot.status = AiTaskStatus::Failed;
snapshot.failure_message = Some(failure_message);
snapshot.completed_at_micros = Some(input.completed_at_micros);
snapshot.updated_at_micros = input.completed_at_micros;
snapshot.version += 1;
persist_ai_task_snapshot(ctx, &snapshot)?;
emit_ai_task_event(
ctx,
&snapshot,
AiTaskEventKind::TaskStatusChanged,
None,
None,
None,
input.completed_at_micros,
);
Ok(snapshot)
}
fn cancel_ai_task_tx(
ctx: &ReducerContext,
input: AiTaskCancelInput,
) -> Result<AiTaskSnapshot, String> {
let mut snapshot = get_ai_task_snapshot_tx(ctx, &input.task_id)?;
ensure_ai_task_can_transition(snapshot.status)?;
snapshot.status = AiTaskStatus::Cancelled;
snapshot.completed_at_micros = Some(input.completed_at_micros);
snapshot.updated_at_micros = input.completed_at_micros;
snapshot.version += 1;
persist_ai_task_snapshot(ctx, &snapshot)?;
emit_ai_task_event(
ctx,
&snapshot,
AiTaskEventKind::TaskStatusChanged,
None,
None,
None,
input.completed_at_micros,
);
Ok(snapshot)
}
pub(crate) fn get_ai_task_snapshot_tx(
ctx: &ReducerContext,
task_id: &str,
) -> Result<AiTaskSnapshot, String> {
let row = ctx
.db
.ai_task()
.task_id()
.find(&task_id.trim().to_string())
.ok_or_else(|| "ai_task 不存在".to_string())?;
Ok(build_ai_task_snapshot_from_row(ctx, &row))
}
pub(crate) fn persist_ai_task_snapshot(
ctx: &ReducerContext,
snapshot: &AiTaskSnapshot,
) -> Result<(), String> {
ctx.db.ai_task().task_id().delete(&snapshot.task_id);
ctx.db.ai_task().insert(build_ai_task_row(snapshot));
replace_ai_task_stages(ctx, &snapshot.task_id, &snapshot.stages);
Ok(())
}
pub(crate) fn ensure_ai_task_can_transition(status: AiTaskStatus) -> Result<(), String> {
if matches!(
status,
AiTaskStatus::Completed | AiTaskStatus::Failed | AiTaskStatus::Cancelled
) {
Err("当前 ai_task 状态不允许执行该操作".to_string())
} else {
Ok(())
}
}
fn build_ai_task_snapshot_from_create_input(input: &AiTaskCreateInput) -> AiTaskSnapshot {
AiTaskSnapshot {
task_id: input.task_id.trim().to_string(),
task_kind: input.task_kind,
owner_user_id: input.owner_user_id.trim().to_string(),
request_label: input.request_label.trim().to_string(),
source_module: input.source_module.trim().to_string(),
source_entity_id: normalize_optional_text(input.source_entity_id.clone()),
request_payload_json: normalize_optional_text(input.request_payload_json.clone()),
status: AiTaskStatus::Pending,
failure_message: None,
stages: input
.stages
.iter()
.map(|stage| AiTaskStageSnapshot {
stage_kind: stage.stage_kind,
label: stage.label.trim().to_string(),
detail: stage.detail.trim().to_string(),
order: stage.order,
status: AiTaskStageStatus::Pending,
text_output: None,
structured_payload_json: None,
warning_messages: Vec::new(),
started_at_micros: None,
completed_at_micros: None,
})
.collect(),
result_references: Vec::new(),
latest_text_output: None,
latest_structured_payload_json: None,
version: INITIAL_AI_TASK_VERSION,
created_at_micros: input.created_at_micros,
started_at_micros: None,
completed_at_micros: None,
updated_at_micros: input.created_at_micros,
}
}