Files
Genarrative/server-rs/crates/api-server/src/ai_tasks.rs

795 lines
27 KiB
Rust

use axum::{
Json,
extract::{Extension, Path, State},
http::StatusCode,
response::{IntoResponse, Response},
};
use module_ai::{
AiResultReferenceInput, AiResultReferenceKind, AiStageCompletionInput, AiTaskCancelInput,
AiTaskCreateInput, AiTaskFailureInput, AiTaskFinishInput, AiTaskKind, AiTaskStageBlueprint,
AiTaskStageKind, AiTaskStageStartInput, AiTaskStartInput, AiTextChunkAppendInput,
generate_ai_task_id,
};
use serde_json::{Value, json};
use shared_contracts::ai::{
AiResultReferencePayload, AiTaskAcceptedResponse, AiTaskMutationResponse, AiTaskPayload,
AiTaskStagePayload, AiTextChunkPayload, AppendAiTextChunkRequest,
AttachAiResultReferenceRequest, CompleteAiStageRequest, CreateAiTaskRequest, FailAiTaskRequest,
};
use spacetime_client::{AiTaskMutationRecord, SpacetimeClientError};
use crate::{
api_response::json_success_body, auth::AuthenticatedAccessToken, http_error::AppError,
request_context::RequestContext, state::AppState,
};
pub async fn create_ai_task(
State(state): State<AppState>,
Extension(request_context): Extension<RequestContext>,
Extension(authenticated): Extension<AuthenticatedAccessToken>,
Json(payload): Json<CreateAiTaskRequest>,
) -> Result<Json<Value>, Response> {
let now_micros = current_utc_micros();
let task_kind = parse_ai_task_kind_strict(&payload.task_kind).ok_or_else(|| {
ai_tasks_error_response(
&request_context,
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
"provider": "ai-task",
"message": "taskKind 非法",
})),
)
})?;
let stages = build_stage_blueprints(task_kind, payload.stage_kinds, &request_context)?;
let owner_user_id = authenticated.claims().user_id().to_string();
let result = state
.spacetime_client()
.create_ai_task(AiTaskCreateInput {
task_id: generate_ai_task_id(now_micros),
task_kind,
owner_user_id,
request_label: payload.request_label,
source_module: payload.source_module,
source_entity_id: payload.source_entity_id,
request_payload_json: payload.request_payload_json,
stages,
created_at_micros: now_micros,
})
.await
.map_err(|error| {
ai_tasks_error_response(&request_context, map_ai_task_client_error(error))
})?;
Ok(json_success_body(
Some(&request_context),
build_ai_task_mutation_response(result),
))
}
pub async fn start_ai_task(
State(state): State<AppState>,
Path(task_id): Path<String>,
Extension(request_context): Extension<RequestContext>,
Extension(_authenticated): Extension<AuthenticatedAccessToken>,
) -> Result<Response, Response> {
state
.spacetime_client()
.start_ai_task(AiTaskStartInput {
task_id: task_id.clone(),
started_at_micros: current_utc_micros(),
})
.await
.map_err(|error| {
ai_tasks_error_response(&request_context, map_ai_task_client_error(error))
})?;
Ok(ai_task_accepted_response(
&request_context,
AiTaskAcceptedResponse {
accepted: true,
task_id,
action: "start_task".to_string(),
stage_kind: None,
},
))
}
pub async fn start_ai_task_stage(
State(state): State<AppState>,
Path((task_id, stage_kind_text)): Path<(String, String)>,
Extension(request_context): Extension<RequestContext>,
Extension(_authenticated): Extension<AuthenticatedAccessToken>,
) -> Result<Response, Response> {
let stage_kind = parse_ai_task_stage_kind_strict(&stage_kind_text).ok_or_else(|| {
ai_tasks_error_response(
&request_context,
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
"provider": "ai-task-stage",
"message": "stageKind 非法",
})),
)
})?;
state
.spacetime_client()
.start_ai_task_stage(AiTaskStageStartInput {
task_id: task_id.clone(),
stage_kind,
started_at_micros: current_utc_micros(),
})
.await
.map_err(|error| {
ai_tasks_error_response(&request_context, map_ai_task_client_error(error))
})?;
Ok(ai_task_accepted_response(
&request_context,
AiTaskAcceptedResponse {
accepted: true,
task_id,
action: "start_stage".to_string(),
stage_kind: Some(stage_kind.as_str().to_string()),
},
))
}
pub async fn append_ai_text_chunk(
State(state): State<AppState>,
Path(task_id): Path<String>,
Extension(request_context): Extension<RequestContext>,
Extension(_authenticated): Extension<AuthenticatedAccessToken>,
Json(payload): Json<AppendAiTextChunkRequest>,
) -> Result<Json<Value>, Response> {
let stage_kind = parse_ai_task_stage_kind_strict(&payload.stage_kind).ok_or_else(|| {
ai_tasks_error_response(
&request_context,
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
"provider": "ai-task-stage",
"message": "stageKind 非法",
})),
)
})?;
let result = state
.spacetime_client()
.append_ai_text_chunk(AiTextChunkAppendInput {
task_id,
stage_kind,
sequence: payload.sequence,
delta_text: payload.delta_text,
created_at_micros: current_utc_micros(),
})
.await
.map_err(|error| {
ai_tasks_error_response(&request_context, map_ai_task_client_error(error))
})?;
Ok(json_success_body(
Some(&request_context),
build_ai_task_mutation_response(result),
))
}
pub async fn complete_ai_stage(
State(state): State<AppState>,
Path((task_id, stage_kind_text)): Path<(String, String)>,
Extension(request_context): Extension<RequestContext>,
Extension(_authenticated): Extension<AuthenticatedAccessToken>,
Json(payload): Json<CompleteAiStageRequest>,
) -> Result<Json<Value>, Response> {
let stage_kind = parse_ai_task_stage_kind_strict(&stage_kind_text).ok_or_else(|| {
ai_tasks_error_response(
&request_context,
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
"provider": "ai-task-stage",
"message": "stageKind 非法",
})),
)
})?;
let result = state
.spacetime_client()
.complete_ai_stage(AiStageCompletionInput {
task_id,
stage_kind,
text_output: payload.text_output,
structured_payload_json: payload.structured_payload_json,
warning_messages: payload.warning_messages,
completed_at_micros: current_utc_micros(),
})
.await
.map_err(|error| {
ai_tasks_error_response(&request_context, map_ai_task_client_error(error))
})?;
Ok(json_success_body(
Some(&request_context),
build_ai_task_mutation_response(result),
))
}
pub async fn attach_ai_result_reference(
State(state): State<AppState>,
Path(task_id): Path<String>,
Extension(request_context): Extension<RequestContext>,
Extension(_authenticated): Extension<AuthenticatedAccessToken>,
Json(payload): Json<AttachAiResultReferenceRequest>,
) -> Result<Json<Value>, Response> {
let reference_kind = parse_ai_result_reference_kind_strict(&payload.reference_kind)
.ok_or_else(|| {
ai_tasks_error_response(
&request_context,
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
"provider": "ai-task-reference",
"message": "referenceKind 非法",
})),
)
})?;
let result = state
.spacetime_client()
.attach_ai_result_reference(AiResultReferenceInput {
task_id,
reference_kind,
reference_id: payload.reference_id,
label: payload.label,
created_at_micros: current_utc_micros(),
})
.await
.map_err(|error| {
ai_tasks_error_response(&request_context, map_ai_task_client_error(error))
})?;
Ok(json_success_body(
Some(&request_context),
build_ai_task_mutation_response(result),
))
}
pub async fn complete_ai_task(
State(state): State<AppState>,
Path(task_id): Path<String>,
Extension(request_context): Extension<RequestContext>,
Extension(_authenticated): Extension<AuthenticatedAccessToken>,
) -> Result<Json<Value>, Response> {
let result = state
.spacetime_client()
.complete_ai_task(AiTaskFinishInput {
task_id,
completed_at_micros: current_utc_micros(),
})
.await
.map_err(|error| {
ai_tasks_error_response(&request_context, map_ai_task_client_error(error))
})?;
Ok(json_success_body(
Some(&request_context),
build_ai_task_mutation_response(result),
))
}
pub async fn fail_ai_task(
State(state): State<AppState>,
Path(task_id): Path<String>,
Extension(request_context): Extension<RequestContext>,
Extension(_authenticated): Extension<AuthenticatedAccessToken>,
Json(payload): Json<FailAiTaskRequest>,
) -> Result<Json<Value>, Response> {
let result = state
.spacetime_client()
.fail_ai_task(AiTaskFailureInput {
task_id,
failure_message: payload.failure_message,
completed_at_micros: current_utc_micros(),
})
.await
.map_err(|error| {
ai_tasks_error_response(&request_context, map_ai_task_client_error(error))
})?;
Ok(json_success_body(
Some(&request_context),
build_ai_task_mutation_response(result),
))
}
pub async fn cancel_ai_task(
State(state): State<AppState>,
Path(task_id): Path<String>,
Extension(request_context): Extension<RequestContext>,
Extension(_authenticated): Extension<AuthenticatedAccessToken>,
) -> Result<Json<Value>, Response> {
let result = state
.spacetime_client()
.cancel_ai_task(AiTaskCancelInput {
task_id,
completed_at_micros: current_utc_micros(),
})
.await
.map_err(|error| {
ai_tasks_error_response(&request_context, map_ai_task_client_error(error))
})?;
Ok(json_success_body(
Some(&request_context),
build_ai_task_mutation_response(result),
))
}
fn build_stage_blueprints(
task_kind: AiTaskKind,
stage_kinds: Vec<String>,
request_context: &RequestContext,
) -> Result<Vec<AiTaskStageBlueprint>, Response> {
if stage_kinds.is_empty() {
return Ok(task_kind.default_stage_blueprints());
}
stage_kinds
.into_iter()
.enumerate()
.map(|(index, stage_kind_text)| {
let stage_kind =
parse_ai_task_stage_kind_strict(&stage_kind_text).ok_or_else(|| {
ai_tasks_error_response(
request_context,
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
"provider": "ai-task-stage",
"message": format!("stageKinds[{index}] 非法"),
})),
)
})?;
Ok(AiTaskStageBlueprint {
stage_kind,
label: stage_kind.default_label().to_string(),
detail: stage_kind.default_detail().to_string(),
order: index as u32,
})
})
.collect()
}
fn build_ai_task_mutation_response(record: AiTaskMutationRecord) -> AiTaskMutationResponse {
AiTaskMutationResponse {
ai_task: build_ai_task_payload(record.task),
ai_text_chunk: record.text_chunk.map(build_ai_text_chunk_payload),
}
}
fn build_ai_task_payload(record: spacetime_client::AiTaskRecord) -> AiTaskPayload {
AiTaskPayload {
task_id: record.task_id,
task_kind: record.task_kind,
owner_user_id: record.owner_user_id,
request_label: record.request_label,
source_module: record.source_module,
source_entity_id: record.source_entity_id,
request_payload_json: record.request_payload_json,
status: record.status,
failure_message: record.failure_message,
stages: record
.stages
.into_iter()
.map(build_ai_task_stage_payload)
.collect(),
result_references: record
.result_references
.into_iter()
.map(build_ai_result_reference_payload)
.collect(),
latest_text_output: record.latest_text_output,
latest_structured_payload_json: record.latest_structured_payload_json,
version: record.version,
created_at: record.created_at,
started_at: record.started_at,
completed_at: record.completed_at,
updated_at: record.updated_at,
}
}
fn build_ai_task_stage_payload(record: spacetime_client::AiTaskStageRecord) -> AiTaskStagePayload {
AiTaskStagePayload {
stage_kind: record.stage_kind,
label: record.label,
detail: record.detail,
order: record.order,
status: record.status,
text_output: record.text_output,
structured_payload_json: record.structured_payload_json,
warning_messages: record.warning_messages,
started_at: record.started_at,
completed_at: record.completed_at,
}
}
fn build_ai_result_reference_payload(
record: spacetime_client::AiResultReferenceRecord,
) -> AiResultReferencePayload {
AiResultReferencePayload {
result_ref_id: record.result_ref_id,
task_id: record.task_id,
reference_kind: record.reference_kind,
reference_id: record.reference_id,
label: record.label,
created_at: record.created_at,
}
}
fn build_ai_text_chunk_payload(record: spacetime_client::AiTextChunkRecord) -> AiTextChunkPayload {
AiTextChunkPayload {
chunk_id: record.chunk_id,
task_id: record.task_id,
stage_kind: record.stage_kind,
sequence: record.sequence,
delta_text: record.delta_text,
created_at: record.created_at,
}
}
fn parse_ai_task_kind_strict(value: &str) -> Option<AiTaskKind> {
match value.trim() {
"story_generation" => Some(AiTaskKind::StoryGeneration),
"character_chat" => Some(AiTaskKind::CharacterChat),
"npc_chat" => Some(AiTaskKind::NpcChat),
"custom_world_generation" => Some(AiTaskKind::CustomWorldGeneration),
"quest_intent" => Some(AiTaskKind::QuestIntent),
"runtime_item_intent" => Some(AiTaskKind::RuntimeItemIntent),
_ => None,
}
}
fn parse_ai_task_stage_kind_strict(value: &str) -> Option<AiTaskStageKind> {
match value.trim() {
"prepare_prompt" => Some(AiTaskStageKind::PreparePrompt),
"request_model" => Some(AiTaskStageKind::RequestModel),
"repair_response" => Some(AiTaskStageKind::RepairResponse),
"normalize_result" => Some(AiTaskStageKind::NormalizeResult),
"persist_result" => Some(AiTaskStageKind::PersistResult),
_ => None,
}
}
fn parse_ai_result_reference_kind_strict(value: &str) -> Option<AiResultReferenceKind> {
match value.trim() {
"story_session" => Some(AiResultReferenceKind::StorySession),
"story_event" => Some(AiResultReferenceKind::StoryEvent),
"custom_world_profile" => Some(AiResultReferenceKind::CustomWorldProfile),
"quest_record" => Some(AiResultReferenceKind::QuestRecord),
"runtime_item_record" => Some(AiResultReferenceKind::RuntimeItemRecord),
"asset_object" => Some(AiResultReferenceKind::AssetObject),
_ => None,
}
}
fn map_ai_task_client_error(error: SpacetimeClientError) -> AppError {
let status = match &error {
SpacetimeClientError::Runtime(_) => StatusCode::BAD_REQUEST,
_ => StatusCode::BAD_GATEWAY,
};
AppError::from_status(status).with_details(json!({
"provider": "spacetimedb",
"message": error.to_string(),
}))
}
fn ai_tasks_error_response(request_context: &RequestContext, error: AppError) -> Response {
error.into_response_with_context(Some(request_context))
}
fn ai_task_accepted_response(
request_context: &RequestContext,
payload: AiTaskAcceptedResponse,
) -> Response {
let mut response = json_success_body(Some(request_context), payload).into_response();
*response.status_mut() = StatusCode::ACCEPTED;
response
}
fn current_utc_micros() -> i64 {
use std::time::{SystemTime, UNIX_EPOCH};
let duration = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system clock should be after unix epoch");
i64::try_from(duration.as_micros()).expect("current unix micros should fit in i64")
}
#[cfg(test)]
mod tests {
use axum::{
Router,
body::Body,
http::{Request, StatusCode},
};
use http_body_util::BodyExt;
use platform_auth::{
AccessTokenClaims, AccessTokenClaimsInput, AuthProvider, BindingStatus, sign_access_token,
};
use serde_json::{Value, json};
use time::OffsetDateTime;
use tower::ServiceExt;
use crate::{app::build_router, config::AppConfig, state::AppState};
#[tokio::test]
async fn create_ai_task_requires_authentication() {
let app = build_router(AppState::new(AppConfig::default()).expect("state should build"));
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/api/ai/tasks")
.header("content-type", "application/json")
.body(Body::from(
json!({
"taskKind": "story_generation",
"requestLabel": "营地开场",
"sourceModule": "story"
})
.to_string(),
))
.expect("request should build"),
)
.await
.expect("request should succeed");
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn create_ai_task_returns_bad_gateway_when_spacetime_not_published() {
let state = seed_authenticated_state().await;
let token = issue_access_token(&state);
let app = build_router(state);
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/api/ai/tasks")
.header("authorization", format!("Bearer {token}"))
.header("content-type", "application/json")
.header("x-genarrative-response-envelope", "v1")
.body(Body::from(
json!({
"taskKind": "npc_chat",
"requestLabel": "试探问话",
"sourceModule": "npc"
})
.to_string(),
))
.expect("request should build"),
)
.await
.expect("request should succeed");
assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
let body = response
.into_body()
.collect()
.await
.expect("body should collect")
.to_bytes();
let payload: Value =
serde_json::from_slice(&body).expect("response body should be valid json");
assert_eq!(payload["ok"], Value::Bool(false));
assert_eq!(
payload["error"]["details"]["provider"],
Value::String("spacetimedb".to_string())
);
}
#[tokio::test]
async fn start_ai_task_requires_authentication() {
let app = build_router(AppState::new(AppConfig::default()).expect("state should build"));
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/api/ai/tasks/aitask_001/start")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("request should succeed");
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn start_ai_task_returns_bad_gateway_when_spacetime_not_published() {
let state = seed_authenticated_state().await;
let token = issue_access_token(&state);
let app = build_router(state);
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/api/ai/tasks/aitask_001/start")
.header("authorization", format!("Bearer {token}"))
.header("x-genarrative-response-envelope", "v1")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("request should succeed");
assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
let body = response
.into_body()
.collect()
.await
.expect("body should collect")
.to_bytes();
let payload: Value =
serde_json::from_slice(&body).expect("response body should be valid json");
assert_eq!(payload["ok"], Value::Bool(false));
assert_eq!(
payload["error"]["details"]["provider"],
Value::String("spacetimedb".to_string())
);
}
#[tokio::test]
async fn ai_task_mutation_routes_require_authentication() {
let app = build_router(AppState::new(AppConfig::default()).expect("state should build"));
for route in ai_task_mutation_route_cases() {
let (status, _) = post_ai_task_route(app.clone(), route.uri, None, route.body).await;
assert_eq!(status, StatusCode::UNAUTHORIZED, "{}", route.uri);
}
}
#[tokio::test]
async fn ai_task_mutation_routes_return_bad_gateway_when_spacetime_not_published() {
let state = seed_authenticated_state().await;
let token = issue_access_token(&state);
let app = build_router(state);
for route in ai_task_mutation_route_cases() {
let (status, payload) =
post_ai_task_route(app.clone(), route.uri, Some(&token), route.body).await;
assert_eq!(status, StatusCode::BAD_GATEWAY, "{}", route.uri);
assert_eq!(
payload["error"]["details"]["provider"],
Value::String("spacetimedb".to_string()),
"{}",
route.uri
);
}
}
struct AiTaskRouteCase {
uri: &'static str,
body: Option<Value>,
}
fn ai_task_mutation_route_cases() -> Vec<AiTaskRouteCase> {
vec![
AiTaskRouteCase {
uri: "/api/ai/tasks/aitask_001/stages/request_model/start",
body: None,
},
AiTaskRouteCase {
uri: "/api/ai/tasks/aitask_001/chunks",
body: Some(json!({
"stageKind": "request_model",
"sequence": 1,
"deltaText": "你听见远处的铃声。"
})),
},
AiTaskRouteCase {
uri: "/api/ai/tasks/aitask_001/stages/request_model/complete",
body: Some(json!({
"textOutput": "你听见远处的铃声。",
"structuredPayloadJson": "{\"scene\":\"camp\"}",
"warningMessages": []
})),
},
AiTaskRouteCase {
uri: "/api/ai/tasks/aitask_001/references",
body: Some(json!({
"referenceKind": "story_event",
"referenceId": "storyevt_001",
"label": "营地开场"
})),
},
AiTaskRouteCase {
uri: "/api/ai/tasks/aitask_001/complete",
body: None,
},
AiTaskRouteCase {
uri: "/api/ai/tasks/aitask_001/fail",
body: Some(json!({
"failureMessage": "模型返回内容为空"
})),
},
AiTaskRouteCase {
uri: "/api/ai/tasks/aitask_001/cancel",
body: None,
},
]
}
async fn post_ai_task_route(
app: Router,
uri: &str,
bearer_token: Option<&str>,
body: Option<Value>,
) -> (StatusCode, Value) {
let mut request = Request::builder()
.method("POST")
.uri(uri)
.header("x-genarrative-response-envelope", "v1");
if let Some(token) = bearer_token {
request = request.header("authorization", format!("Bearer {token}"));
}
let body = if let Some(payload) = body {
request = request.header("content-type", "application/json");
Body::from(payload.to_string())
} else {
Body::empty()
};
let response = app
.oneshot(request.body(body).expect("request should build"))
.await
.expect("request should succeed");
let status = response.status();
let body = response
.into_body()
.collect()
.await
.expect("body should collect")
.to_bytes();
let payload = if body.is_empty() {
Value::Null
} else {
serde_json::from_slice(&body).expect("response body should be valid json")
};
(status, payload)
}
async fn seed_authenticated_state() -> AppState {
let state = AppState::new(AppConfig::default()).expect("state should build");
state
.seed_test_phone_user_with_password("13800138100", "secret123")
.await
.id;
state
}
fn issue_access_token(state: &AppState) -> String {
let claims = AccessTokenClaims::from_input(
AccessTokenClaimsInput {
user_id: "user_00000001".to_string(),
session_id: "sess_ai_tasks".to_string(),
provider: AuthProvider::Password,
roles: vec!["user".to_string()],
token_version: 2,
phone_verified: true,
binding_status: BindingStatus::Active,
display_name: Some("AI 任务用户".to_string()),
},
state.auth_jwt_config(),
OffsetDateTime::now_utc(),
)
.expect("claims should build");
sign_access_token(&claims, state.auth_jwt_config()).expect("token should sign")
}
}