331 lines
10 KiB
Rust
331 lines
10 KiB
Rust
use crate::*;
|
|
use module_ai::{INITIAL_AI_TASK_VERSION, normalize_optional_text, validate_task_create_input};
|
|
|
|
#[spacetimedb::table(
|
|
accessor = ai_task,
|
|
index(accessor = by_ai_task_owner_user_id, btree(columns = [owner_user_id])),
|
|
index(accessor = by_ai_task_status, btree(columns = [status])),
|
|
index(accessor = by_ai_task_kind, btree(columns = [task_kind]))
|
|
)]
|
|
pub struct AiTask {
|
|
#[primary_key]
|
|
pub(crate) task_id: String,
|
|
pub(crate) task_kind: AiTaskKind,
|
|
pub(crate) owner_user_id: String,
|
|
pub(crate) request_label: String,
|
|
pub(crate) source_module: String,
|
|
pub(crate) source_entity_id: Option<String>,
|
|
pub(crate) request_payload_json: Option<String>,
|
|
pub(crate) status: AiTaskStatus,
|
|
pub(crate) failure_message: Option<String>,
|
|
pub(crate) latest_text_output: Option<String>,
|
|
pub(crate) latest_structured_payload_json: Option<String>,
|
|
pub(crate) version: u32,
|
|
pub(crate) created_at: Timestamp,
|
|
pub(crate) started_at: Option<Timestamp>,
|
|
pub(crate) completed_at: Option<Timestamp>,
|
|
pub(crate) updated_at: Timestamp,
|
|
}
|
|
|
|
// AI 任务当前先固定成 private 真相表,后续由 Axum / platform-llm 再往外包一层 HTTP 与 SSE 协议。
|
|
#[spacetimedb::reducer]
|
|
pub fn create_ai_task(ctx: &ReducerContext, input: AiTaskCreateInput) -> Result<(), String> {
|
|
create_ai_task_tx(ctx, input).map(|_| ())
|
|
}
|
|
|
|
#[spacetimedb::procedure]
|
|
pub fn create_ai_task_and_return(
|
|
ctx: &mut ProcedureContext,
|
|
input: AiTaskCreateInput,
|
|
) -> AiTaskProcedureResult {
|
|
match ctx.try_with_tx(|tx| create_ai_task_tx(tx, input.clone())) {
|
|
Ok(task) => AiTaskProcedureResult {
|
|
ok: true,
|
|
task: Some(task),
|
|
text_chunk: None,
|
|
error_message: None,
|
|
},
|
|
Err(message) => AiTaskProcedureResult {
|
|
ok: false,
|
|
task: None,
|
|
text_chunk: None,
|
|
error_message: Some(message),
|
|
},
|
|
}
|
|
}
|
|
|
|
#[spacetimedb::reducer]
|
|
pub fn start_ai_task(ctx: &ReducerContext, input: AiTaskStartInput) -> Result<(), String> {
|
|
start_ai_task_tx(ctx, input).map(|_| ())
|
|
}
|
|
|
|
#[spacetimedb::procedure]
|
|
pub fn complete_ai_task_and_return(
|
|
ctx: &mut ProcedureContext,
|
|
input: AiTaskFinishInput,
|
|
) -> AiTaskProcedureResult {
|
|
match ctx.try_with_tx(|tx| complete_ai_task_tx(tx, input.clone())) {
|
|
Ok(task) => AiTaskProcedureResult {
|
|
ok: true,
|
|
task: Some(task),
|
|
text_chunk: None,
|
|
error_message: None,
|
|
},
|
|
Err(message) => AiTaskProcedureResult {
|
|
ok: false,
|
|
task: None,
|
|
text_chunk: None,
|
|
error_message: Some(message),
|
|
},
|
|
}
|
|
}
|
|
|
|
#[spacetimedb::procedure]
|
|
pub fn fail_ai_task_and_return(
|
|
ctx: &mut ProcedureContext,
|
|
input: AiTaskFailureInput,
|
|
) -> AiTaskProcedureResult {
|
|
match ctx.try_with_tx(|tx| fail_ai_task_tx(tx, input.clone())) {
|
|
Ok(task) => AiTaskProcedureResult {
|
|
ok: true,
|
|
task: Some(task),
|
|
text_chunk: None,
|
|
error_message: None,
|
|
},
|
|
Err(message) => AiTaskProcedureResult {
|
|
ok: false,
|
|
task: None,
|
|
text_chunk: None,
|
|
error_message: Some(message),
|
|
},
|
|
}
|
|
}
|
|
|
|
#[spacetimedb::procedure]
|
|
pub fn cancel_ai_task_and_return(
|
|
ctx: &mut ProcedureContext,
|
|
input: AiTaskCancelInput,
|
|
) -> AiTaskProcedureResult {
|
|
match ctx.try_with_tx(|tx| cancel_ai_task_tx(tx, input.clone())) {
|
|
Ok(task) => AiTaskProcedureResult {
|
|
ok: true,
|
|
task: Some(task),
|
|
text_chunk: None,
|
|
error_message: None,
|
|
},
|
|
Err(message) => AiTaskProcedureResult {
|
|
ok: false,
|
|
task: None,
|
|
text_chunk: None,
|
|
error_message: Some(message),
|
|
},
|
|
}
|
|
}
|
|
|
|
fn create_ai_task_tx(
|
|
ctx: &ReducerContext,
|
|
input: AiTaskCreateInput,
|
|
) -> Result<AiTaskSnapshot, String> {
|
|
validate_task_create_input(&input).map_err(|error| error.to_string())?;
|
|
|
|
if ctx.db.ai_task().task_id().find(&input.task_id).is_some() {
|
|
return Err("ai_task.task_id 已存在".to_string());
|
|
}
|
|
|
|
let task_snapshot = build_ai_task_snapshot_from_create_input(&input);
|
|
ctx.db.ai_task().insert(build_ai_task_row(&task_snapshot));
|
|
replace_ai_task_stages(ctx, &task_snapshot.task_id, &task_snapshot.stages);
|
|
emit_ai_task_event(
|
|
ctx,
|
|
&task_snapshot,
|
|
AiTaskEventKind::TaskCreated,
|
|
None,
|
|
None,
|
|
None,
|
|
task_snapshot.created_at_micros,
|
|
);
|
|
|
|
get_ai_task_snapshot_tx(ctx, &task_snapshot.task_id)
|
|
}
|
|
|
|
fn start_ai_task_tx(
|
|
ctx: &ReducerContext,
|
|
input: AiTaskStartInput,
|
|
) -> Result<AiTaskSnapshot, String> {
|
|
let mut snapshot = get_ai_task_snapshot_tx(ctx, &input.task_id)?;
|
|
ensure_ai_task_can_transition(snapshot.status)?;
|
|
|
|
snapshot.status = AiTaskStatus::Running;
|
|
if snapshot.started_at_micros.is_none() {
|
|
snapshot.started_at_micros = Some(input.started_at_micros);
|
|
}
|
|
snapshot.updated_at_micros = input.started_at_micros;
|
|
snapshot.version += 1;
|
|
|
|
persist_ai_task_snapshot(ctx, &snapshot)?;
|
|
emit_ai_task_event(
|
|
ctx,
|
|
&snapshot,
|
|
AiTaskEventKind::TaskStatusChanged,
|
|
None,
|
|
None,
|
|
None,
|
|
input.started_at_micros,
|
|
);
|
|
Ok(snapshot)
|
|
}
|
|
|
|
fn complete_ai_task_tx(
|
|
ctx: &ReducerContext,
|
|
input: AiTaskFinishInput,
|
|
) -> Result<AiTaskSnapshot, String> {
|
|
let mut snapshot = get_ai_task_snapshot_tx(ctx, &input.task_id)?;
|
|
ensure_ai_task_can_transition(snapshot.status)?;
|
|
|
|
snapshot.status = AiTaskStatus::Completed;
|
|
snapshot.completed_at_micros = Some(input.completed_at_micros);
|
|
snapshot.updated_at_micros = input.completed_at_micros;
|
|
snapshot.version += 1;
|
|
|
|
persist_ai_task_snapshot(ctx, &snapshot)?;
|
|
emit_ai_task_event(
|
|
ctx,
|
|
&snapshot,
|
|
AiTaskEventKind::TaskStatusChanged,
|
|
None,
|
|
None,
|
|
None,
|
|
input.completed_at_micros,
|
|
);
|
|
Ok(snapshot)
|
|
}
|
|
|
|
fn fail_ai_task_tx(
|
|
ctx: &ReducerContext,
|
|
input: AiTaskFailureInput,
|
|
) -> Result<AiTaskSnapshot, String> {
|
|
let failure_message = input.failure_message.trim().to_string();
|
|
if failure_message.is_empty() {
|
|
return Err("ai_task.failure_message 不能为空".to_string());
|
|
}
|
|
|
|
let mut snapshot = get_ai_task_snapshot_tx(ctx, &input.task_id)?;
|
|
ensure_ai_task_can_transition(snapshot.status)?;
|
|
|
|
snapshot.status = AiTaskStatus::Failed;
|
|
snapshot.failure_message = Some(failure_message);
|
|
snapshot.completed_at_micros = Some(input.completed_at_micros);
|
|
snapshot.updated_at_micros = input.completed_at_micros;
|
|
snapshot.version += 1;
|
|
|
|
persist_ai_task_snapshot(ctx, &snapshot)?;
|
|
emit_ai_task_event(
|
|
ctx,
|
|
&snapshot,
|
|
AiTaskEventKind::TaskStatusChanged,
|
|
None,
|
|
None,
|
|
None,
|
|
input.completed_at_micros,
|
|
);
|
|
Ok(snapshot)
|
|
}
|
|
|
|
fn cancel_ai_task_tx(
|
|
ctx: &ReducerContext,
|
|
input: AiTaskCancelInput,
|
|
) -> Result<AiTaskSnapshot, String> {
|
|
let mut snapshot = get_ai_task_snapshot_tx(ctx, &input.task_id)?;
|
|
ensure_ai_task_can_transition(snapshot.status)?;
|
|
|
|
snapshot.status = AiTaskStatus::Cancelled;
|
|
snapshot.completed_at_micros = Some(input.completed_at_micros);
|
|
snapshot.updated_at_micros = input.completed_at_micros;
|
|
snapshot.version += 1;
|
|
|
|
persist_ai_task_snapshot(ctx, &snapshot)?;
|
|
emit_ai_task_event(
|
|
ctx,
|
|
&snapshot,
|
|
AiTaskEventKind::TaskStatusChanged,
|
|
None,
|
|
None,
|
|
None,
|
|
input.completed_at_micros,
|
|
);
|
|
Ok(snapshot)
|
|
}
|
|
|
|
pub(crate) fn get_ai_task_snapshot_tx(
|
|
ctx: &ReducerContext,
|
|
task_id: &str,
|
|
) -> Result<AiTaskSnapshot, String> {
|
|
let row = ctx
|
|
.db
|
|
.ai_task()
|
|
.task_id()
|
|
.find(&task_id.trim().to_string())
|
|
.ok_or_else(|| "ai_task 不存在".to_string())?;
|
|
|
|
Ok(build_ai_task_snapshot_from_row(ctx, &row))
|
|
}
|
|
|
|
pub(crate) fn persist_ai_task_snapshot(
|
|
ctx: &ReducerContext,
|
|
snapshot: &AiTaskSnapshot,
|
|
) -> Result<(), String> {
|
|
ctx.db.ai_task().task_id().delete(&snapshot.task_id);
|
|
ctx.db.ai_task().insert(build_ai_task_row(snapshot));
|
|
replace_ai_task_stages(ctx, &snapshot.task_id, &snapshot.stages);
|
|
Ok(())
|
|
}
|
|
|
|
pub(crate) fn ensure_ai_task_can_transition(status: AiTaskStatus) -> Result<(), String> {
|
|
if matches!(
|
|
status,
|
|
AiTaskStatus::Completed | AiTaskStatus::Failed | AiTaskStatus::Cancelled
|
|
) {
|
|
Err("当前 ai_task 状态不允许执行该操作".to_string())
|
|
} else {
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
fn build_ai_task_snapshot_from_create_input(input: &AiTaskCreateInput) -> AiTaskSnapshot {
|
|
AiTaskSnapshot {
|
|
task_id: input.task_id.trim().to_string(),
|
|
task_kind: input.task_kind,
|
|
owner_user_id: input.owner_user_id.trim().to_string(),
|
|
request_label: input.request_label.trim().to_string(),
|
|
source_module: input.source_module.trim().to_string(),
|
|
source_entity_id: normalize_optional_text(input.source_entity_id.clone()),
|
|
request_payload_json: normalize_optional_text(input.request_payload_json.clone()),
|
|
status: AiTaskStatus::Pending,
|
|
failure_message: None,
|
|
stages: input
|
|
.stages
|
|
.iter()
|
|
.map(|stage| AiTaskStageSnapshot {
|
|
stage_kind: stage.stage_kind,
|
|
label: stage.label.trim().to_string(),
|
|
detail: stage.detail.trim().to_string(),
|
|
order: stage.order,
|
|
status: AiTaskStageStatus::Pending,
|
|
text_output: None,
|
|
structured_payload_json: None,
|
|
warning_messages: Vec::new(),
|
|
started_at_micros: None,
|
|
completed_at_micros: None,
|
|
})
|
|
.collect(),
|
|
result_references: Vec::new(),
|
|
latest_text_output: None,
|
|
latest_structured_payload_json: None,
|
|
version: INITIAL_AI_TASK_VERSION,
|
|
created_at_micros: input.created_at_micros,
|
|
started_at_micros: None,
|
|
completed_at_micros: None,
|
|
updated_at_micros: input.created_at_micros,
|
|
}
|
|
}
|