use axum::{ Json, extract::{Extension, Path, State}, http::StatusCode, response::{IntoResponse, Response}, }; use module_ai::{ AiResultReferenceInput, AiResultReferenceKind, AiStageCompletionInput, AiTaskCancelInput, AiTaskCreateInput, AiTaskFailureInput, AiTaskFinishInput, AiTaskKind, AiTaskStageBlueprint, AiTaskStageKind, AiTaskStageStartInput, AiTaskStartInput, AiTextChunkAppendInput, generate_ai_task_id, }; use serde_json::{Value, json}; use shared_contracts::ai::{ AiResultReferencePayload, AiTaskAcceptedResponse, AiTaskMutationResponse, AiTaskPayload, AiTaskStagePayload, AiTextChunkPayload, AppendAiTextChunkRequest, AttachAiResultReferenceRequest, CompleteAiStageRequest, CreateAiTaskRequest, FailAiTaskRequest, }; use spacetime_client::{AiTaskMutationRecord, SpacetimeClientError}; use crate::{ api_response::json_success_body, auth::AuthenticatedAccessToken, http_error::AppError, request_context::RequestContext, state::AppState, }; pub async fn create_ai_task( State(state): State, Extension(request_context): Extension, Extension(authenticated): Extension, Json(payload): Json, ) -> Result, Response> { let now_micros = current_utc_micros(); let task_kind = parse_ai_task_kind_strict(&payload.task_kind).ok_or_else(|| { ai_tasks_error_response( &request_context, AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({ "provider": "ai-task", "message": "taskKind 非法", })), ) })?; let stages = build_stage_blueprints(task_kind, payload.stage_kinds, &request_context)?; let owner_user_id = authenticated.claims().user_id().to_string(); let result = state .spacetime_client() .create_ai_task(AiTaskCreateInput { task_id: generate_ai_task_id(now_micros), task_kind, owner_user_id, request_label: payload.request_label, source_module: payload.source_module, source_entity_id: payload.source_entity_id, request_payload_json: payload.request_payload_json, stages, created_at_micros: now_micros, }) .await .map_err(|error| { ai_tasks_error_response(&request_context, map_ai_task_client_error(error)) })?; Ok(json_success_body( Some(&request_context), build_ai_task_mutation_response(result), )) } pub async fn start_ai_task( State(state): State, Path(task_id): Path, Extension(request_context): Extension, Extension(_authenticated): Extension, ) -> Result { state .spacetime_client() .start_ai_task(AiTaskStartInput { task_id: task_id.clone(), started_at_micros: current_utc_micros(), }) .await .map_err(|error| { ai_tasks_error_response(&request_context, map_ai_task_client_error(error)) })?; Ok(ai_task_accepted_response( &request_context, AiTaskAcceptedResponse { accepted: true, task_id, action: "start_task".to_string(), stage_kind: None, }, )) } pub async fn start_ai_task_stage( State(state): State, Path((task_id, stage_kind_text)): Path<(String, String)>, Extension(request_context): Extension, Extension(_authenticated): Extension, ) -> Result { let stage_kind = parse_ai_task_stage_kind_strict(&stage_kind_text).ok_or_else(|| { ai_tasks_error_response( &request_context, AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({ "provider": "ai-task-stage", "message": "stageKind 非法", })), ) })?; state .spacetime_client() .start_ai_task_stage(AiTaskStageStartInput { task_id: task_id.clone(), stage_kind, started_at_micros: current_utc_micros(), }) .await .map_err(|error| { ai_tasks_error_response(&request_context, map_ai_task_client_error(error)) })?; Ok(ai_task_accepted_response( &request_context, AiTaskAcceptedResponse { accepted: true, task_id, action: "start_stage".to_string(), stage_kind: Some(stage_kind.as_str().to_string()), }, )) } pub async fn append_ai_text_chunk( State(state): State, Path(task_id): Path, Extension(request_context): Extension, Extension(_authenticated): Extension, Json(payload): Json, ) -> Result, Response> { let stage_kind = parse_ai_task_stage_kind_strict(&payload.stage_kind).ok_or_else(|| { ai_tasks_error_response( &request_context, AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({ "provider": "ai-task-stage", "message": "stageKind 非法", })), ) })?; let result = state .spacetime_client() .append_ai_text_chunk(AiTextChunkAppendInput { task_id, stage_kind, sequence: payload.sequence, delta_text: payload.delta_text, created_at_micros: current_utc_micros(), }) .await .map_err(|error| { ai_tasks_error_response(&request_context, map_ai_task_client_error(error)) })?; Ok(json_success_body( Some(&request_context), build_ai_task_mutation_response(result), )) } pub async fn complete_ai_stage( State(state): State, Path((task_id, stage_kind_text)): Path<(String, String)>, Extension(request_context): Extension, Extension(_authenticated): Extension, Json(payload): Json, ) -> Result, Response> { let stage_kind = parse_ai_task_stage_kind_strict(&stage_kind_text).ok_or_else(|| { ai_tasks_error_response( &request_context, AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({ "provider": "ai-task-stage", "message": "stageKind 非法", })), ) })?; let result = state .spacetime_client() .complete_ai_stage(AiStageCompletionInput { task_id, stage_kind, text_output: payload.text_output, structured_payload_json: payload.structured_payload_json, warning_messages: payload.warning_messages, completed_at_micros: current_utc_micros(), }) .await .map_err(|error| { ai_tasks_error_response(&request_context, map_ai_task_client_error(error)) })?; Ok(json_success_body( Some(&request_context), build_ai_task_mutation_response(result), )) } pub async fn attach_ai_result_reference( State(state): State, Path(task_id): Path, Extension(request_context): Extension, Extension(_authenticated): Extension, Json(payload): Json, ) -> Result, Response> { let reference_kind = parse_ai_result_reference_kind_strict(&payload.reference_kind) .ok_or_else(|| { ai_tasks_error_response( &request_context, AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({ "provider": "ai-task-reference", "message": "referenceKind 非法", })), ) })?; let result = state .spacetime_client() .attach_ai_result_reference(AiResultReferenceInput { task_id, reference_kind, reference_id: payload.reference_id, label: payload.label, created_at_micros: current_utc_micros(), }) .await .map_err(|error| { ai_tasks_error_response(&request_context, map_ai_task_client_error(error)) })?; Ok(json_success_body( Some(&request_context), build_ai_task_mutation_response(result), )) } pub async fn complete_ai_task( State(state): State, Path(task_id): Path, Extension(request_context): Extension, Extension(_authenticated): Extension, ) -> Result, Response> { let result = state .spacetime_client() .complete_ai_task(AiTaskFinishInput { task_id, completed_at_micros: current_utc_micros(), }) .await .map_err(|error| { ai_tasks_error_response(&request_context, map_ai_task_client_error(error)) })?; Ok(json_success_body( Some(&request_context), build_ai_task_mutation_response(result), )) } pub async fn fail_ai_task( State(state): State, Path(task_id): Path, Extension(request_context): Extension, Extension(_authenticated): Extension, Json(payload): Json, ) -> Result, Response> { let result = state .spacetime_client() .fail_ai_task(AiTaskFailureInput { task_id, failure_message: payload.failure_message, completed_at_micros: current_utc_micros(), }) .await .map_err(|error| { ai_tasks_error_response(&request_context, map_ai_task_client_error(error)) })?; Ok(json_success_body( Some(&request_context), build_ai_task_mutation_response(result), )) } pub async fn cancel_ai_task( State(state): State, Path(task_id): Path, Extension(request_context): Extension, Extension(_authenticated): Extension, ) -> Result, Response> { let result = state .spacetime_client() .cancel_ai_task(AiTaskCancelInput { task_id, completed_at_micros: current_utc_micros(), }) .await .map_err(|error| { ai_tasks_error_response(&request_context, map_ai_task_client_error(error)) })?; Ok(json_success_body( Some(&request_context), build_ai_task_mutation_response(result), )) } fn build_stage_blueprints( task_kind: AiTaskKind, stage_kinds: Vec, request_context: &RequestContext, ) -> Result, Response> { if stage_kinds.is_empty() { return Ok(task_kind.default_stage_blueprints()); } stage_kinds .into_iter() .enumerate() .map(|(index, stage_kind_text)| { let stage_kind = parse_ai_task_stage_kind_strict(&stage_kind_text).ok_or_else(|| { ai_tasks_error_response( request_context, AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({ "provider": "ai-task-stage", "message": format!("stageKinds[{index}] 非法"), })), ) })?; Ok(AiTaskStageBlueprint { stage_kind, label: stage_kind.default_label().to_string(), detail: stage_kind.default_detail().to_string(), order: index as u32, }) }) .collect() } fn build_ai_task_mutation_response(record: AiTaskMutationRecord) -> AiTaskMutationResponse { AiTaskMutationResponse { ai_task: build_ai_task_payload(record.task), ai_text_chunk: record.text_chunk.map(build_ai_text_chunk_payload), } } fn build_ai_task_payload(record: spacetime_client::AiTaskRecord) -> AiTaskPayload { AiTaskPayload { task_id: record.task_id, task_kind: record.task_kind, owner_user_id: record.owner_user_id, request_label: record.request_label, source_module: record.source_module, source_entity_id: record.source_entity_id, request_payload_json: record.request_payload_json, status: record.status, failure_message: record.failure_message, stages: record .stages .into_iter() .map(build_ai_task_stage_payload) .collect(), result_references: record .result_references .into_iter() .map(build_ai_result_reference_payload) .collect(), latest_text_output: record.latest_text_output, latest_structured_payload_json: record.latest_structured_payload_json, version: record.version, created_at: record.created_at, started_at: record.started_at, completed_at: record.completed_at, updated_at: record.updated_at, } } fn build_ai_task_stage_payload(record: spacetime_client::AiTaskStageRecord) -> AiTaskStagePayload { AiTaskStagePayload { stage_kind: record.stage_kind, label: record.label, detail: record.detail, order: record.order, status: record.status, text_output: record.text_output, structured_payload_json: record.structured_payload_json, warning_messages: record.warning_messages, started_at: record.started_at, completed_at: record.completed_at, } } fn build_ai_result_reference_payload( record: spacetime_client::AiResultReferenceRecord, ) -> AiResultReferencePayload { AiResultReferencePayload { result_ref_id: record.result_ref_id, task_id: record.task_id, reference_kind: record.reference_kind, reference_id: record.reference_id, label: record.label, created_at: record.created_at, } } fn build_ai_text_chunk_payload(record: spacetime_client::AiTextChunkRecord) -> AiTextChunkPayload { AiTextChunkPayload { chunk_id: record.chunk_id, task_id: record.task_id, stage_kind: record.stage_kind, sequence: record.sequence, delta_text: record.delta_text, created_at: record.created_at, } } fn parse_ai_task_kind_strict(value: &str) -> Option { match value.trim() { "story_generation" => Some(AiTaskKind::StoryGeneration), "character_chat" => Some(AiTaskKind::CharacterChat), "npc_chat" => Some(AiTaskKind::NpcChat), "custom_world_generation" => Some(AiTaskKind::CustomWorldGeneration), "quest_intent" => Some(AiTaskKind::QuestIntent), "runtime_item_intent" => Some(AiTaskKind::RuntimeItemIntent), _ => None, } } fn parse_ai_task_stage_kind_strict(value: &str) -> Option { match value.trim() { "prepare_prompt" => Some(AiTaskStageKind::PreparePrompt), "request_model" => Some(AiTaskStageKind::RequestModel), "repair_response" => Some(AiTaskStageKind::RepairResponse), "normalize_result" => Some(AiTaskStageKind::NormalizeResult), "persist_result" => Some(AiTaskStageKind::PersistResult), _ => None, } } fn parse_ai_result_reference_kind_strict(value: &str) -> Option { match value.trim() { "story_session" => Some(AiResultReferenceKind::StorySession), "story_event" => Some(AiResultReferenceKind::StoryEvent), "custom_world_profile" => Some(AiResultReferenceKind::CustomWorldProfile), "quest_record" => Some(AiResultReferenceKind::QuestRecord), "runtime_item_record" => Some(AiResultReferenceKind::RuntimeItemRecord), "asset_object" => Some(AiResultReferenceKind::AssetObject), _ => None, } } fn map_ai_task_client_error(error: SpacetimeClientError) -> AppError { let status = match &error { SpacetimeClientError::Runtime(_) => StatusCode::BAD_REQUEST, _ => StatusCode::BAD_GATEWAY, }; AppError::from_status(status).with_details(json!({ "provider": "spacetimedb", "message": error.to_string(), })) } fn ai_tasks_error_response(request_context: &RequestContext, error: AppError) -> Response { error.into_response_with_context(Some(request_context)) } fn ai_task_accepted_response( request_context: &RequestContext, payload: AiTaskAcceptedResponse, ) -> Response { let mut response = json_success_body(Some(request_context), payload).into_response(); *response.status_mut() = StatusCode::ACCEPTED; response } fn current_utc_micros() -> i64 { use std::time::{SystemTime, UNIX_EPOCH}; let duration = SystemTime::now() .duration_since(UNIX_EPOCH) .expect("system clock should be after unix epoch"); i64::try_from(duration.as_micros()).expect("current unix micros should fit in i64") } #[cfg(test)] mod tests { use axum::{ Router, body::Body, http::{Request, StatusCode}, }; use http_body_util::BodyExt; use platform_auth::{ AccessTokenClaims, AccessTokenClaimsInput, AuthProvider, BindingStatus, sign_access_token, }; use serde_json::{Value, json}; use time::OffsetDateTime; use tower::ServiceExt; use crate::{app::build_router, config::AppConfig, state::AppState}; #[tokio::test] async fn create_ai_task_requires_authentication() { let app = build_router(AppState::new(AppConfig::default()).expect("state should build")); let response = app .oneshot( Request::builder() .method("POST") .uri("/api/ai/tasks") .header("content-type", "application/json") .body(Body::from( json!({ "taskKind": "story_generation", "requestLabel": "营地开场", "sourceModule": "story" }) .to_string(), )) .expect("request should build"), ) .await .expect("request should succeed"); assert_eq!(response.status(), StatusCode::UNAUTHORIZED); } #[tokio::test] async fn create_ai_task_returns_bad_gateway_when_spacetime_not_published() { let state = seed_authenticated_state().await; let token = issue_access_token(&state); let app = build_router(state); let response = app .oneshot( Request::builder() .method("POST") .uri("/api/ai/tasks") .header("authorization", format!("Bearer {token}")) .header("content-type", "application/json") .header("x-genarrative-response-envelope", "v1") .body(Body::from( json!({ "taskKind": "npc_chat", "requestLabel": "试探问话", "sourceModule": "npc" }) .to_string(), )) .expect("request should build"), ) .await .expect("request should succeed"); assert_eq!(response.status(), StatusCode::BAD_GATEWAY); let body = response .into_body() .collect() .await .expect("body should collect") .to_bytes(); let payload: Value = serde_json::from_slice(&body).expect("response body should be valid json"); assert_eq!(payload["ok"], Value::Bool(false)); assert_eq!( payload["error"]["details"]["provider"], Value::String("spacetimedb".to_string()) ); } #[tokio::test] async fn start_ai_task_requires_authentication() { let app = build_router(AppState::new(AppConfig::default()).expect("state should build")); let response = app .oneshot( Request::builder() .method("POST") .uri("/api/ai/tasks/aitask_001/start") .body(Body::empty()) .expect("request should build"), ) .await .expect("request should succeed"); assert_eq!(response.status(), StatusCode::UNAUTHORIZED); } #[tokio::test] async fn start_ai_task_returns_bad_gateway_when_spacetime_not_published() { let state = seed_authenticated_state().await; let token = issue_access_token(&state); let app = build_router(state); let response = app .oneshot( Request::builder() .method("POST") .uri("/api/ai/tasks/aitask_001/start") .header("authorization", format!("Bearer {token}")) .header("x-genarrative-response-envelope", "v1") .body(Body::empty()) .expect("request should build"), ) .await .expect("request should succeed"); assert_eq!(response.status(), StatusCode::BAD_GATEWAY); let body = response .into_body() .collect() .await .expect("body should collect") .to_bytes(); let payload: Value = serde_json::from_slice(&body).expect("response body should be valid json"); assert_eq!(payload["ok"], Value::Bool(false)); assert_eq!( payload["error"]["details"]["provider"], Value::String("spacetimedb".to_string()) ); } #[tokio::test] async fn ai_task_mutation_routes_require_authentication() { let app = build_router(AppState::new(AppConfig::default()).expect("state should build")); for route in ai_task_mutation_route_cases() { let (status, _) = post_ai_task_route(app.clone(), route.uri, None, route.body).await; assert_eq!(status, StatusCode::UNAUTHORIZED, "{}", route.uri); } } #[tokio::test] async fn ai_task_mutation_routes_return_bad_gateway_when_spacetime_not_published() { let state = seed_authenticated_state().await; let token = issue_access_token(&state); let app = build_router(state); for route in ai_task_mutation_route_cases() { let (status, payload) = post_ai_task_route(app.clone(), route.uri, Some(&token), route.body).await; assert_eq!(status, StatusCode::BAD_GATEWAY, "{}", route.uri); assert_eq!( payload["error"]["details"]["provider"], Value::String("spacetimedb".to_string()), "{}", route.uri ); } } struct AiTaskRouteCase { uri: &'static str, body: Option, } fn ai_task_mutation_route_cases() -> Vec { vec![ AiTaskRouteCase { uri: "/api/ai/tasks/aitask_001/stages/request_model/start", body: None, }, AiTaskRouteCase { uri: "/api/ai/tasks/aitask_001/chunks", body: Some(json!({ "stageKind": "request_model", "sequence": 1, "deltaText": "你听见远处的铃声。" })), }, AiTaskRouteCase { uri: "/api/ai/tasks/aitask_001/stages/request_model/complete", body: Some(json!({ "textOutput": "你听见远处的铃声。", "structuredPayloadJson": "{\"scene\":\"camp\"}", "warningMessages": [] })), }, AiTaskRouteCase { uri: "/api/ai/tasks/aitask_001/references", body: Some(json!({ "referenceKind": "story_event", "referenceId": "storyevt_001", "label": "营地开场" })), }, AiTaskRouteCase { uri: "/api/ai/tasks/aitask_001/complete", body: None, }, AiTaskRouteCase { uri: "/api/ai/tasks/aitask_001/fail", body: Some(json!({ "failureMessage": "模型返回内容为空" })), }, AiTaskRouteCase { uri: "/api/ai/tasks/aitask_001/cancel", body: None, }, ] } async fn post_ai_task_route( app: Router, uri: &str, bearer_token: Option<&str>, body: Option, ) -> (StatusCode, Value) { let mut request = Request::builder() .method("POST") .uri(uri) .header("x-genarrative-response-envelope", "v1"); if let Some(token) = bearer_token { request = request.header("authorization", format!("Bearer {token}")); } let body = if let Some(payload) = body { request = request.header("content-type", "application/json"); Body::from(payload.to_string()) } else { Body::empty() }; let response = app .oneshot(request.body(body).expect("request should build")) .await .expect("request should succeed"); let status = response.status(); let body = response .into_body() .collect() .await .expect("body should collect") .to_bytes(); let payload = if body.is_empty() { Value::Null } else { serde_json::from_slice(&body).expect("response body should be valid json") }; (status, payload) } async fn seed_authenticated_state() -> AppState { let state = AppState::new(AppConfig::default()).expect("state should build"); state .seed_test_phone_user_with_password("13800138100", "secret123") .await .id; state } fn issue_access_token(state: &AppState) -> String { let claims = AccessTokenClaims::from_input( AccessTokenClaimsInput { user_id: "user_00000001".to_string(), session_id: state .seed_test_refresh_session_for_user_id("user_00000001", "sess_ai_tasks"), provider: AuthProvider::Password, roles: vec!["user".to_string()], token_version: 2, phone_verified: true, binding_status: BindingStatus::Active, display_name: Some("AI 任务用户".to_string()), }, state.auth_jwt_config(), OffsetDateTime::now_utc(), ) .expect("claims should build"); sign_access_token(&claims, state.auth_jwt_config()).expect("token should sign") } }