use crate::*; use serde::{Deserialize, Serialize}; use serde_json::Value; const AUTH_STORE_SNAPSHOT_ID: &str = "default"; #[derive(Clone, Debug, PartialEq, Eq, SpacetimeType)] pub struct AuthStoreSnapshotRecord { pub snapshot_json: Option, pub updated_at_micros: Option, } #[derive(Clone, Debug, PartialEq, Eq, SpacetimeType)] pub struct AuthStoreSnapshotUpsertInput { pub snapshot_json: String, pub updated_at_micros: i64, } #[derive(Clone, Debug, PartialEq, Eq, SpacetimeType)] pub struct AuthStoreSnapshotProcedureResult { pub ok: bool, pub record: Option, pub error_message: Option, } #[derive(Clone, Debug, PartialEq, Eq, SpacetimeType)] pub struct AuthStoreSnapshotImportRecord { pub imported_user_count: u32, pub imported_identity_count: u32, pub imported_refresh_session_count: u32, } #[derive(Clone, Debug, PartialEq, Eq, SpacetimeType)] pub struct AuthStoreSnapshotImportProcedureResult { pub ok: bool, pub record: Option, pub error_message: Option, } #[spacetimedb::table(accessor = auth_store_snapshot)] pub struct AuthStoreSnapshot { #[primary_key] pub(crate) snapshot_id: String, pub(crate) snapshot_json: String, pub(crate) updated_at: Timestamp, } #[spacetimedb::table( accessor = user_account, index(accessor = by_user_account_username, btree(columns = [username])), index(accessor = by_user_account_public_code, btree(columns = [public_user_code])) )] pub struct UserAccount { #[primary_key] pub(crate) user_id: String, pub(crate) public_user_code: String, pub(crate) username: String, pub(crate) display_name: String, pub(crate) phone_number_masked: Option, pub(crate) phone_number_e164: Option, pub(crate) login_method: String, pub(crate) binding_status: String, pub(crate) wechat_bound: bool, pub(crate) password_hash: String, pub(crate) password_login_enabled: bool, pub(crate) token_version: u64, } #[spacetimedb::table( accessor = auth_identity, index(accessor = by_auth_identity_user_id, btree(columns = [user_id])), index(accessor = by_auth_identity_provider_uid, btree(columns = [provider, provider_uid])) )] pub struct AuthIdentity { #[primary_key] pub(crate) identity_id: String, pub(crate) user_id: String, pub(crate) provider: String, pub(crate) provider_uid: String, pub(crate) provider_union_id: Option, pub(crate) phone_e164: Option, pub(crate) display_name: Option, pub(crate) avatar_url: Option, } #[spacetimedb::table( accessor = refresh_session, index(accessor = by_refresh_session_user_id, btree(columns = [user_id])), index(accessor = by_refresh_session_token_hash, btree(columns = [refresh_token_hash])) )] pub struct RefreshSession { #[primary_key] pub(crate) session_id: String, pub(crate) user_id: String, pub(crate) refresh_token_hash: String, pub(crate) issued_by_provider: String, pub(crate) client_info_json: String, pub(crate) expires_at: String, pub(crate) revoked_at: Option, pub(crate) created_at: String, pub(crate) updated_at: String, pub(crate) last_seen_at: String, } // Axum 启动恢复认证状态时读取当前快照;记录不存在代表尚未产生登录态。 #[spacetimedb::procedure] pub fn get_auth_store_snapshot(ctx: &mut ProcedureContext) -> AuthStoreSnapshotProcedureResult { match ctx.try_with_tx(|tx| get_auth_store_snapshot_tx(tx)) { Ok(record) => AuthStoreSnapshotProcedureResult { ok: true, record: Some(record), error_message: None, }, Err(message) => AuthStoreSnapshotProcedureResult { ok: false, record: None, error_message: Some(message), }, } } // Axum 每次鉴权仓储变更后覆盖写入整份快照,后续拆表阶段再替换为细粒度 reducer。 #[spacetimedb::procedure] pub fn upsert_auth_store_snapshot( ctx: &mut ProcedureContext, input: AuthStoreSnapshotUpsertInput, ) -> AuthStoreSnapshotProcedureResult { match ctx.try_with_tx(|tx| upsert_auth_store_snapshot_tx(tx, input.clone())) { Ok(record) => AuthStoreSnapshotProcedureResult { ok: true, record: Some(record), error_message: None, }, Err(message) => AuthStoreSnapshotProcedureResult { ok: false, record: None, error_message: Some(message), }, } } #[spacetimedb::procedure] pub fn import_auth_store_snapshot( ctx: &mut ProcedureContext, ) -> AuthStoreSnapshotImportProcedureResult { match ctx.try_with_tx(|tx| import_auth_store_snapshot_tx(tx)) { Ok(record) => AuthStoreSnapshotImportProcedureResult { ok: true, record: Some(record), error_message: None, }, Err(message) => AuthStoreSnapshotImportProcedureResult { ok: false, record: None, error_message: Some(message), }, } } // Axum ??????????????? module-auth ???????????????? #[spacetimedb::procedure] pub fn export_auth_store_snapshot_from_tables( ctx: &mut ProcedureContext, ) -> AuthStoreSnapshotProcedureResult { match ctx.try_with_tx(|tx| export_auth_store_snapshot_from_tables_tx(tx)) { Ok(record) => AuthStoreSnapshotProcedureResult { ok: true, record: Some(record), error_message: None, }, Err(message) => AuthStoreSnapshotProcedureResult { ok: false, record: None, error_message: Some(message), }, } } fn get_auth_store_snapshot_tx(ctx: &ReducerContext) -> Result { Ok( match ctx .db .auth_store_snapshot() .snapshot_id() .find(&AUTH_STORE_SNAPSHOT_ID.to_string()) { Some(row) => AuthStoreSnapshotRecord { snapshot_json: Some(row.snapshot_json), updated_at_micros: Some(row.updated_at.to_micros_since_unix_epoch()), }, None => AuthStoreSnapshotRecord { snapshot_json: None, updated_at_micros: None, }, }, ) } fn upsert_auth_store_snapshot_tx( ctx: &ReducerContext, input: AuthStoreSnapshotUpsertInput, ) -> Result { let snapshot_json = input.snapshot_json.trim().to_string(); if snapshot_json.is_empty() { return Err("认证快照 JSON 不能为空".to_string()); } let updated_at = Timestamp::from_micros_since_unix_epoch(input.updated_at_micros); if ctx .db .auth_store_snapshot() .snapshot_id() .find(&AUTH_STORE_SNAPSHOT_ID.to_string()) .is_some() { ctx.db .auth_store_snapshot() .snapshot_id() .delete(&AUTH_STORE_SNAPSHOT_ID.to_string()); } ctx.db.auth_store_snapshot().insert(AuthStoreSnapshot { snapshot_id: AUTH_STORE_SNAPSHOT_ID.to_string(), snapshot_json: snapshot_json.clone(), updated_at, }); Ok(AuthStoreSnapshotRecord { snapshot_json: Some(snapshot_json), updated_at_micros: Some(input.updated_at_micros), }) } fn import_auth_store_snapshot_tx( ctx: &ReducerContext, ) -> Result { let snapshot = ctx .db .auth_store_snapshot() .snapshot_id() .find(&AUTH_STORE_SNAPSHOT_ID.to_string()) .ok_or_else(|| "认证快照不存在,无法导入正式表".to_string())?; let parsed = serde_json::from_str::(&snapshot.snapshot_json) .map_err(|error| format!("认证快照 JSON 解析失败:{error}"))?; clear_auth_target_tables(ctx); let mut imported_user_count = 0_u32; let mut imported_identity_count = 0_u32; let mut imported_refresh_session_count = 0_u32; for stored_user in parsed.users_by_username.into_values() { let user = stored_user.user; ctx.db.user_account().insert(UserAccount { user_id: user.id.clone(), public_user_code: user.public_user_code, username: user.username, display_name: user.display_name, phone_number_masked: user.phone_number_masked, phone_number_e164: stored_user.phone_number.clone(), login_method: user.login_method, binding_status: user.binding_status, wechat_bound: user.wechat_bound, password_hash: stored_user.password_hash, password_login_enabled: stored_user.password_login_enabled, token_version: user.token_version, }); imported_user_count += 1; if let Some(phone_number) = stored_user.phone_number { ctx.db.auth_identity().insert(AuthIdentity { identity_id: format!("authi_phone_{}", sanitize_identity_component(&phone_number)), user_id: user.id, provider: "phone".to_string(), provider_uid: phone_number.clone(), provider_union_id: None, phone_e164: Some(phone_number), display_name: None, avatar_url: None, }); imported_identity_count += 1; } } for identity in parsed.wechat_identity_by_provider_uid.into_values() { ctx.db.auth_identity().insert(AuthIdentity { identity_id: format!( "authi_wechat_{}", sanitize_identity_component(&identity.provider_uid) ), user_id: identity.user_id, provider: "wechat".to_string(), provider_uid: identity.provider_uid, provider_union_id: identity.provider_union_id, phone_e164: None, display_name: identity.display_name, avatar_url: identity.avatar_url, }); imported_identity_count += 1; } for stored_session in parsed.sessions_by_id.into_values() { let session = stored_session.session; let client_info_json = serde_json::to_string(&session.client_info) .map_err(|error| format!("客户端身份序列化失败:{error}"))?; ctx.db.refresh_session().insert(RefreshSession { session_id: session.session_id, user_id: session.user_id, refresh_token_hash: session.refresh_token_hash, issued_by_provider: session.issued_by_provider, client_info_json, expires_at: session.expires_at, revoked_at: session.revoked_at, created_at: session.created_at, updated_at: session.updated_at, last_seen_at: session.last_seen_at, }); imported_refresh_session_count += 1; } Ok(AuthStoreSnapshotImportRecord { imported_user_count, imported_identity_count, imported_refresh_session_count, }) } fn export_auth_store_snapshot_from_tables_tx( ctx: &ReducerContext, ) -> Result { let users = ctx.db.user_account().iter().collect::>(); let identities = ctx.db.auth_identity().iter().collect::>(); let sessions = ctx.db.refresh_session().iter().collect::>(); if users.is_empty() && identities.is_empty() && sessions.is_empty() { return Ok(AuthStoreSnapshotRecord { snapshot_json: None, updated_at_micros: None, }); } let mut phone_identity_by_user_id = std::collections::HashMap::new(); let mut phone_to_user_id = std::collections::HashMap::new(); let mut wechat_identity_by_provider_uid = std::collections::HashMap::new(); let mut user_id_by_provider_union_id = std::collections::HashMap::new(); for identity in identities { match identity.provider.as_str() { "phone" => { let phone_number = identity .phone_e164 .clone() .unwrap_or_else(|| identity.provider_uid.clone()); phone_to_user_id.insert(phone_number.clone(), identity.user_id.clone()); phone_identity_by_user_id.insert(identity.user_id, phone_number); } "wechat" => { if let Some(union_id) = identity.provider_union_id.clone() { user_id_by_provider_union_id.insert(union_id, identity.user_id.clone()); } wechat_identity_by_provider_uid.insert( identity.provider_uid.clone(), StoredWechatIdentitySnapshot { user_id: identity.user_id, provider_uid: identity.provider_uid, provider_union_id: identity.provider_union_id, display_name: identity.display_name, avatar_url: identity.avatar_url, }, ); } _ => {} } } let mut next_user_id = 1_u64; let mut users_by_username = std::collections::HashMap::new(); for user in users { if let Some(numeric_id) = user .user_id .strip_prefix("user_") .and_then(|value| value.parse::().ok()) { next_user_id = next_user_id.max(numeric_id.saturating_add(1)); } let auth_user = AuthUserSnapshot { id: user.user_id.clone(), public_user_code: user.public_user_code, username: user.username.clone(), display_name: user.display_name, phone_number_masked: user.phone_number_masked, login_method: user.login_method, binding_status: user.binding_status, wechat_bound: user.wechat_bound, token_version: user.token_version, }; users_by_username.insert( user.username, StoredPasswordUserSnapshot { user: auth_user, password_hash: user.password_hash, password_login_enabled: user.password_login_enabled, phone_number: user .phone_number_e164 .or_else(|| phone_identity_by_user_id.remove(&user.user_id)), }, ); } let mut sessions_by_id = std::collections::HashMap::new(); let mut session_id_by_refresh_token_hash = std::collections::HashMap::new(); for session in sessions { let client_info = serde_json::from_str::(&session.client_info_json) .map_err(|error| format!("refresh session ????? JSON ?????{error}"))?; session_id_by_refresh_token_hash.insert( session.refresh_token_hash.clone(), session.session_id.clone(), ); sessions_by_id.insert( session.session_id.clone(), StoredRefreshSessionSnapshot { session: RefreshSessionSnapshot { session_id: session.session_id, user_id: session.user_id, refresh_token_hash: session.refresh_token_hash, issued_by_provider: session.issued_by_provider, client_info, expires_at: session.expires_at, revoked_at: session.revoked_at, created_at: session.created_at, updated_at: session.updated_at, last_seen_at: session.last_seen_at, }, }, ); } let snapshot = PersistentAuthStoreSnapshot { next_user_id, users_by_username, phone_to_user_id, sessions_by_id, session_id_by_refresh_token_hash, wechat_identity_by_provider_uid, user_id_by_provider_union_id, }; let snapshot_json = serde_json::to_string_pretty(&snapshot).map_err(|error| format!("?????????????{error}"))?; Ok(AuthStoreSnapshotRecord { snapshot_json: Some(snapshot_json), updated_at_micros: None, }) } fn clear_auth_target_tables(ctx: &ReducerContext) { for row in ctx.db.refresh_session().iter().collect::>() { ctx.db .refresh_session() .session_id() .delete(&row.session_id); } for row in ctx.db.auth_identity().iter().collect::>() { ctx.db .auth_identity() .identity_id() .delete(&row.identity_id); } for row in ctx.db.user_account().iter().collect::>() { ctx.db.user_account().user_id().delete(&row.user_id); } } fn sanitize_identity_component(value: &str) -> String { let sanitized = value .chars() .map(|character| { if character.is_ascii_alphanumeric() { character } else { '_' } }) .collect::(); sanitized.trim_matches('_').to_string() } #[derive(Deserialize, Serialize)] struct PersistentAuthStoreSnapshot { #[serde(default = "default_next_user_id")] next_user_id: u64, users_by_username: std::collections::HashMap, #[serde(default)] phone_to_user_id: std::collections::HashMap, sessions_by_id: std::collections::HashMap, #[serde(default)] session_id_by_refresh_token_hash: std::collections::HashMap, wechat_identity_by_provider_uid: std::collections::HashMap, #[serde(default)] user_id_by_provider_union_id: std::collections::HashMap, } fn default_next_user_id() -> u64 { 1 } #[derive(Deserialize, Serialize)] struct StoredPasswordUserSnapshot { user: AuthUserSnapshot, password_hash: String, #[serde(default)] password_login_enabled: bool, phone_number: Option, } #[derive(Deserialize, Serialize)] struct AuthUserSnapshot { id: String, public_user_code: String, username: String, display_name: String, phone_number_masked: Option, login_method: String, binding_status: String, wechat_bound: bool, token_version: u64, } #[derive(Deserialize, Serialize)] struct StoredWechatIdentitySnapshot { user_id: String, provider_uid: String, provider_union_id: Option, display_name: Option, avatar_url: Option, } #[derive(Deserialize, Serialize)] struct StoredRefreshSessionSnapshot { session: RefreshSessionSnapshot, } #[derive(Deserialize, Serialize)] struct RefreshSessionSnapshot { session_id: String, user_id: String, refresh_token_hash: String, issued_by_provider: String, client_info: Value, expires_at: String, revoked_at: Option, created_at: String, updated_at: String, last_seen_at: String, }