Files
Genarrative/server-rs/crates/api-server/src/admin.rs
kdletters 29cf68a31a Merge branch 'codex/web-admin'
# Conflicts:
#	server-rs/crates/api-server/src/admin.rs
2026-05-01 00:58:42 +08:00

807 lines
26 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use std::{
collections::BTreeSet,
net::{IpAddr, Ipv4Addr, Ipv6Addr},
};
use axum::{
Json,
extract::{Extension, Request, State},
http::{
HeaderMap, HeaderName, HeaderValue, Method, StatusCode,
header::{AUTHORIZATION, CONTENT_TYPE},
},
middleware::Next,
response::Response,
};
use reqwest::Client;
use serde::Deserialize;
use serde_json::Value;
use shared_contracts::admin::{
AdminDatabaseOverviewPayload, AdminDatabaseTableStatPayload, AdminDebugHeaderInput,
AdminDebugHttpRequest, AdminDebugHttpResponse, AdminLoginRequest, AdminLoginResponse,
AdminMeResponse, AdminOverviewResponse, AdminServiceOverviewPayload, AdminSessionPayload,
};
use time::{OffsetDateTime, format_description::well_known::Rfc3339};
use crate::{
api_response::json_success_body,
http_error::AppError,
request_context::RequestContext,
state::{AdminRuntime, AppState},
};
// 首版调试台只允许有限大小的请求体,避免把后台当作通用代理大包转发器。
const MAX_DEBUG_BODY_BYTES: usize = 128 * 1024;
const BLOCKED_DEBUG_HEADERS: &[&str] = &[
"host",
"content-length",
"connection",
"transfer-encoding",
"expect",
];
// 数据库概览首版只统计受控白名单表,禁止后台页面直接输入任意 SQL。
const DATABASE_OVERVIEW_TABLES: &[&str] = &[
"runtime_setting",
"runtime_snapshot",
"user_browse_history",
"profile_dashboard_state",
"profile_wallet_ledger",
"profile_played_world",
"profile_save_archive",
"story_session",
"story_event",
"battle_state",
"inventory_slot",
"quest_record",
"quest_log",
"treasure_record",
"npc_state",
"custom_world_profile",
"custom_world_gallery_entry",
"custom_world_agent_session",
"custom_world_agent_message",
"custom_world_agent_operation",
"custom_world_draft_card",
"big_fish_creation_session",
"big_fish_agent_message",
"big_fish_asset_slot",
"puzzle_work_profile",
"puzzle_agent_session",
"puzzle_agent_message",
"puzzle_runtime_run",
"ai_task",
"ai_task_stage",
"ai_text_chunk",
"ai_result_reference",
"asset_object",
"asset_entity_binding",
];
// SpacetimeDB 2.x 的 schema HTTP API 要求显式传入 BSATN JSON 版本。
// 后台总览只读取表名,固定使用当前 CLI 2.1.0 兼容的版本参数即可。
const SPACETIME_SCHEMA_VERSION_QUERY: &str = "version=9";
#[derive(Clone, Debug)]
pub struct AuthenticatedAdmin {
session: AdminSessionPayload,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct SpacetimeDatabaseInfoResponse {
database_identity: Option<String>,
owner_identity: Option<String>,
host_type: Option<String>,
}
#[derive(Debug, Deserialize)]
struct SpacetimeSchemaResponse {
tables: Option<Vec<SpacetimeSchemaTable>>,
}
#[derive(Debug, Deserialize)]
struct SpacetimeSchemaTable {
name: Option<String>,
}
impl AuthenticatedAdmin {
pub fn new(session: AdminSessionPayload) -> Self {
Self { session }
}
pub fn session(&self) -> &AdminSessionPayload {
&self.session
}
}
pub async fn admin_login(
State(state): State<AppState>,
Extension(request_context): Extension<RequestContext>,
Json(payload): Json<AdminLoginRequest>,
) -> Result<Json<Value>, AppError> {
let runtime = state.admin_runtime().ok_or_else(|| {
AppError::from_status(StatusCode::SERVICE_UNAVAILABLE).with_message("后台管理未启用")
})?;
let expected_username = runtime.username().trim();
let expected_password = runtime.password().trim();
let submitted_username = payload.username.trim();
let submitted_password = payload.password.trim();
if expected_username.is_empty() || expected_password.is_empty() {
return Err(
AppError::from_status(StatusCode::SERVICE_UNAVAILABLE).with_message("后台管理未启用")
);
}
if submitted_username != expected_username || submitted_password != expected_password {
return Err(
AppError::from_status(StatusCode::UNAUTHORIZED).with_message("管理员用户名或密码错误")
);
}
let now = OffsetDateTime::now_utc();
let claims = runtime.build_claims(now).map_err(|error| {
AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR).with_message(error)
})?;
let token = runtime.sign_token(&claims).map_err(|error| {
AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR).with_message(error)
})?;
Ok(json_success_body(
Some(&request_context),
AdminLoginResponse {
token,
admin: build_admin_session_payload(runtime.build_session(&claims)),
},
))
}
pub async fn admin_me(
Extension(request_context): Extension<RequestContext>,
Extension(admin): Extension<AuthenticatedAdmin>,
) -> Json<Value> {
json_success_body(
Some(&request_context),
AdminMeResponse {
admin: admin.session().clone(),
},
)
}
pub async fn admin_overview(
State(state): State<AppState>,
Extension(request_context): Extension<RequestContext>,
Extension(_admin): Extension<AuthenticatedAdmin>,
) -> Result<Json<Value>, AppError> {
let runtime = state.admin_runtime().ok_or_else(|| {
AppError::from_status(StatusCode::SERVICE_UNAVAILABLE).with_message("后台管理未启用")
})?;
let overview = build_admin_overview(&state, runtime).await?;
Ok(json_success_body(Some(&request_context), overview))
}
pub async fn admin_debug_http(
State(state): State<AppState>,
Extension(request_context): Extension<RequestContext>,
Extension(_admin): Extension<AuthenticatedAdmin>,
Json(payload): Json<AdminDebugHttpRequest>,
) -> Result<Json<Value>, AppError> {
let response = execute_admin_debug_http(&state, payload).await?;
Ok(json_success_body(Some(&request_context), response))
}
pub async fn require_admin_auth(
State(state): State<AppState>,
mut request: Request,
next: Next,
) -> Result<Response, AppError> {
// 后台鉴权必须同时满足令牌验签通过、主体匹配当前管理员、roles 含 admin。
let runtime = state.admin_runtime().ok_or_else(|| {
AppError::from_status(StatusCode::SERVICE_UNAVAILABLE).with_message("后台管理未启用")
})?;
let bearer_token = extract_bearer_token(request.headers())?;
let claims = runtime
.verify_token(&bearer_token)
.map_err(|error| AppError::from_status(StatusCode::UNAUTHORIZED).with_message(error))?;
let admin_session = runtime
.validate_claims(&claims)
.map_err(|error| AppError::from_status(StatusCode::FORBIDDEN).with_message(error))?;
request
.extensions_mut()
.insert(AuthenticatedAdmin::new(build_admin_session_payload(
admin_session,
)));
Ok(next.run(request).await)
}
fn extract_bearer_token(headers: &HeaderMap) -> Result<String, AppError> {
let authorization = headers
.get(AUTHORIZATION)
.and_then(|value| value.to_str().ok())
.map(str::trim)
.ok_or_else(|| AppError::from_status(StatusCode::UNAUTHORIZED))?;
let token = authorization
.strip_prefix("Bearer ")
.or_else(|| authorization.strip_prefix("bearer "))
.map(str::trim)
.filter(|token| !token.is_empty())
.ok_or_else(|| AppError::from_status(StatusCode::UNAUTHORIZED))?;
Ok(token.to_string())
}
async fn build_admin_overview(
state: &AppState,
runtime: &AdminRuntime,
) -> Result<AdminOverviewResponse, AppError> {
let service = AdminServiceOverviewPayload {
bind_host: state.config.bind_host.clone(),
bind_port: state.config.bind_port,
jwt_issuer: state.config.jwt_issuer.clone(),
admin_enabled: runtime.is_enabled(),
spacetime_server_url: state.config.spacetime_server_url.clone(),
spacetime_database: state.config.spacetime_database.clone(),
};
let database = fetch_database_overview(state).await;
Ok(AdminOverviewResponse { service, database })
}
async fn fetch_database_overview(state: &AppState) -> AdminDatabaseOverviewPayload {
// 概览直接读取 SpacetimeDB HTTP API保证后台看到的是真实数据库元信息而不是本地缓存。
let client = Client::new();
let server_root = state.config.spacetime_server_url.trim_end_matches('/');
let database = state.config.spacetime_database.trim();
let token = state
.config
.spacetime_token
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty());
let mut fetch_errors = Vec::new();
let database_info = fetch_spacetime_json::<SpacetimeDatabaseInfoResponse>(
&client,
&format!("{server_root}/v1/database/{database}"),
token,
)
.await
.map_err(|error| fetch_errors.push(format!("数据库信息读取失败:{error}")))
.ok()
.flatten();
let schema = fetch_spacetime_json::<SpacetimeSchemaResponse>(
&client,
&build_spacetime_schema_url(server_root, database),
token,
)
.await
.map_err(|error| fetch_errors.push(format!("数据库 schema 读取失败:{error}")))
.ok()
.flatten();
let mut schema_table_names = schema
.as_ref()
.and_then(|value| value.tables.as_ref())
.map(|tables| {
tables
.iter()
.filter_map(|table| table.name.as_deref())
.map(str::trim)
.filter(|name| !name.is_empty())
.map(ToOwned::to_owned)
.collect::<BTreeSet<_>>()
.into_iter()
.collect::<Vec<_>>()
})
.unwrap_or_default();
let mut table_stats = Vec::new();
for table_name in DATABASE_OVERVIEW_TABLES {
let sql = format!("SELECT COUNT(*) AS row_count FROM {table_name}");
match fetch_spacetime_sql_count(&client, server_root, database, token, &sql).await {
Ok(row_count) => table_stats.push(AdminDatabaseTableStatPayload {
table_name: (*table_name).to_string(),
row_count: Some(row_count),
error_message: None,
}),
Err(error) => {
table_stats.push(AdminDatabaseTableStatPayload {
table_name: (*table_name).to_string(),
row_count: None,
error_message: Some(error),
});
}
}
}
for table_name in DATABASE_OVERVIEW_TABLES {
if !schema_table_names.iter().any(|name| name == table_name) {
schema_table_names.push((*table_name).to_string());
}
}
schema_table_names.sort();
AdminDatabaseOverviewPayload {
database_identity: database_info
.as_ref()
.and_then(|value| value.database_identity.clone()),
owner_identity: database_info
.as_ref()
.and_then(|value| value.owner_identity.clone()),
host_type: database_info
.as_ref()
.and_then(|value| value.host_type.clone()),
schema_table_names,
table_stats,
fetch_errors,
}
}
fn build_spacetime_schema_url(server_root: &str, database: &str) -> String {
format!("{server_root}/v1/database/{database}/schema?{SPACETIME_SCHEMA_VERSION_QUERY}")
}
async fn fetch_spacetime_json<T>(
client: &Client,
url: &str,
token: Option<&str>,
) -> Result<Option<T>, String>
where
T: for<'de> Deserialize<'de>,
{
let mut request = client.get(url);
if let Some(token) = token {
request = request.bearer_auth(token);
}
let response = request
.send()
.await
.map_err(|error| format!("请求失败:{error}"))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(format!("HTTP {}{}", status.as_u16(), trim_preview(&body)));
}
response
.json::<T>()
.await
.map(Some)
.map_err(|error| format!("响应解析失败:{error}"))
}
async fn fetch_spacetime_sql_count(
client: &Client,
server_root: &str,
database: &str,
token: Option<&str>,
sql: &str,
) -> Result<u64, String> {
let mut request = client
.post(format!("{server_root}/v1/database/{database}/sql"))
.header(CONTENT_TYPE, "text/plain; charset=utf-8")
.body(sql.to_string());
if let Some(token) = token {
request = request.bearer_auth(token);
}
let response = request
.send()
.await
.map_err(|error| format!("SQL 请求失败:{error}"))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(format!("HTTP {}{}", status.as_u16(), trim_preview(&body)));
}
let payload = response
.json::<Value>()
.await
.map_err(|error| format!("SQL 响应解析失败:{error}"))?;
parse_spacetime_sql_count_response(payload)
}
fn parse_spacetime_sql_count_response(payload: Value) -> Result<u64, String> {
match payload {
// SpacetimeDB 2.x /sql 返回 statement result 数组,每个 result 内含 schema 与 rows。
Value::Array(statements) => {
let statement = statements
.into_iter()
.next()
.ok_or_else(|| "SQL 结果为空".to_string())?;
extract_sql_count_from_statement(statement)
}
// 保留兼容旧对象形状,便于本地/远端 API 小版本差异时仍能读取计数。
Value::Object(statement) => extract_sql_count_from_statement(Value::Object(statement)),
_ => Err("SQL 响应格式非法".to_string()),
}
}
fn extract_sql_count_from_statement(statement: Value) -> Result<u64, String> {
let Value::Object(mut statement) = statement else {
return Err("SQL statement 结果格式非法".to_string());
};
let schema = statement.remove("schema");
let rows = statement
.remove("rows")
.ok_or_else(|| "SQL 响应缺少 rows 字段".to_string())?;
extract_sql_count_from_rows(rows, schema.as_ref())
}
fn extract_sql_count_from_rows(rows: Value, schema: Option<&Value>) -> Result<u64, String> {
let Value::Array(rows) = rows else {
return Err("SQL rows 字段格式非法".to_string());
};
let row = rows.first().ok_or_else(|| "SQL 结果为空".to_string())?;
extract_sql_count_from_row(row, schema)
}
fn extract_sql_count_from_row(row: &Value, schema: Option<&Value>) -> Result<u64, String> {
match row {
Value::Object(columns) => extract_sql_count(columns),
Value::Array(values) => {
let count_index = schema.and_then(find_sql_count_column_index).unwrap_or(0);
values
.get(count_index)
.ok_or_else(|| "SQL 结果缺少 count 字段".to_string())
.and_then(parse_count_value)
}
value => parse_count_value(value),
}
}
fn extract_sql_count(columns: &serde_json::Map<String, Value>) -> Result<u64, String> {
for key in ["row_count", "count", "COUNT(*)"] {
if let Some(value) = columns.get(key) {
return parse_count_value(value);
}
}
columns
.values()
.next()
.ok_or_else(|| "SQL 结果缺少 count 字段".to_string())
.and_then(parse_count_value)
}
fn find_sql_count_column_index(schema: &Value) -> Option<usize> {
let elements = schema.get("elements")?.as_array()?;
elements.iter().position(|element| {
element
.get("name")
.and_then(extract_sql_schema_name)
.map(|name| matches!(name, "row_count" | "count" | "COUNT(*)"))
.unwrap_or(false)
})
}
fn extract_sql_schema_name(value: &Value) -> Option<&str> {
match value {
Value::String(text) => Some(text.as_str()),
Value::Object(object) => object.get("some").and_then(Value::as_str),
_ => None,
}
}
fn parse_count_value(value: &Value) -> Result<u64, String> {
match value {
Value::Number(number) => number
.as_u64()
.ok_or_else(|| "count 字段不是无符号整数".to_string()),
Value::String(text) => text
.trim()
.parse::<u64>()
.map_err(|error| format!("count 字段解析失败:{error}")),
_ => Err("count 字段类型非法".to_string()),
}
}
async fn execute_admin_debug_http(
state: &AppState,
payload: AdminDebugHttpRequest,
) -> Result<AdminDebugHttpResponse, AppError> {
// 调试请求始终回打当前 api-server同源受控不允许作为外部代理使用。
let method = Method::from_bytes(payload.method.trim().as_bytes()).map_err(|_| {
AppError::from_status(StatusCode::BAD_REQUEST).with_message("HTTP 方法不合法")
})?;
let path = normalize_debug_path(&payload.path)?;
let base_url = build_debug_base_url(&state.config.bind_host, state.config.bind_port);
let target_url = format!("{base_url}{path}");
let body_text = payload.body.unwrap_or_default();
if body_text.len() > MAX_DEBUG_BODY_BYTES {
return Err(
AppError::from_status(StatusCode::BAD_REQUEST).with_message("调试请求体超过长度限制")
);
}
let client = Client::new();
let mut request = client.request(method, &target_url);
if !body_text.is_empty() {
request = request.body(body_text.clone());
}
for header in payload.headers.unwrap_or_default() {
let header_name = header.name.trim().to_ascii_lowercase();
if BLOCKED_DEBUG_HEADERS
.iter()
.any(|blocked| *blocked == header_name)
{
continue;
}
let name = HeaderName::from_bytes(header_name.as_bytes()).map_err(|_| {
AppError::from_status(StatusCode::BAD_REQUEST).with_message("调试请求头名称不合法")
})?;
let value = HeaderValue::from_str(header.value.trim()).map_err(|_| {
AppError::from_status(StatusCode::BAD_REQUEST).with_message("调试请求头值不合法")
})?;
request = request.header(name, value);
}
let response = request.send().await.map_err(|error| {
AppError::from_status(StatusCode::BAD_GATEWAY)
.with_message(format!("调试请求失败:{error}"))
})?;
let status = response.status();
let headers = response
.headers()
.iter()
.map(|(name, value)| AdminDebugHeaderInput {
name: name.to_string(),
value: value.to_str().unwrap_or_default().to_string(),
})
.collect::<Vec<_>>();
let response_body = response.bytes().await.map_err(|error| {
AppError::from_status(StatusCode::BAD_GATEWAY)
.with_message(format!("调试响应读取失败:{error}"))
})?;
let body_preview = build_body_preview(&response_body);
let body_json = serde_json::from_slice::<Value>(&response_body).ok();
Ok(AdminDebugHttpResponse {
status: status.as_u16(),
status_text: status.canonical_reason().unwrap_or("Unknown").to_string(),
headers,
body_text: body_preview,
body_json,
})
}
fn build_debug_base_url(bind_host: &str, bind_port: u16) -> String {
let debug_host = resolve_debug_host(bind_host);
let authority_host = format_http_authority_host(&debug_host);
format!("http://{authority_host}:{bind_port}")
}
fn resolve_debug_host(bind_host: &str) -> String {
let trimmed = bind_host.trim();
if trimmed.is_empty() {
return Ipv4Addr::LOCALHOST.to_string();
}
match trimmed.parse::<IpAddr>() {
Ok(IpAddr::V4(ip)) if ip.is_unspecified() => Ipv4Addr::LOCALHOST.to_string(),
Ok(IpAddr::V6(ip)) if ip.is_unspecified() => Ipv6Addr::LOCALHOST.to_string(),
Ok(ip) => ip.to_string(),
Err(_) => trimmed.to_string(),
}
}
fn format_http_authority_host(host: &str) -> String {
if host.starts_with('[') && host.ends_with(']') {
return host.to_string();
}
if host.parse::<Ipv6Addr>().is_ok() {
return format!("[{host}]");
}
host.to_string()
}
fn normalize_debug_path(path: &str) -> Result<String, AppError> {
// 只允许 `/xxx` 形式的同源相对路径,明确拒绝绝对 URL 与后台登录接口。
let trimmed = path.trim();
if trimmed.is_empty() {
return Err(AppError::from_status(StatusCode::BAD_REQUEST).with_message("调试路径不能为空"));
}
if trimmed.starts_with("http://") || trimmed.starts_with("https://") {
return Err(
AppError::from_status(StatusCode::BAD_REQUEST).with_message("只允许调试同源相对路径")
);
}
if !trimmed.starts_with('/') {
return Err(
AppError::from_status(StatusCode::BAD_REQUEST).with_message("调试路径必须以 / 开头")
);
}
if trimmed == "/admin/api/login" {
return Err(
AppError::from_status(StatusCode::BAD_REQUEST).with_message("禁止调试后台登录接口")
);
}
Ok(trimmed.to_string())
}
fn build_body_preview(bytes: &[u8]) -> String {
if bytes.is_empty() {
return String::new();
}
let text = String::from_utf8_lossy(bytes).to_string();
trim_preview(&text)
}
fn trim_preview(text: &str) -> String {
let trimmed = text.trim();
if trimmed.chars().count() <= 4000 {
return trimmed.to_string();
}
trimmed.chars().take(4000).collect::<String>()
}
fn build_admin_session_payload(session: crate::state::AdminSession) -> AdminSessionPayload {
AdminSessionPayload {
subject: session.subject,
username: session.username,
display_name: session.display_name,
roles: session.roles,
issued_at: session
.issued_at
.format(&Rfc3339)
.unwrap_or_else(|_| "1970-01-01T00:00:00Z".to_string()),
expires_at: session
.expires_at
.format(&Rfc3339)
.unwrap_or_else(|_| "1970-01-01T00:00:00Z".to_string()),
}
}
#[cfg(test)]
mod tests {
use super::{
build_body_preview, build_debug_base_url, build_spacetime_schema_url, normalize_debug_path,
parse_spacetime_sql_count_response, trim_preview,
};
use axum::{http::StatusCode, response::IntoResponse};
use serde_json::json;
#[test]
fn normalize_debug_path_rejects_absolute_url() {
let error =
normalize_debug_path("https://example.com/api").expect_err("absolute url should fail");
assert_eq!(error.into_response().status(), StatusCode::BAD_REQUEST);
}
#[test]
fn normalize_debug_path_rejects_admin_login_route() {
let error =
normalize_debug_path("/admin/api/login").expect_err("admin login route should fail");
assert_eq!(error.into_response().status(), StatusCode::BAD_REQUEST);
}
#[test]
fn normalize_debug_path_accepts_healthz() {
let path = normalize_debug_path("/healthz").expect("healthz path should pass validation");
assert_eq!(path, "/healthz");
}
#[test]
fn build_debug_base_url_rewrites_wildcard_ipv4_to_loopback() {
let url = build_debug_base_url("0.0.0.0", 3200);
assert_eq!(url, "http://127.0.0.1:3200");
}
#[test]
fn build_debug_base_url_wraps_ipv6_host() {
let url = build_debug_base_url("::1", 3200);
assert_eq!(url, "http://[::1]:3200");
}
#[test]
fn trim_preview_limits_length() {
let text = "a".repeat(5000);
assert_eq!(trim_preview(&text).chars().count(), 4000);
}
#[test]
fn build_spacetime_schema_url_includes_required_version_query() {
let url = build_spacetime_schema_url("http://127.0.0.1:3101", "xushi-p4wfr");
assert_eq!(
url,
"http://127.0.0.1:3101/v1/database/xushi-p4wfr/schema?version=9"
);
}
#[test]
fn parse_spacetime_sql_count_response_accepts_statement_array_rows() {
let payload = json!([
{
"schema": {
"elements": [
{
"name": {
"some": "row_count"
},
"algebraic_type": {
"U64": []
}
}
]
},
"rows": [[7]],
"total_duration_micros": 116,
"stats": {
"rows_inserted": 0,
"rows_deleted": 0,
"rows_updated": 0
}
}
]);
let count =
parse_spacetime_sql_count_response(payload).expect("statement array should parse");
assert_eq!(count, 7);
}
#[test]
fn parse_spacetime_sql_count_response_uses_schema_column_index() {
let payload = json!([
{
"schema": {
"elements": [
{
"name": {
"some": "table_name"
}
},
{
"name": {
"some": "row_count"
}
}
]
},
"rows": [["runtime_setting", "12"]]
}
]);
let count =
parse_spacetime_sql_count_response(payload).expect("schema column index should parse");
assert_eq!(count, 12);
}
#[test]
fn parse_spacetime_sql_count_response_keeps_object_row_compatibility() {
let payload = json!({
"rows": [
{
"row_count": "3"
}
]
});
let count = parse_spacetime_sql_count_response(payload).expect("object row should parse");
assert_eq!(count, 3);
}
#[test]
fn build_body_preview_handles_utf8() {
let preview = build_body_preview("后台测试".as_bytes());
assert_eq!(preview, "后台测试");
}
}