481 lines
16 KiB
Rust
481 lines
16 KiB
Rust
use std::{error::Error, fmt};
|
||
|
||
#[cfg(test)]
|
||
use std::{
|
||
collections::HashMap,
|
||
sync::{Arc, Mutex},
|
||
};
|
||
|
||
use module_ai::{AiTaskService, InMemoryAiTaskStore};
|
||
use module_auth::{
|
||
AuthUserService, InMemoryAuthStore, PasswordEntryService, PhoneAuthService,
|
||
RefreshSessionService, WechatAuthService, WechatAuthStateService,
|
||
};
|
||
use module_runtime::RuntimeSnapshotRecord;
|
||
#[cfg(test)]
|
||
use module_runtime::{SAVE_SNAPSHOT_VERSION, format_utc_micros};
|
||
use platform_auth::{
|
||
JwtConfig, JwtError, RefreshCookieConfig, RefreshCookieError, RefreshCookieSameSite,
|
||
SmsAuthConfig, SmsAuthProvider, SmsAuthProviderKind, SmsProviderError,
|
||
};
|
||
use platform_llm::{LlmClient, LlmConfig, LlmError};
|
||
use platform_oss::{OssClient, OssConfig, OssError};
|
||
use serde_json::Value;
|
||
use spacetime_client::{SpacetimeClient, SpacetimeClientConfig, SpacetimeClientError};
|
||
|
||
use crate::config::AppConfig;
|
||
use crate::wechat_provider::{WechatProvider, build_wechat_provider};
|
||
|
||
// 当前阶段先保留最小共享状态壳,后续逐步接入配置、客户端与平台适配。
|
||
#[derive(Clone, Debug)]
|
||
pub struct AppState {
|
||
// 配置会在后续中间件、路由和平台适配接入时逐步消费。
|
||
#[allow(dead_code)]
|
||
pub config: AppConfig,
|
||
auth_jwt_config: JwtConfig,
|
||
refresh_cookie_config: RefreshCookieConfig,
|
||
oss_client: Option<OssClient>,
|
||
password_entry_service: PasswordEntryService,
|
||
refresh_session_service: RefreshSessionService,
|
||
auth_user_service: AuthUserService,
|
||
phone_auth_service: PhoneAuthService,
|
||
wechat_auth_state_service: WechatAuthStateService,
|
||
wechat_auth_service: WechatAuthService,
|
||
wechat_provider: WechatProvider,
|
||
#[cfg_attr(not(test), allow(dead_code))]
|
||
ai_task_service: AiTaskService,
|
||
spacetime_client: SpacetimeClient,
|
||
llm_client: Option<LlmClient>,
|
||
#[cfg(test)]
|
||
// 测试环境允许在未启动 SpacetimeDB 时,用内存快照兜底当前 runtime story 回归链。
|
||
test_runtime_snapshot_store: Arc<Mutex<HashMap<String, RuntimeSnapshotRecord>>>,
|
||
}
|
||
|
||
#[derive(Debug)]
|
||
pub enum AppStateInitError {
|
||
Jwt(JwtError),
|
||
RefreshCookie(RefreshCookieError),
|
||
SmsProvider(SmsProviderError),
|
||
Oss(OssError),
|
||
Llm(LlmError),
|
||
}
|
||
|
||
impl AppState {
|
||
pub fn new(config: AppConfig) -> Result<Self, AppStateInitError> {
|
||
let auth_jwt_config = JwtConfig::new(
|
||
config.jwt_issuer.clone(),
|
||
config.jwt_secret.clone(),
|
||
config.jwt_access_token_ttl_seconds,
|
||
)?;
|
||
let refresh_cookie_same_site =
|
||
RefreshCookieSameSite::parse(&config.refresh_cookie_same_site).ok_or(
|
||
RefreshCookieError::InvalidConfig("refresh cookie SameSite 取值非法"),
|
||
)?;
|
||
let refresh_cookie_config = RefreshCookieConfig::new(
|
||
config.refresh_cookie_name.clone(),
|
||
config.refresh_cookie_path.clone(),
|
||
config.refresh_cookie_secure,
|
||
refresh_cookie_same_site,
|
||
config.refresh_session_ttl_days,
|
||
)?;
|
||
let oss_client = build_oss_client(&config)?;
|
||
let auth_store = InMemoryAuthStore::default();
|
||
let sms_provider = SmsAuthProvider::new(SmsAuthConfig::new(
|
||
SmsAuthProviderKind::parse(&config.sms_auth_provider).ok_or_else(|| {
|
||
SmsProviderError::InvalidConfig("短信 provider 配置非法".to_string())
|
||
})?,
|
||
config.sms_endpoint.clone(),
|
||
config.sms_access_key_id.clone(),
|
||
config.sms_access_key_secret.clone(),
|
||
config.sms_sign_name.clone(),
|
||
config.sms_template_code.clone(),
|
||
config.sms_template_param_key.clone(),
|
||
config.sms_country_code.clone(),
|
||
config.sms_scheme_name.clone(),
|
||
config.sms_code_length,
|
||
config.sms_code_type,
|
||
config.sms_valid_time_seconds,
|
||
config.sms_interval_seconds,
|
||
config.sms_duplicate_policy,
|
||
config.sms_case_auth_policy,
|
||
config.sms_return_verify_code,
|
||
config.sms_mock_verify_code.clone(),
|
||
)?)?;
|
||
let password_entry_service = PasswordEntryService::new(auth_store.clone());
|
||
let auth_user_service = AuthUserService::new(auth_store.clone());
|
||
let phone_auth_service = PhoneAuthService::new(auth_store.clone(), sms_provider);
|
||
let wechat_auth_state_service =
|
||
WechatAuthStateService::new(auth_store.clone(), config.wechat_state_ttl_minutes);
|
||
let wechat_auth_service = WechatAuthService::new(auth_store.clone());
|
||
let wechat_provider = build_wechat_provider(&config);
|
||
let refresh_session_service =
|
||
RefreshSessionService::new(auth_store, config.refresh_session_ttl_days);
|
||
// AI 编排服务当前先挂接内存态 store,后续再按 task table / procedure 接到 SpacetimeDB 真相源。
|
||
let ai_task_service = AiTaskService::new(InMemoryAiTaskStore::default());
|
||
let spacetime_client = SpacetimeClient::new(SpacetimeClientConfig {
|
||
server_url: config.spacetime_server_url.clone(),
|
||
database: config.spacetime_database.clone(),
|
||
token: config.spacetime_token.clone(),
|
||
});
|
||
let llm_client = build_llm_client(&config)?;
|
||
|
||
Ok(Self {
|
||
config,
|
||
auth_jwt_config,
|
||
refresh_cookie_config,
|
||
oss_client,
|
||
password_entry_service,
|
||
refresh_session_service,
|
||
auth_user_service,
|
||
phone_auth_service,
|
||
wechat_auth_state_service,
|
||
wechat_auth_service,
|
||
wechat_provider,
|
||
ai_task_service,
|
||
spacetime_client,
|
||
llm_client,
|
||
#[cfg(test)]
|
||
test_runtime_snapshot_store: Arc::new(Mutex::new(HashMap::new())),
|
||
})
|
||
}
|
||
|
||
pub fn auth_jwt_config(&self) -> &JwtConfig {
|
||
&self.auth_jwt_config
|
||
}
|
||
|
||
pub fn refresh_cookie_config(&self) -> &RefreshCookieConfig {
|
||
&self.refresh_cookie_config
|
||
}
|
||
|
||
pub fn oss_client(&self) -> Option<&OssClient> {
|
||
self.oss_client.as_ref()
|
||
}
|
||
|
||
pub fn password_entry_service(&self) -> &PasswordEntryService {
|
||
&self.password_entry_service
|
||
}
|
||
|
||
pub fn refresh_session_service(&self) -> &RefreshSessionService {
|
||
&self.refresh_session_service
|
||
}
|
||
|
||
pub fn auth_user_service(&self) -> &AuthUserService {
|
||
&self.auth_user_service
|
||
}
|
||
|
||
pub fn phone_auth_service(&self) -> &PhoneAuthService {
|
||
&self.phone_auth_service
|
||
}
|
||
|
||
pub fn wechat_auth_state_service(&self) -> &WechatAuthStateService {
|
||
&self.wechat_auth_state_service
|
||
}
|
||
|
||
pub fn wechat_auth_service(&self) -> &WechatAuthService {
|
||
&self.wechat_auth_service
|
||
}
|
||
|
||
pub fn wechat_provider(&self) -> &WechatProvider {
|
||
&self.wechat_provider
|
||
}
|
||
|
||
#[cfg_attr(not(test), allow(dead_code))]
|
||
pub fn ai_task_service(&self) -> &AiTaskService {
|
||
&self.ai_task_service
|
||
}
|
||
|
||
pub fn spacetime_client(&self) -> &SpacetimeClient {
|
||
&self.spacetime_client
|
||
}
|
||
|
||
pub fn llm_client(&self) -> Option<&LlmClient> {
|
||
self.llm_client.as_ref()
|
||
}
|
||
|
||
pub async fn get_runtime_snapshot_record(
|
||
&self,
|
||
user_id: String,
|
||
) -> Result<Option<RuntimeSnapshotRecord>, SpacetimeClientError> {
|
||
match self
|
||
.spacetime_client
|
||
.get_runtime_snapshot(user_id.clone())
|
||
.await
|
||
{
|
||
Ok(record) => {
|
||
#[cfg(test)]
|
||
if let Some(snapshot) = record.as_ref() {
|
||
self.cache_test_runtime_snapshot(snapshot.clone());
|
||
}
|
||
Ok(record)
|
||
}
|
||
#[cfg(test)]
|
||
Err(_) => Ok(self.read_test_runtime_snapshot(user_id.as_str())),
|
||
#[cfg(not(test))]
|
||
Err(error) => Err(error),
|
||
}
|
||
}
|
||
|
||
pub async fn put_runtime_snapshot_record(
|
||
&self,
|
||
user_id: String,
|
||
saved_at_micros: i64,
|
||
bottom_tab: String,
|
||
game_state: Value,
|
||
current_story: Option<Value>,
|
||
updated_at_micros: i64,
|
||
) -> Result<RuntimeSnapshotRecord, SpacetimeClientError> {
|
||
match self
|
||
.spacetime_client
|
||
.put_runtime_snapshot(
|
||
user_id.clone(),
|
||
saved_at_micros,
|
||
bottom_tab.clone(),
|
||
game_state.clone(),
|
||
current_story.clone(),
|
||
updated_at_micros,
|
||
)
|
||
.await
|
||
{
|
||
Ok(record) => {
|
||
#[cfg(test)]
|
||
self.cache_test_runtime_snapshot(record.clone());
|
||
Ok(record)
|
||
}
|
||
#[cfg(test)]
|
||
Err(_) => {
|
||
let snapshot = self.build_test_runtime_snapshot_record(
|
||
user_id,
|
||
saved_at_micros,
|
||
bottom_tab,
|
||
game_state,
|
||
current_story,
|
||
updated_at_micros,
|
||
)?;
|
||
self.cache_test_runtime_snapshot(snapshot.clone());
|
||
Ok(snapshot)
|
||
}
|
||
#[cfg(not(test))]
|
||
Err(error) => Err(error),
|
||
}
|
||
}
|
||
|
||
pub async fn delete_runtime_snapshot_record(
|
||
&self,
|
||
user_id: String,
|
||
) -> Result<bool, SpacetimeClientError> {
|
||
match self
|
||
.spacetime_client
|
||
.delete_runtime_snapshot(user_id.clone())
|
||
.await
|
||
{
|
||
Ok(deleted) => {
|
||
#[cfg(test)]
|
||
if deleted {
|
||
self.remove_test_runtime_snapshot(user_id.as_str());
|
||
}
|
||
Ok(deleted)
|
||
}
|
||
#[cfg(test)]
|
||
Err(_) => Ok(self
|
||
.remove_test_runtime_snapshot(user_id.as_str())
|
||
.is_some()),
|
||
#[cfg(not(test))]
|
||
Err(error) => Err(error),
|
||
}
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
impl AppState {
|
||
fn cache_test_runtime_snapshot(&self, record: RuntimeSnapshotRecord) {
|
||
self.test_runtime_snapshot_store
|
||
.lock()
|
||
.expect("test runtime snapshot store should lock")
|
||
.insert(record.user_id.clone(), record);
|
||
}
|
||
|
||
fn read_test_runtime_snapshot(&self, user_id: &str) -> Option<RuntimeSnapshotRecord> {
|
||
self.test_runtime_snapshot_store
|
||
.lock()
|
||
.expect("test runtime snapshot store should lock")
|
||
.get(user_id)
|
||
.cloned()
|
||
}
|
||
|
||
fn remove_test_runtime_snapshot(&self, user_id: &str) -> Option<RuntimeSnapshotRecord> {
|
||
self.test_runtime_snapshot_store
|
||
.lock()
|
||
.expect("test runtime snapshot store should lock")
|
||
.remove(user_id)
|
||
}
|
||
|
||
fn build_test_runtime_snapshot_record(
|
||
&self,
|
||
user_id: String,
|
||
saved_at_micros: i64,
|
||
bottom_tab: String,
|
||
game_state: Value,
|
||
current_story: Option<Value>,
|
||
updated_at_micros: i64,
|
||
) -> Result<RuntimeSnapshotRecord, SpacetimeClientError> {
|
||
let previous = self.read_test_runtime_snapshot(user_id.as_str());
|
||
let game_state_json = serde_json::to_string(&game_state).map_err(|error| {
|
||
SpacetimeClientError::Runtime(format!("测试快照 game_state 序列化失败: {error}"))
|
||
})?;
|
||
let current_story_json = current_story
|
||
.as_ref()
|
||
.map(serde_json::to_string)
|
||
.transpose()
|
||
.map_err(|error| {
|
||
SpacetimeClientError::Runtime(format!("测试快照 current_story 序列化失败: {error}"))
|
||
})?;
|
||
|
||
Ok(RuntimeSnapshotRecord {
|
||
user_id,
|
||
version: SAVE_SNAPSHOT_VERSION,
|
||
saved_at: format_utc_micros(saved_at_micros),
|
||
saved_at_micros,
|
||
bottom_tab,
|
||
game_state,
|
||
current_story,
|
||
game_state_json,
|
||
current_story_json,
|
||
created_at_micros: previous
|
||
.as_ref()
|
||
.map(|record| record.created_at_micros)
|
||
.unwrap_or(updated_at_micros),
|
||
updated_at_micros,
|
||
})
|
||
}
|
||
}
|
||
|
||
impl fmt::Display for AppStateInitError {
|
||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||
match self {
|
||
Self::Jwt(error) => write!(f, "{error}"),
|
||
Self::RefreshCookie(error) => write!(f, "{error}"),
|
||
Self::SmsProvider(error) => write!(f, "{error}"),
|
||
Self::Oss(error) => write!(f, "{error}"),
|
||
Self::Llm(error) => write!(f, "{error}"),
|
||
}
|
||
}
|
||
}
|
||
|
||
impl Error for AppStateInitError {}
|
||
|
||
impl From<JwtError> for AppStateInitError {
|
||
fn from(value: JwtError) -> Self {
|
||
Self::Jwt(value)
|
||
}
|
||
}
|
||
|
||
impl From<RefreshCookieError> for AppStateInitError {
|
||
fn from(value: RefreshCookieError) -> Self {
|
||
Self::RefreshCookie(value)
|
||
}
|
||
}
|
||
|
||
impl From<SmsProviderError> for AppStateInitError {
|
||
fn from(value: SmsProviderError) -> Self {
|
||
Self::SmsProvider(value)
|
||
}
|
||
}
|
||
|
||
impl From<OssError> for AppStateInitError {
|
||
fn from(value: OssError) -> Self {
|
||
Self::Oss(value)
|
||
}
|
||
}
|
||
|
||
impl From<LlmError> for AppStateInitError {
|
||
fn from(value: LlmError) -> Self {
|
||
Self::Llm(value)
|
||
}
|
||
}
|
||
|
||
fn build_oss_client(config: &AppConfig) -> Result<Option<OssClient>, AppStateInitError> {
|
||
let has_any_oss_field = config.oss_bucket.is_some()
|
||
|| config.oss_endpoint.is_some()
|
||
|| config.oss_access_key_id.is_some()
|
||
|| config.oss_access_key_secret.is_some();
|
||
|
||
if !has_any_oss_field {
|
||
return Ok(None);
|
||
}
|
||
|
||
let oss_config = OssConfig::new(
|
||
config.oss_bucket.clone().unwrap_or_default(),
|
||
config.oss_endpoint.clone().unwrap_or_default(),
|
||
config.oss_access_key_id.clone().unwrap_or_default(),
|
||
config.oss_access_key_secret.clone().unwrap_or_default(),
|
||
config.oss_read_expire_seconds,
|
||
config.oss_post_expire_seconds,
|
||
config.oss_post_max_size_bytes,
|
||
config.oss_success_action_status,
|
||
)?;
|
||
|
||
Ok(Some(OssClient::new(oss_config)))
|
||
}
|
||
|
||
fn build_llm_client(config: &AppConfig) -> Result<Option<LlmClient>, AppStateInitError> {
|
||
let Some(api_key) = config
|
||
.llm_api_key
|
||
.as_ref()
|
||
.map(|value| value.trim())
|
||
.filter(|value| !value.is_empty())
|
||
else {
|
||
return Ok(None);
|
||
};
|
||
|
||
let llm_config = LlmConfig::new(
|
||
config.llm_provider,
|
||
config.llm_base_url.clone(),
|
||
api_key.to_string(),
|
||
config.llm_model.clone(),
|
||
config.llm_request_timeout_ms,
|
||
config.llm_max_retries,
|
||
config.llm_retry_backoff_ms,
|
||
)?;
|
||
|
||
Ok(Some(LlmClient::new(llm_config)?))
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use module_ai::{AiTaskKind, generate_ai_task_id};
|
||
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn app_state_exposes_usable_ai_task_service() {
|
||
let state = AppState::new(AppConfig::default()).expect("state should build");
|
||
let task_id = generate_ai_task_id(1_713_680_000_000_000);
|
||
|
||
let created = state
|
||
.ai_task_service()
|
||
.create_task(module_ai::AiTaskCreateInput {
|
||
task_id: task_id.clone(),
|
||
task_kind: AiTaskKind::StoryGeneration,
|
||
owner_user_id: "user_001".to_string(),
|
||
request_label: "营地开场".to_string(),
|
||
source_module: "story".to_string(),
|
||
source_entity_id: Some("storysess_001".to_string()),
|
||
request_payload_json: Some("{\"scene\":\"camp\"}".to_string()),
|
||
stages: AiTaskKind::StoryGeneration.default_stage_blueprints(),
|
||
created_at_micros: 1_713_680_000_000_000,
|
||
})
|
||
.expect("ai task should create");
|
||
|
||
assert_eq!(created.task_id, task_id);
|
||
assert_eq!(created.task_kind, AiTaskKind::StoryGeneration);
|
||
assert_eq!(created.stages.len(), 4);
|
||
}
|
||
|
||
#[test]
|
||
fn app_state_skips_llm_client_when_api_key_missing() {
|
||
let state = AppState::new(AppConfig::default()).expect("state should build");
|
||
|
||
assert!(state.llm_client().is_none());
|
||
}
|
||
}
|