795 lines
27 KiB
Rust
795 lines
27 KiB
Rust
use axum::{
|
|
Json,
|
|
extract::{Extension, Path, State},
|
|
http::StatusCode,
|
|
response::{IntoResponse, Response},
|
|
};
|
|
use module_ai::{
|
|
AiResultReferenceInput, AiResultReferenceKind, AiStageCompletionInput, AiTaskCancelInput,
|
|
AiTaskCreateInput, AiTaskFailureInput, AiTaskFinishInput, AiTaskKind, AiTaskStageBlueprint,
|
|
AiTaskStageKind, AiTaskStageStartInput, AiTaskStartInput, AiTextChunkAppendInput,
|
|
generate_ai_task_id,
|
|
};
|
|
use serde_json::{Value, json};
|
|
use shared_contracts::ai::{
|
|
AiResultReferencePayload, AiTaskAcceptedResponse, AiTaskMutationResponse, AiTaskPayload,
|
|
AiTaskStagePayload, AiTextChunkPayload, AppendAiTextChunkRequest,
|
|
AttachAiResultReferenceRequest, CompleteAiStageRequest, CreateAiTaskRequest, FailAiTaskRequest,
|
|
};
|
|
use spacetime_client::{AiTaskMutationRecord, SpacetimeClientError};
|
|
|
|
use crate::{
|
|
api_response::json_success_body, auth::AuthenticatedAccessToken, http_error::AppError,
|
|
request_context::RequestContext, state::AppState,
|
|
};
|
|
|
|
pub async fn create_ai_task(
|
|
State(state): State<AppState>,
|
|
Extension(request_context): Extension<RequestContext>,
|
|
Extension(authenticated): Extension<AuthenticatedAccessToken>,
|
|
Json(payload): Json<CreateAiTaskRequest>,
|
|
) -> Result<Json<Value>, Response> {
|
|
let now_micros = current_utc_micros();
|
|
let task_kind = parse_ai_task_kind_strict(&payload.task_kind).ok_or_else(|| {
|
|
ai_tasks_error_response(
|
|
&request_context,
|
|
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
|
|
"provider": "ai-task",
|
|
"message": "taskKind 非法",
|
|
})),
|
|
)
|
|
})?;
|
|
let stages = build_stage_blueprints(task_kind, payload.stage_kinds, &request_context)?;
|
|
let owner_user_id = authenticated.claims().user_id().to_string();
|
|
|
|
let result = state
|
|
.spacetime_client()
|
|
.create_ai_task(AiTaskCreateInput {
|
|
task_id: generate_ai_task_id(now_micros),
|
|
task_kind,
|
|
owner_user_id,
|
|
request_label: payload.request_label,
|
|
source_module: payload.source_module,
|
|
source_entity_id: payload.source_entity_id,
|
|
request_payload_json: payload.request_payload_json,
|
|
stages,
|
|
created_at_micros: now_micros,
|
|
})
|
|
.await
|
|
.map_err(|error| {
|
|
ai_tasks_error_response(&request_context, map_ai_task_client_error(error))
|
|
})?;
|
|
|
|
Ok(json_success_body(
|
|
Some(&request_context),
|
|
build_ai_task_mutation_response(result),
|
|
))
|
|
}
|
|
|
|
pub async fn start_ai_task(
|
|
State(state): State<AppState>,
|
|
Path(task_id): Path<String>,
|
|
Extension(request_context): Extension<RequestContext>,
|
|
Extension(_authenticated): Extension<AuthenticatedAccessToken>,
|
|
) -> Result<Response, Response> {
|
|
state
|
|
.spacetime_client()
|
|
.start_ai_task(AiTaskStartInput {
|
|
task_id: task_id.clone(),
|
|
started_at_micros: current_utc_micros(),
|
|
})
|
|
.await
|
|
.map_err(|error| {
|
|
ai_tasks_error_response(&request_context, map_ai_task_client_error(error))
|
|
})?;
|
|
|
|
Ok(ai_task_accepted_response(
|
|
&request_context,
|
|
AiTaskAcceptedResponse {
|
|
accepted: true,
|
|
task_id,
|
|
action: "start_task".to_string(),
|
|
stage_kind: None,
|
|
},
|
|
))
|
|
}
|
|
|
|
pub async fn start_ai_task_stage(
|
|
State(state): State<AppState>,
|
|
Path((task_id, stage_kind_text)): Path<(String, String)>,
|
|
Extension(request_context): Extension<RequestContext>,
|
|
Extension(_authenticated): Extension<AuthenticatedAccessToken>,
|
|
) -> Result<Response, Response> {
|
|
let stage_kind = parse_ai_task_stage_kind_strict(&stage_kind_text).ok_or_else(|| {
|
|
ai_tasks_error_response(
|
|
&request_context,
|
|
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
|
|
"provider": "ai-task-stage",
|
|
"message": "stageKind 非法",
|
|
})),
|
|
)
|
|
})?;
|
|
|
|
state
|
|
.spacetime_client()
|
|
.start_ai_task_stage(AiTaskStageStartInput {
|
|
task_id: task_id.clone(),
|
|
stage_kind,
|
|
started_at_micros: current_utc_micros(),
|
|
})
|
|
.await
|
|
.map_err(|error| {
|
|
ai_tasks_error_response(&request_context, map_ai_task_client_error(error))
|
|
})?;
|
|
|
|
Ok(ai_task_accepted_response(
|
|
&request_context,
|
|
AiTaskAcceptedResponse {
|
|
accepted: true,
|
|
task_id,
|
|
action: "start_stage".to_string(),
|
|
stage_kind: Some(stage_kind.as_str().to_string()),
|
|
},
|
|
))
|
|
}
|
|
|
|
pub async fn append_ai_text_chunk(
|
|
State(state): State<AppState>,
|
|
Path(task_id): Path<String>,
|
|
Extension(request_context): Extension<RequestContext>,
|
|
Extension(_authenticated): Extension<AuthenticatedAccessToken>,
|
|
Json(payload): Json<AppendAiTextChunkRequest>,
|
|
) -> Result<Json<Value>, Response> {
|
|
let stage_kind = parse_ai_task_stage_kind_strict(&payload.stage_kind).ok_or_else(|| {
|
|
ai_tasks_error_response(
|
|
&request_context,
|
|
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
|
|
"provider": "ai-task-stage",
|
|
"message": "stageKind 非法",
|
|
})),
|
|
)
|
|
})?;
|
|
|
|
let result = state
|
|
.spacetime_client()
|
|
.append_ai_text_chunk(AiTextChunkAppendInput {
|
|
task_id,
|
|
stage_kind,
|
|
sequence: payload.sequence,
|
|
delta_text: payload.delta_text,
|
|
created_at_micros: current_utc_micros(),
|
|
})
|
|
.await
|
|
.map_err(|error| {
|
|
ai_tasks_error_response(&request_context, map_ai_task_client_error(error))
|
|
})?;
|
|
|
|
Ok(json_success_body(
|
|
Some(&request_context),
|
|
build_ai_task_mutation_response(result),
|
|
))
|
|
}
|
|
|
|
pub async fn complete_ai_stage(
|
|
State(state): State<AppState>,
|
|
Path((task_id, stage_kind_text)): Path<(String, String)>,
|
|
Extension(request_context): Extension<RequestContext>,
|
|
Extension(_authenticated): Extension<AuthenticatedAccessToken>,
|
|
Json(payload): Json<CompleteAiStageRequest>,
|
|
) -> Result<Json<Value>, Response> {
|
|
let stage_kind = parse_ai_task_stage_kind_strict(&stage_kind_text).ok_or_else(|| {
|
|
ai_tasks_error_response(
|
|
&request_context,
|
|
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
|
|
"provider": "ai-task-stage",
|
|
"message": "stageKind 非法",
|
|
})),
|
|
)
|
|
})?;
|
|
|
|
let result = state
|
|
.spacetime_client()
|
|
.complete_ai_stage(AiStageCompletionInput {
|
|
task_id,
|
|
stage_kind,
|
|
text_output: payload.text_output,
|
|
structured_payload_json: payload.structured_payload_json,
|
|
warning_messages: payload.warning_messages,
|
|
completed_at_micros: current_utc_micros(),
|
|
})
|
|
.await
|
|
.map_err(|error| {
|
|
ai_tasks_error_response(&request_context, map_ai_task_client_error(error))
|
|
})?;
|
|
|
|
Ok(json_success_body(
|
|
Some(&request_context),
|
|
build_ai_task_mutation_response(result),
|
|
))
|
|
}
|
|
|
|
pub async fn attach_ai_result_reference(
|
|
State(state): State<AppState>,
|
|
Path(task_id): Path<String>,
|
|
Extension(request_context): Extension<RequestContext>,
|
|
Extension(_authenticated): Extension<AuthenticatedAccessToken>,
|
|
Json(payload): Json<AttachAiResultReferenceRequest>,
|
|
) -> Result<Json<Value>, Response> {
|
|
let reference_kind = parse_ai_result_reference_kind_strict(&payload.reference_kind)
|
|
.ok_or_else(|| {
|
|
ai_tasks_error_response(
|
|
&request_context,
|
|
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
|
|
"provider": "ai-task-reference",
|
|
"message": "referenceKind 非法",
|
|
})),
|
|
)
|
|
})?;
|
|
|
|
let result = state
|
|
.spacetime_client()
|
|
.attach_ai_result_reference(AiResultReferenceInput {
|
|
task_id,
|
|
reference_kind,
|
|
reference_id: payload.reference_id,
|
|
label: payload.label,
|
|
created_at_micros: current_utc_micros(),
|
|
})
|
|
.await
|
|
.map_err(|error| {
|
|
ai_tasks_error_response(&request_context, map_ai_task_client_error(error))
|
|
})?;
|
|
|
|
Ok(json_success_body(
|
|
Some(&request_context),
|
|
build_ai_task_mutation_response(result),
|
|
))
|
|
}
|
|
|
|
pub async fn complete_ai_task(
|
|
State(state): State<AppState>,
|
|
Path(task_id): Path<String>,
|
|
Extension(request_context): Extension<RequestContext>,
|
|
Extension(_authenticated): Extension<AuthenticatedAccessToken>,
|
|
) -> Result<Json<Value>, Response> {
|
|
let result = state
|
|
.spacetime_client()
|
|
.complete_ai_task(AiTaskFinishInput {
|
|
task_id,
|
|
completed_at_micros: current_utc_micros(),
|
|
})
|
|
.await
|
|
.map_err(|error| {
|
|
ai_tasks_error_response(&request_context, map_ai_task_client_error(error))
|
|
})?;
|
|
|
|
Ok(json_success_body(
|
|
Some(&request_context),
|
|
build_ai_task_mutation_response(result),
|
|
))
|
|
}
|
|
|
|
pub async fn fail_ai_task(
|
|
State(state): State<AppState>,
|
|
Path(task_id): Path<String>,
|
|
Extension(request_context): Extension<RequestContext>,
|
|
Extension(_authenticated): Extension<AuthenticatedAccessToken>,
|
|
Json(payload): Json<FailAiTaskRequest>,
|
|
) -> Result<Json<Value>, Response> {
|
|
let result = state
|
|
.spacetime_client()
|
|
.fail_ai_task(AiTaskFailureInput {
|
|
task_id,
|
|
failure_message: payload.failure_message,
|
|
completed_at_micros: current_utc_micros(),
|
|
})
|
|
.await
|
|
.map_err(|error| {
|
|
ai_tasks_error_response(&request_context, map_ai_task_client_error(error))
|
|
})?;
|
|
|
|
Ok(json_success_body(
|
|
Some(&request_context),
|
|
build_ai_task_mutation_response(result),
|
|
))
|
|
}
|
|
|
|
pub async fn cancel_ai_task(
|
|
State(state): State<AppState>,
|
|
Path(task_id): Path<String>,
|
|
Extension(request_context): Extension<RequestContext>,
|
|
Extension(_authenticated): Extension<AuthenticatedAccessToken>,
|
|
) -> Result<Json<Value>, Response> {
|
|
let result = state
|
|
.spacetime_client()
|
|
.cancel_ai_task(AiTaskCancelInput {
|
|
task_id,
|
|
completed_at_micros: current_utc_micros(),
|
|
})
|
|
.await
|
|
.map_err(|error| {
|
|
ai_tasks_error_response(&request_context, map_ai_task_client_error(error))
|
|
})?;
|
|
|
|
Ok(json_success_body(
|
|
Some(&request_context),
|
|
build_ai_task_mutation_response(result),
|
|
))
|
|
}
|
|
|
|
fn build_stage_blueprints(
|
|
task_kind: AiTaskKind,
|
|
stage_kinds: Vec<String>,
|
|
request_context: &RequestContext,
|
|
) -> Result<Vec<AiTaskStageBlueprint>, Response> {
|
|
if stage_kinds.is_empty() {
|
|
return Ok(task_kind.default_stage_blueprints());
|
|
}
|
|
|
|
stage_kinds
|
|
.into_iter()
|
|
.enumerate()
|
|
.map(|(index, stage_kind_text)| {
|
|
let stage_kind =
|
|
parse_ai_task_stage_kind_strict(&stage_kind_text).ok_or_else(|| {
|
|
ai_tasks_error_response(
|
|
request_context,
|
|
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
|
|
"provider": "ai-task-stage",
|
|
"message": format!("stageKinds[{index}] 非法"),
|
|
})),
|
|
)
|
|
})?;
|
|
|
|
Ok(AiTaskStageBlueprint {
|
|
stage_kind,
|
|
label: stage_kind.default_label().to_string(),
|
|
detail: stage_kind.default_detail().to_string(),
|
|
order: index as u32,
|
|
})
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
fn build_ai_task_mutation_response(record: AiTaskMutationRecord) -> AiTaskMutationResponse {
|
|
AiTaskMutationResponse {
|
|
ai_task: build_ai_task_payload(record.task),
|
|
ai_text_chunk: record.text_chunk.map(build_ai_text_chunk_payload),
|
|
}
|
|
}
|
|
|
|
fn build_ai_task_payload(record: spacetime_client::AiTaskRecord) -> AiTaskPayload {
|
|
AiTaskPayload {
|
|
task_id: record.task_id,
|
|
task_kind: record.task_kind,
|
|
owner_user_id: record.owner_user_id,
|
|
request_label: record.request_label,
|
|
source_module: record.source_module,
|
|
source_entity_id: record.source_entity_id,
|
|
request_payload_json: record.request_payload_json,
|
|
status: record.status,
|
|
failure_message: record.failure_message,
|
|
stages: record
|
|
.stages
|
|
.into_iter()
|
|
.map(build_ai_task_stage_payload)
|
|
.collect(),
|
|
result_references: record
|
|
.result_references
|
|
.into_iter()
|
|
.map(build_ai_result_reference_payload)
|
|
.collect(),
|
|
latest_text_output: record.latest_text_output,
|
|
latest_structured_payload_json: record.latest_structured_payload_json,
|
|
version: record.version,
|
|
created_at: record.created_at,
|
|
started_at: record.started_at,
|
|
completed_at: record.completed_at,
|
|
updated_at: record.updated_at,
|
|
}
|
|
}
|
|
|
|
fn build_ai_task_stage_payload(record: spacetime_client::AiTaskStageRecord) -> AiTaskStagePayload {
|
|
AiTaskStagePayload {
|
|
stage_kind: record.stage_kind,
|
|
label: record.label,
|
|
detail: record.detail,
|
|
order: record.order,
|
|
status: record.status,
|
|
text_output: record.text_output,
|
|
structured_payload_json: record.structured_payload_json,
|
|
warning_messages: record.warning_messages,
|
|
started_at: record.started_at,
|
|
completed_at: record.completed_at,
|
|
}
|
|
}
|
|
|
|
fn build_ai_result_reference_payload(
|
|
record: spacetime_client::AiResultReferenceRecord,
|
|
) -> AiResultReferencePayload {
|
|
AiResultReferencePayload {
|
|
result_ref_id: record.result_ref_id,
|
|
task_id: record.task_id,
|
|
reference_kind: record.reference_kind,
|
|
reference_id: record.reference_id,
|
|
label: record.label,
|
|
created_at: record.created_at,
|
|
}
|
|
}
|
|
|
|
fn build_ai_text_chunk_payload(record: spacetime_client::AiTextChunkRecord) -> AiTextChunkPayload {
|
|
AiTextChunkPayload {
|
|
chunk_id: record.chunk_id,
|
|
task_id: record.task_id,
|
|
stage_kind: record.stage_kind,
|
|
sequence: record.sequence,
|
|
delta_text: record.delta_text,
|
|
created_at: record.created_at,
|
|
}
|
|
}
|
|
|
|
fn parse_ai_task_kind_strict(value: &str) -> Option<AiTaskKind> {
|
|
match value.trim() {
|
|
"story_generation" => Some(AiTaskKind::StoryGeneration),
|
|
"character_chat" => Some(AiTaskKind::CharacterChat),
|
|
"npc_chat" => Some(AiTaskKind::NpcChat),
|
|
"custom_world_generation" => Some(AiTaskKind::CustomWorldGeneration),
|
|
"quest_intent" => Some(AiTaskKind::QuestIntent),
|
|
"runtime_item_intent" => Some(AiTaskKind::RuntimeItemIntent),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
fn parse_ai_task_stage_kind_strict(value: &str) -> Option<AiTaskStageKind> {
|
|
match value.trim() {
|
|
"prepare_prompt" => Some(AiTaskStageKind::PreparePrompt),
|
|
"request_model" => Some(AiTaskStageKind::RequestModel),
|
|
"repair_response" => Some(AiTaskStageKind::RepairResponse),
|
|
"normalize_result" => Some(AiTaskStageKind::NormalizeResult),
|
|
"persist_result" => Some(AiTaskStageKind::PersistResult),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
fn parse_ai_result_reference_kind_strict(value: &str) -> Option<AiResultReferenceKind> {
|
|
match value.trim() {
|
|
"story_session" => Some(AiResultReferenceKind::StorySession),
|
|
"story_event" => Some(AiResultReferenceKind::StoryEvent),
|
|
"custom_world_profile" => Some(AiResultReferenceKind::CustomWorldProfile),
|
|
"quest_record" => Some(AiResultReferenceKind::QuestRecord),
|
|
"runtime_item_record" => Some(AiResultReferenceKind::RuntimeItemRecord),
|
|
"asset_object" => Some(AiResultReferenceKind::AssetObject),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
fn map_ai_task_client_error(error: SpacetimeClientError) -> AppError {
|
|
let status = match &error {
|
|
SpacetimeClientError::Runtime(_) => StatusCode::BAD_REQUEST,
|
|
_ => StatusCode::BAD_GATEWAY,
|
|
};
|
|
|
|
AppError::from_status(status).with_details(json!({
|
|
"provider": "spacetimedb",
|
|
"message": error.to_string(),
|
|
}))
|
|
}
|
|
|
|
fn ai_tasks_error_response(request_context: &RequestContext, error: AppError) -> Response {
|
|
error.into_response_with_context(Some(request_context))
|
|
}
|
|
|
|
fn ai_task_accepted_response(
|
|
request_context: &RequestContext,
|
|
payload: AiTaskAcceptedResponse,
|
|
) -> Response {
|
|
let mut response = json_success_body(Some(request_context), payload).into_response();
|
|
*response.status_mut() = StatusCode::ACCEPTED;
|
|
response
|
|
}
|
|
|
|
fn current_utc_micros() -> i64 {
|
|
use std::time::{SystemTime, UNIX_EPOCH};
|
|
|
|
let duration = SystemTime::now()
|
|
.duration_since(UNIX_EPOCH)
|
|
.expect("system clock should be after unix epoch");
|
|
i64::try_from(duration.as_micros()).expect("current unix micros should fit in i64")
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use axum::{
|
|
Router,
|
|
body::Body,
|
|
http::{Request, StatusCode},
|
|
};
|
|
use http_body_util::BodyExt;
|
|
use platform_auth::{
|
|
AccessTokenClaims, AccessTokenClaimsInput, AuthProvider, BindingStatus, sign_access_token,
|
|
};
|
|
use serde_json::{Value, json};
|
|
use time::OffsetDateTime;
|
|
use tower::ServiceExt;
|
|
|
|
use crate::{app::build_router, config::AppConfig, state::AppState};
|
|
|
|
#[tokio::test]
|
|
async fn create_ai_task_requires_authentication() {
|
|
let app = build_router(AppState::new(AppConfig::default()).expect("state should build"));
|
|
|
|
let response = app
|
|
.oneshot(
|
|
Request::builder()
|
|
.method("POST")
|
|
.uri("/api/ai/tasks")
|
|
.header("content-type", "application/json")
|
|
.body(Body::from(
|
|
json!({
|
|
"taskKind": "story_generation",
|
|
"requestLabel": "营地开场",
|
|
"sourceModule": "story"
|
|
})
|
|
.to_string(),
|
|
))
|
|
.expect("request should build"),
|
|
)
|
|
.await
|
|
.expect("request should succeed");
|
|
|
|
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn create_ai_task_returns_bad_gateway_when_spacetime_not_published() {
|
|
let state = seed_authenticated_state().await;
|
|
let token = issue_access_token(&state);
|
|
let app = build_router(state);
|
|
|
|
let response = app
|
|
.oneshot(
|
|
Request::builder()
|
|
.method("POST")
|
|
.uri("/api/ai/tasks")
|
|
.header("authorization", format!("Bearer {token}"))
|
|
.header("content-type", "application/json")
|
|
.header("x-genarrative-response-envelope", "v1")
|
|
.body(Body::from(
|
|
json!({
|
|
"taskKind": "npc_chat",
|
|
"requestLabel": "试探问话",
|
|
"sourceModule": "npc"
|
|
})
|
|
.to_string(),
|
|
))
|
|
.expect("request should build"),
|
|
)
|
|
.await
|
|
.expect("request should succeed");
|
|
|
|
assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
|
|
|
|
let body = response
|
|
.into_body()
|
|
.collect()
|
|
.await
|
|
.expect("body should collect")
|
|
.to_bytes();
|
|
let payload: Value =
|
|
serde_json::from_slice(&body).expect("response body should be valid json");
|
|
|
|
assert_eq!(payload["ok"], Value::Bool(false));
|
|
assert_eq!(
|
|
payload["error"]["details"]["provider"],
|
|
Value::String("spacetimedb".to_string())
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn start_ai_task_requires_authentication() {
|
|
let app = build_router(AppState::new(AppConfig::default()).expect("state should build"));
|
|
|
|
let response = app
|
|
.oneshot(
|
|
Request::builder()
|
|
.method("POST")
|
|
.uri("/api/ai/tasks/aitask_001/start")
|
|
.body(Body::empty())
|
|
.expect("request should build"),
|
|
)
|
|
.await
|
|
.expect("request should succeed");
|
|
|
|
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn start_ai_task_returns_bad_gateway_when_spacetime_not_published() {
|
|
let state = seed_authenticated_state().await;
|
|
let token = issue_access_token(&state);
|
|
let app = build_router(state);
|
|
|
|
let response = app
|
|
.oneshot(
|
|
Request::builder()
|
|
.method("POST")
|
|
.uri("/api/ai/tasks/aitask_001/start")
|
|
.header("authorization", format!("Bearer {token}"))
|
|
.header("x-genarrative-response-envelope", "v1")
|
|
.body(Body::empty())
|
|
.expect("request should build"),
|
|
)
|
|
.await
|
|
.expect("request should succeed");
|
|
|
|
assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
|
|
|
|
let body = response
|
|
.into_body()
|
|
.collect()
|
|
.await
|
|
.expect("body should collect")
|
|
.to_bytes();
|
|
let payload: Value =
|
|
serde_json::from_slice(&body).expect("response body should be valid json");
|
|
|
|
assert_eq!(payload["ok"], Value::Bool(false));
|
|
assert_eq!(
|
|
payload["error"]["details"]["provider"],
|
|
Value::String("spacetimedb".to_string())
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn ai_task_mutation_routes_require_authentication() {
|
|
let app = build_router(AppState::new(AppConfig::default()).expect("state should build"));
|
|
|
|
for route in ai_task_mutation_route_cases() {
|
|
let (status, _) = post_ai_task_route(app.clone(), route.uri, None, route.body).await;
|
|
assert_eq!(status, StatusCode::UNAUTHORIZED, "{}", route.uri);
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn ai_task_mutation_routes_return_bad_gateway_when_spacetime_not_published() {
|
|
let state = seed_authenticated_state().await;
|
|
let token = issue_access_token(&state);
|
|
let app = build_router(state);
|
|
|
|
for route in ai_task_mutation_route_cases() {
|
|
let (status, payload) =
|
|
post_ai_task_route(app.clone(), route.uri, Some(&token), route.body).await;
|
|
assert_eq!(status, StatusCode::BAD_GATEWAY, "{}", route.uri);
|
|
assert_eq!(
|
|
payload["error"]["details"]["provider"],
|
|
Value::String("spacetimedb".to_string()),
|
|
"{}",
|
|
route.uri
|
|
);
|
|
}
|
|
}
|
|
|
|
struct AiTaskRouteCase {
|
|
uri: &'static str,
|
|
body: Option<Value>,
|
|
}
|
|
|
|
fn ai_task_mutation_route_cases() -> Vec<AiTaskRouteCase> {
|
|
vec![
|
|
AiTaskRouteCase {
|
|
uri: "/api/ai/tasks/aitask_001/stages/request_model/start",
|
|
body: None,
|
|
},
|
|
AiTaskRouteCase {
|
|
uri: "/api/ai/tasks/aitask_001/chunks",
|
|
body: Some(json!({
|
|
"stageKind": "request_model",
|
|
"sequence": 1,
|
|
"deltaText": "你听见远处的铃声。"
|
|
})),
|
|
},
|
|
AiTaskRouteCase {
|
|
uri: "/api/ai/tasks/aitask_001/stages/request_model/complete",
|
|
body: Some(json!({
|
|
"textOutput": "你听见远处的铃声。",
|
|
"structuredPayloadJson": "{\"scene\":\"camp\"}",
|
|
"warningMessages": []
|
|
})),
|
|
},
|
|
AiTaskRouteCase {
|
|
uri: "/api/ai/tasks/aitask_001/references",
|
|
body: Some(json!({
|
|
"referenceKind": "story_event",
|
|
"referenceId": "storyevt_001",
|
|
"label": "营地开场"
|
|
})),
|
|
},
|
|
AiTaskRouteCase {
|
|
uri: "/api/ai/tasks/aitask_001/complete",
|
|
body: None,
|
|
},
|
|
AiTaskRouteCase {
|
|
uri: "/api/ai/tasks/aitask_001/fail",
|
|
body: Some(json!({
|
|
"failureMessage": "模型返回内容为空"
|
|
})),
|
|
},
|
|
AiTaskRouteCase {
|
|
uri: "/api/ai/tasks/aitask_001/cancel",
|
|
body: None,
|
|
},
|
|
]
|
|
}
|
|
|
|
async fn post_ai_task_route(
|
|
app: Router,
|
|
uri: &str,
|
|
bearer_token: Option<&str>,
|
|
body: Option<Value>,
|
|
) -> (StatusCode, Value) {
|
|
let mut request = Request::builder()
|
|
.method("POST")
|
|
.uri(uri)
|
|
.header("x-genarrative-response-envelope", "v1");
|
|
|
|
if let Some(token) = bearer_token {
|
|
request = request.header("authorization", format!("Bearer {token}"));
|
|
}
|
|
|
|
let body = if let Some(payload) = body {
|
|
request = request.header("content-type", "application/json");
|
|
Body::from(payload.to_string())
|
|
} else {
|
|
Body::empty()
|
|
};
|
|
|
|
let response = app
|
|
.oneshot(request.body(body).expect("request should build"))
|
|
.await
|
|
.expect("request should succeed");
|
|
let status = response.status();
|
|
let body = response
|
|
.into_body()
|
|
.collect()
|
|
.await
|
|
.expect("body should collect")
|
|
.to_bytes();
|
|
let payload = if body.is_empty() {
|
|
Value::Null
|
|
} else {
|
|
serde_json::from_slice(&body).expect("response body should be valid json")
|
|
};
|
|
|
|
(status, payload)
|
|
}
|
|
|
|
async fn seed_authenticated_state() -> AppState {
|
|
let state = AppState::new(AppConfig::default()).expect("state should build");
|
|
state
|
|
.seed_test_phone_user_with_password("13800138100", "secret123")
|
|
.await
|
|
.id;
|
|
state
|
|
}
|
|
|
|
fn issue_access_token(state: &AppState) -> String {
|
|
let claims = AccessTokenClaims::from_input(
|
|
AccessTokenClaimsInput {
|
|
user_id: "user_00000001".to_string(),
|
|
session_id: "sess_ai_tasks".to_string(),
|
|
provider: AuthProvider::Password,
|
|
roles: vec!["user".to_string()],
|
|
token_version: 2,
|
|
phone_verified: true,
|
|
binding_status: BindingStatus::Active,
|
|
display_name: Some("AI 任务用户".to_string()),
|
|
},
|
|
state.auth_jwt_config(),
|
|
OffsetDateTime::now_utc(),
|
|
)
|
|
.expect("claims should build");
|
|
|
|
sign_access_token(&claims, state.auth_jwt_config()).expect("token should sign")
|
|
}
|
|
}
|