Files
Genarrative/server-rs/crates/api-server/src/state.rs
kdletters cbc27bad4a
Some checks failed
CI / verify (push) Has been cancelled
init with react+axum+spacetimedb
2026-04-26 18:06:23 +08:00

770 lines
26 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use std::{error::Error, fmt, sync::Arc};
#[cfg(test)]
use std::{collections::HashMap, sync::Mutex};
use module_ai::{AiTaskService, InMemoryAiTaskStore};
use module_auth::{
AuthUserService, InMemoryAuthStore, PasswordEntryService, PhoneAuthService,
RefreshSessionService, WechatAuthService, WechatAuthStateService,
};
use module_runtime::RuntimeSnapshotRecord;
#[cfg(test)]
use module_runtime::{SAVE_SNAPSHOT_VERSION, format_utc_micros};
use platform_auth::{
AccessTokenClaims, AccessTokenClaimsInput, AuthProvider, BindingStatus, JwtConfig, JwtError,
RefreshCookieConfig, RefreshCookieError, RefreshCookieSameSite, SmsAuthConfig, SmsAuthProvider,
SmsAuthProviderKind, SmsProviderError, sign_access_token, verify_access_token,
};
use platform_llm::{LlmClient, LlmConfig, LlmError};
use platform_oss::{OssClient, OssConfig, OssError};
use serde_json::Value;
use spacetime_client::{SpacetimeClient, SpacetimeClientConfig, SpacetimeClientError};
use time::OffsetDateTime;
use tracing::{info, warn};
use crate::config::AppConfig;
use crate::wechat_provider::{WechatProvider, build_wechat_provider};
const ADMIN_ROLE: &str = "admin";
// 当前阶段先保留最小共享状态壳,后续逐步接入配置、客户端与平台适配。
#[derive(Clone, Debug)]
pub struct AppState {
// 配置会在后续中间件、路由和平台适配接入时逐步消费。
#[allow(dead_code)]
pub config: AppConfig,
auth_jwt_config: JwtConfig,
admin_runtime: Option<AdminRuntime>,
refresh_cookie_config: RefreshCookieConfig,
oss_client: Option<OssClient>,
#[cfg_attr(test, allow(dead_code))]
auth_store: InMemoryAuthStore,
password_entry_service: PasswordEntryService,
refresh_session_service: RefreshSessionService,
auth_user_service: AuthUserService,
phone_auth_service: PhoneAuthService,
wechat_auth_state_service: WechatAuthStateService,
wechat_auth_service: WechatAuthService,
wechat_provider: WechatProvider,
#[cfg_attr(not(test), allow(dead_code))]
ai_task_service: AiTaskService,
spacetime_client: SpacetimeClient,
llm_client: Option<LlmClient>,
#[cfg(test)]
// 测试环境允许在未启动 SpacetimeDB 时,用内存快照兜底当前 runtime story 回归链。
test_runtime_snapshot_store: Arc<Mutex<HashMap<String, RuntimeSnapshotRecord>>>,
}
// 后台管理员运行态独立于普通玩家登录体系,只从环境变量构造。
#[derive(Clone, Debug)]
pub struct AdminRuntime {
username: Arc<str>,
password: Arc<str>,
subject: Arc<str>,
display_name: Arc<str>,
token_ttl_seconds: u64,
jwt_config: JwtConfig,
}
#[derive(Clone, Debug)]
pub struct AdminClaims {
pub subject: String,
pub username: String,
pub issued_at: OffsetDateTime,
pub expires_at: OffsetDateTime,
}
#[derive(Clone, Debug)]
pub struct AdminSession {
pub subject: String,
pub username: String,
pub display_name: String,
pub roles: Vec<String>,
pub issued_at: OffsetDateTime,
pub expires_at: OffsetDateTime,
}
#[derive(Debug)]
pub enum AppStateInitError {
Jwt(JwtError),
RefreshCookie(RefreshCookieError),
AuthStore(String),
SmsProvider(SmsProviderError),
Oss(OssError),
Llm(LlmError),
}
impl AppState {
pub fn new(config: AppConfig) -> Result<Self, AppStateInitError> {
#[cfg(test)]
let auth_store = InMemoryAuthStore::default();
#[cfg(not(test))]
let auth_store = InMemoryAuthStore::from_persistence_path(config.auth_store_path.clone())
.map_err(AppStateInitError::AuthStore)?;
Self::new_with_auth_store(config, auth_store)
}
fn new_with_auth_store(
config: AppConfig,
auth_store: InMemoryAuthStore,
) -> Result<Self, AppStateInitError> {
let auth_jwt_config = JwtConfig::new(
config.jwt_issuer.clone(),
config.jwt_secret.clone(),
config.jwt_access_token_ttl_seconds,
)?;
let admin_runtime = build_admin_runtime(&config, &auth_jwt_config)?;
let refresh_cookie_same_site =
RefreshCookieSameSite::parse(&config.refresh_cookie_same_site).ok_or(
RefreshCookieError::InvalidConfig("refresh cookie SameSite 取值非法"),
)?;
let refresh_cookie_config = RefreshCookieConfig::new(
config.refresh_cookie_name.clone(),
config.refresh_cookie_path.clone(),
config.refresh_cookie_secure,
refresh_cookie_same_site,
config.refresh_session_ttl_days,
)?;
let oss_client = build_oss_client(&config)?;
let sms_provider = SmsAuthProvider::new(SmsAuthConfig::new(
SmsAuthProviderKind::parse(&config.sms_auth_provider).ok_or_else(|| {
SmsProviderError::InvalidConfig("短信 provider 配置非法".to_string())
})?,
config.sms_endpoint.clone(),
config.sms_access_key_id.clone(),
config.sms_access_key_secret.clone(),
config.sms_sign_name.clone(),
config.sms_template_code.clone(),
config.sms_template_param_key.clone(),
config.sms_country_code.clone(),
config.sms_scheme_name.clone(),
config.sms_code_length,
config.sms_code_type,
config.sms_valid_time_seconds,
config.sms_interval_seconds,
config.sms_duplicate_policy,
config.sms_case_auth_policy,
config.sms_return_verify_code,
config.sms_mock_verify_code.clone(),
)?)?;
let password_entry_service = PasswordEntryService::new(auth_store.clone());
let auth_user_service = AuthUserService::new(auth_store.clone());
let phone_auth_service = PhoneAuthService::new(auth_store.clone(), sms_provider);
let wechat_auth_state_service =
WechatAuthStateService::new(auth_store.clone(), config.wechat_state_ttl_minutes);
let wechat_auth_service = WechatAuthService::new(auth_store.clone());
let wechat_provider = build_wechat_provider(&config);
let refresh_session_service =
RefreshSessionService::new(auth_store.clone(), config.refresh_session_ttl_days);
// AI 编排服务当前先挂接内存态 store后续再按 task table / procedure 接到 SpacetimeDB 真相源。
let ai_task_service = AiTaskService::new(InMemoryAiTaskStore::default());
let spacetime_client = SpacetimeClient::new(SpacetimeClientConfig {
server_url: config.spacetime_server_url.clone(),
database: config.spacetime_database.clone(),
token: config.spacetime_token.clone(),
pool_size: config.spacetime_pool_size,
});
let llm_client = build_llm_client(&config)?;
Ok(Self {
config,
auth_jwt_config,
admin_runtime,
refresh_cookie_config,
oss_client,
auth_store,
password_entry_service,
refresh_session_service,
auth_user_service,
phone_auth_service,
wechat_auth_state_service,
wechat_auth_service,
wechat_provider,
ai_task_service,
spacetime_client,
llm_client,
#[cfg(test)]
test_runtime_snapshot_store: Arc::new(Mutex::new(HashMap::new())),
})
}
pub fn auth_jwt_config(&self) -> &JwtConfig {
&self.auth_jwt_config
}
pub fn admin_runtime(&self) -> Option<&AdminRuntime> {
self.admin_runtime.as_ref()
}
pub fn refresh_cookie_config(&self) -> &RefreshCookieConfig {
&self.refresh_cookie_config
}
pub fn oss_client(&self) -> Option<&OssClient> {
self.oss_client.as_ref()
}
pub fn password_entry_service(&self) -> &PasswordEntryService {
&self.password_entry_service
}
pub async fn sync_auth_store_snapshot_to_spacetime(&self) -> Result<(), SpacetimeClientError> {
#[cfg(test)]
return Ok(());
#[cfg(not(test))]
let snapshot_json = self
.auth_store
.export_snapshot_json()
.map_err(SpacetimeClientError::Runtime)?;
#[cfg(not(test))]
let updated_at_micros = i64::try_from(
OffsetDateTime::now_utc().unix_timestamp_nanos() / 1_000,
)
.map_err(|_| SpacetimeClientError::Runtime("认证快照更新时间超出 i64 范围".to_string()))?;
#[cfg(not(test))]
self.spacetime_client
.upsert_auth_store_snapshot(snapshot_json, updated_at_micros)
.await?;
// ?????????????????????????????????
#[cfg(not(test))]
self.spacetime_client.import_auth_store_snapshot().await?;
#[cfg(not(test))]
Ok(())
}
pub async fn try_restore_auth_store_from_spacetime(
config: AppConfig,
) -> Result<Self, AppStateInitError> {
let spacetime_client = SpacetimeClient::new(SpacetimeClientConfig {
server_url: config.spacetime_server_url.clone(),
database: config.spacetime_database.clone(),
token: config.spacetime_token.clone(),
pool_size: config.spacetime_pool_size,
});
match spacetime_client
.export_auth_store_snapshot_from_tables()
.await
{
Ok(snapshot) => {
if let Some(snapshot_json) = snapshot.snapshot_json {
if !snapshot_json.trim().is_empty() {
let auth_store = InMemoryAuthStore::from_snapshot_json(&snapshot_json)
.map_err(AppStateInitError::AuthStore)?;
info!("?? SpacetimeDB ???????????");
return Self::new_with_auth_store(config, auth_store);
}
}
}
Err(error) => {
warn!(error = %error, "? SpacetimeDB ????????????????");
}
}
match spacetime_client.get_auth_store_snapshot().await {
Ok(snapshot) => {
if let Some(snapshot_json) = snapshot.snapshot_json {
if !snapshot_json.trim().is_empty() {
let auth_store = InMemoryAuthStore::from_snapshot_json(&snapshot_json)
.map_err(AppStateInitError::AuthStore)?;
info!("?? SpacetimeDB ???????????");
return Self::new_with_auth_store(config, auth_store);
}
}
}
Err(error) => {
warn!(error = %error, "? SpacetimeDB ?????????????????");
}
}
Self::new(config)
}
pub fn refresh_session_service(&self) -> &RefreshSessionService {
&self.refresh_session_service
}
pub fn auth_user_service(&self) -> &AuthUserService {
&self.auth_user_service
}
pub fn phone_auth_service(&self) -> &PhoneAuthService {
&self.phone_auth_service
}
pub fn wechat_auth_state_service(&self) -> &WechatAuthStateService {
&self.wechat_auth_state_service
}
pub fn wechat_auth_service(&self) -> &WechatAuthService {
&self.wechat_auth_service
}
pub fn wechat_provider(&self) -> &WechatProvider {
&self.wechat_provider
}
#[cfg_attr(not(test), allow(dead_code))]
pub fn ai_task_service(&self) -> &AiTaskService {
&self.ai_task_service
}
pub fn spacetime_client(&self) -> &SpacetimeClient {
&self.spacetime_client
}
pub fn llm_client(&self) -> Option<&LlmClient> {
self.llm_client.as_ref()
}
pub async fn get_runtime_snapshot_record(
&self,
user_id: String,
) -> Result<Option<RuntimeSnapshotRecord>, SpacetimeClientError> {
match self
.spacetime_client
.get_runtime_snapshot(user_id.clone())
.await
{
Ok(record) => {
#[cfg(test)]
if let Some(snapshot) = record.as_ref() {
self.cache_test_runtime_snapshot(snapshot.clone());
}
Ok(record)
}
#[cfg(test)]
Err(_) => Ok(self.read_test_runtime_snapshot(user_id.as_str())),
#[cfg(not(test))]
Err(error) => Err(error),
}
}
pub async fn put_runtime_snapshot_record(
&self,
user_id: String,
saved_at_micros: i64,
bottom_tab: String,
game_state: Value,
current_story: Option<Value>,
updated_at_micros: i64,
) -> Result<RuntimeSnapshotRecord, SpacetimeClientError> {
match self
.spacetime_client
.put_runtime_snapshot(
user_id.clone(),
saved_at_micros,
bottom_tab.clone(),
game_state.clone(),
current_story.clone(),
updated_at_micros,
)
.await
{
Ok(record) => {
#[cfg(test)]
self.cache_test_runtime_snapshot(record.clone());
Ok(record)
}
#[cfg(test)]
Err(_) => {
let snapshot = self.build_test_runtime_snapshot_record(
user_id,
saved_at_micros,
bottom_tab,
game_state,
current_story,
updated_at_micros,
)?;
self.cache_test_runtime_snapshot(snapshot.clone());
Ok(snapshot)
}
#[cfg(not(test))]
Err(error) => Err(error),
}
}
pub async fn delete_runtime_snapshot_record(
&self,
user_id: String,
) -> Result<bool, SpacetimeClientError> {
match self
.spacetime_client
.delete_runtime_snapshot(user_id.clone())
.await
{
Ok(deleted) => {
#[cfg(test)]
if deleted {
self.remove_test_runtime_snapshot(user_id.as_str());
}
Ok(deleted)
}
#[cfg(test)]
Err(_) => Ok(self
.remove_test_runtime_snapshot(user_id.as_str())
.is_some()),
#[cfg(not(test))]
Err(error) => Err(error),
}
}
}
#[cfg(test)]
impl AppState {
pub(crate) async fn seed_test_phone_user_with_password(
&self,
phone_number: &str,
password: &str,
) -> module_auth::AuthUser {
let now = OffsetDateTime::now_utc();
self.phone_auth_service()
.send_code(
module_auth::SendPhoneCodeInput {
phone_number: phone_number.to_string(),
scene: module_auth::PhoneAuthScene::Login,
},
now,
)
.await
.expect("test phone code should send");
let user = self
.phone_auth_service()
.login(
module_auth::PhoneLoginInput {
phone_number: phone_number.to_string(),
verify_code: "123456".to_string(),
},
now + time::Duration::seconds(1),
)
.await
.expect("test phone login should create user")
.user;
let changed = self
.password_entry_service()
.change_password(module_auth::ChangePasswordInput {
user_id: user.id.clone(),
current_password: None,
new_password: password.to_string(),
})
.await
.expect("test password should set");
changed.user
}
fn cache_test_runtime_snapshot(&self, record: RuntimeSnapshotRecord) {
self.test_runtime_snapshot_store
.lock()
.expect("test runtime snapshot store should lock")
.insert(record.user_id.clone(), record);
}
fn read_test_runtime_snapshot(&self, user_id: &str) -> Option<RuntimeSnapshotRecord> {
self.test_runtime_snapshot_store
.lock()
.expect("test runtime snapshot store should lock")
.get(user_id)
.cloned()
}
fn remove_test_runtime_snapshot(&self, user_id: &str) -> Option<RuntimeSnapshotRecord> {
self.test_runtime_snapshot_store
.lock()
.expect("test runtime snapshot store should lock")
.remove(user_id)
}
fn build_test_runtime_snapshot_record(
&self,
user_id: String,
saved_at_micros: i64,
bottom_tab: String,
game_state: Value,
current_story: Option<Value>,
updated_at_micros: i64,
) -> Result<RuntimeSnapshotRecord, SpacetimeClientError> {
let previous = self.read_test_runtime_snapshot(user_id.as_str());
let game_state_json = serde_json::to_string(&game_state).map_err(|error| {
SpacetimeClientError::Runtime(format!("测试快照 game_state 序列化失败: {error}"))
})?;
let current_story_json = current_story
.as_ref()
.map(serde_json::to_string)
.transpose()
.map_err(|error| {
SpacetimeClientError::Runtime(format!("测试快照 current_story 序列化失败: {error}"))
})?;
Ok(RuntimeSnapshotRecord {
user_id,
version: SAVE_SNAPSHOT_VERSION,
saved_at: format_utc_micros(saved_at_micros),
saved_at_micros,
bottom_tab,
game_state,
current_story,
game_state_json,
current_story_json,
created_at_micros: previous
.as_ref()
.map(|record| record.created_at_micros)
.unwrap_or(updated_at_micros),
updated_at_micros,
})
}
}
impl fmt::Display for AppStateInitError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Jwt(error) => write!(f, "{error}"),
Self::RefreshCookie(error) => write!(f, "{error}"),
Self::AuthStore(error) => write!(f, "{error}"),
Self::SmsProvider(error) => write!(f, "{error}"),
Self::Oss(error) => write!(f, "{error}"),
Self::Llm(error) => write!(f, "{error}"),
}
}
}
impl Error for AppStateInitError {}
impl From<JwtError> for AppStateInitError {
fn from(value: JwtError) -> Self {
Self::Jwt(value)
}
}
impl From<RefreshCookieError> for AppStateInitError {
fn from(value: RefreshCookieError) -> Self {
Self::RefreshCookie(value)
}
}
impl From<SmsProviderError> for AppStateInitError {
fn from(value: SmsProviderError) -> Self {
Self::SmsProvider(value)
}
}
impl From<OssError> for AppStateInitError {
fn from(value: OssError) -> Self {
Self::Oss(value)
}
}
impl From<LlmError> for AppStateInitError {
fn from(value: LlmError) -> Self {
Self::Llm(value)
}
}
impl AdminRuntime {
pub fn is_enabled(&self) -> bool {
!self.username.trim().is_empty() && !self.password.trim().is_empty()
}
pub fn username(&self) -> &str {
&self.username
}
pub fn password(&self) -> &str {
&self.password
}
pub fn build_claims(&self, now: OffsetDateTime) -> Result<AdminClaims, String> {
let expires_at = now
.checked_add(time::Duration::seconds(
i64::try_from(self.token_ttl_seconds)
.map_err(|_| "后台 token TTL 超出 i64 上限".to_string())?,
))
.ok_or_else(|| "后台 token 过期时间计算溢出".to_string())?;
Ok(AdminClaims {
subject: self.subject.to_string(),
username: self.username.to_string(),
issued_at: now,
expires_at,
})
}
pub fn sign_token(&self, claims: &AdminClaims) -> Result<String, String> {
let jwt_claims = AccessTokenClaims::from_input(
AccessTokenClaimsInput {
user_id: claims.subject.clone(),
session_id: format!("admin-session-{}", claims.username),
provider: AuthProvider::Password,
roles: vec![ADMIN_ROLE.to_string()],
token_version: 1,
phone_verified: false,
binding_status: BindingStatus::Active,
display_name: Some(self.display_name.to_string()),
},
&self.jwt_config,
claims.issued_at,
)
.map_err(|error| error.to_string())?;
sign_access_token(&jwt_claims, &self.jwt_config).map_err(|error| error.to_string())
}
pub fn verify_token(&self, token: &str) -> Result<AccessTokenClaims, String> {
verify_access_token(token, &self.jwt_config).map_err(|error| error.to_string())
}
pub fn validate_claims(&self, claims: &AccessTokenClaims) -> Result<AdminSession, String> {
if claims.user_id() != self.subject.as_ref() {
return Err("后台管理员主体不匹配".to_string());
}
if !claims.roles.iter().any(|role| role == ADMIN_ROLE) {
return Err("当前令牌不是管理员令牌".to_string());
}
let issued_at = OffsetDateTime::from_unix_timestamp(claims.iat as i64)
.map_err(|_| "后台令牌签发时间无效".to_string())?;
let expires_at = OffsetDateTime::from_unix_timestamp(claims.exp as i64)
.map_err(|_| "后台令牌过期时间无效".to_string())?;
Ok(AdminSession {
subject: claims.user_id().to_string(),
username: self.username.to_string(),
display_name: self.display_name.to_string(),
roles: claims.roles.clone(),
issued_at,
expires_at,
})
}
pub fn build_session(&self, claims: &AdminClaims) -> AdminSession {
AdminSession {
subject: claims.subject.clone(),
username: claims.username.clone(),
display_name: self.display_name.to_string(),
roles: vec![ADMIN_ROLE.to_string()],
issued_at: claims.issued_at,
expires_at: claims.expires_at,
}
}
}
fn build_oss_client(config: &AppConfig) -> Result<Option<OssClient>, AppStateInitError> {
let has_any_oss_field = config.oss_bucket.is_some()
|| config.oss_endpoint.is_some()
|| config.oss_access_key_id.is_some()
|| config.oss_access_key_secret.is_some();
if !has_any_oss_field {
return Ok(None);
}
let oss_config = OssConfig::new(
config.oss_bucket.clone().unwrap_or_default(),
config.oss_endpoint.clone().unwrap_or_default(),
config.oss_access_key_id.clone().unwrap_or_default(),
config.oss_access_key_secret.clone().unwrap_or_default(),
config.oss_read_expire_seconds,
config.oss_post_expire_seconds,
config.oss_post_max_size_bytes,
config.oss_success_action_status,
)?;
Ok(Some(OssClient::new(oss_config)))
}
fn build_llm_client(config: &AppConfig) -> Result<Option<LlmClient>, AppStateInitError> {
let Some(api_key) = config
.llm_api_key
.as_ref()
.map(|value| value.trim())
.filter(|value| !value.is_empty())
else {
return Ok(None);
};
let llm_config = LlmConfig::new(
config.llm_provider,
config.llm_base_url.clone(),
api_key.to_string(),
config.llm_model.clone(),
config.llm_request_timeout_ms,
config.llm_max_retries,
config.llm_retry_backoff_ms,
)?;
Ok(Some(LlmClient::new(llm_config)?))
}
// 只有在用户名和密码都已配置时才启用后台,避免半配置状态暴露伪入口。
fn build_admin_runtime(
config: &AppConfig,
base_jwt_config: &JwtConfig,
) -> Result<Option<AdminRuntime>, AppStateInitError> {
let Some(username) = config
.admin_username
.as_ref()
.map(|value| value.trim())
.filter(|value| !value.is_empty())
else {
return Ok(None);
};
let Some(password) = config
.admin_password
.as_ref()
.map(|value| value.trim())
.filter(|value| !value.is_empty())
else {
return Ok(None);
};
let jwt_config = JwtConfig::new(
base_jwt_config.issuer().to_string(),
config.jwt_secret.clone(),
config.admin_token_ttl_seconds,
)?;
Ok(Some(AdminRuntime {
username: Arc::<str>::from(username),
password: Arc::<str>::from(password),
subject: Arc::<str>::from(format!("admin:{username}")),
display_name: Arc::<str>::from(format!("管理员 {username}")),
token_ttl_seconds: config.admin_token_ttl_seconds,
jwt_config,
}))
}
#[cfg(test)]
mod tests {
use module_ai::{AiTaskKind, generate_ai_task_id};
use super::*;
#[test]
fn app_state_exposes_usable_ai_task_service() {
let state = AppState::new(AppConfig::default()).expect("state should build");
let task_id = generate_ai_task_id(1_713_680_000_000_000);
let created = state
.ai_task_service()
.create_task(module_ai::AiTaskCreateInput {
task_id: task_id.clone(),
task_kind: AiTaskKind::StoryGeneration,
owner_user_id: "user_001".to_string(),
request_label: "营地开场".to_string(),
source_module: "story".to_string(),
source_entity_id: Some("storysess_001".to_string()),
request_payload_json: Some("{\"scene\":\"camp\"}".to_string()),
stages: AiTaskKind::StoryGeneration.default_stage_blueprints(),
created_at_micros: 1_713_680_000_000_000,
})
.expect("ai task should create");
assert_eq!(created.task_id, task_id);
assert_eq!(created.task_kind, AiTaskKind::StoryGeneration);
assert_eq!(created.stages.len(), 4);
}
#[test]
fn app_state_skips_llm_client_when_api_key_missing() {
let state = AppState::new(AppConfig::default()).expect("state should build");
assert!(state.llm_client().is_none());
}
}