use std::{ collections::{BTreeMap, HashSet}, error::Error, fmt, }; use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier, password_hash::SaltString}; use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD}; use hmac::{Hmac, Mac}; use jsonwebtoken::{ Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode, errors::ErrorKind, }; use rand_core::OsRng; use reqwest::{Client, StatusCode}; use serde::{Deserialize, Serialize}; use serde_json::Value; use sha1::Sha1; use sha2::{Digest, Sha256}; use shared_kernel::{new_uuid_simple_string, normalize_optional_string, normalize_required_string}; use time::{Duration, OffsetDateTime}; use tracing::{info, warn}; use url::Url; pub const ACCESS_TOKEN_ALGORITHM: Algorithm = Algorithm::HS256; pub const DEFAULT_ACCESS_TOKEN_TTL_SECONDS: u64 = 2 * 60 * 60; pub const DEFAULT_REFRESH_COOKIE_NAME: &str = "genarrative_refresh_session"; pub const DEFAULT_REFRESH_COOKIE_PATH: &str = "/api/auth"; pub const DEFAULT_REFRESH_SESSION_TTL_DAYS: u32 = 30; pub const DEFAULT_SMS_ENDPOINT: &str = "dypnsapi.aliyuncs.com"; pub const DEFAULT_SMS_COUNTRY_CODE: &str = "86"; pub const DEFAULT_SMS_TEMPLATE_PARAM_KEY: &str = "code"; pub const DEFAULT_SMS_MOCK_VERIFY_CODE: &str = "123456"; pub const DEFAULT_SMS_CODE_LENGTH: u8 = 6; pub const DEFAULT_SMS_CODE_TYPE: u8 = 1; pub const DEFAULT_SMS_VALID_TIME_SECONDS: u64 = 300; pub const DEFAULT_SMS_INTERVAL_SECONDS: u64 = 60; pub const DEFAULT_SMS_DUPLICATE_POLICY: u8 = 1; pub const DEFAULT_SMS_CASE_AUTH_POLICY: u8 = 1; pub const DEFAULT_WECHAT_AUTHORIZE_ENDPOINT: &str = "https://open.weixin.qq.com/connect/qrconnect"; pub const DEFAULT_WECHAT_IN_APP_AUTHORIZE_ENDPOINT: &str = "https://open.weixin.qq.com/connect/oauth2/authorize"; pub const DEFAULT_WECHAT_ACCESS_TOKEN_ENDPOINT: &str = "https://api.weixin.qq.com/sns/oauth2/access_token"; pub const DEFAULT_WECHAT_USER_INFO_ENDPOINT: &str = "https://api.weixin.qq.com/sns/userinfo"; pub const DEFAULT_WECHAT_JS_CODE_SESSION_ENDPOINT: &str = "https://api.weixin.qq.com/sns/jscode2session"; type HmacSha1 = Hmac; // 鉴权 provider 直接冻结成文档中约定的枚举,避免后续在多个 crate 内重复发明字符串字面量。 #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum AuthProvider { Password, Phone, Wechat, } // 绑定状态只保留当前 JWT 需要透传的最小快照,不把完整账号状态枚举直接泄漏到 token 中。 #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum BindingStatus { Active, PendingBindPhone, } // 用于签发 access token 的领域输入,和最终 JWT claims 解耦,避免业务层手动拼 iat/exp/iss。 #[derive(Clone, Debug, PartialEq, Eq)] pub struct AccessTokenClaimsInput { pub user_id: String, pub session_id: String, pub provider: AuthProvider, pub roles: Vec, pub token_version: u64, pub phone_verified: bool, pub binding_status: BindingStatus, pub display_name: Option, } // 直接映射最终 JWT payload,字段名与文档冻结口径保持一致。 #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct AccessTokenClaims { pub iss: String, pub sub: String, pub sid: String, pub provider: AuthProvider, pub roles: Vec, pub ver: u64, pub phone_verified: bool, pub binding_status: BindingStatus, #[serde(skip_serializing_if = "Option::is_none")] pub display_name: Option, pub iat: u64, pub exp: u64, } // 统一承载 JWT 配置,避免 secret、issuer、ttl 在 api-server 与后续模块里散落。 #[derive(Clone, Debug, PartialEq, Eq)] pub struct JwtConfig { issuer: String, secret: String, access_token_ttl_seconds: u64, } // refresh cookie 的 SameSite 固定约束成枚举,避免各层直接使用大小写不一致的字符串。 #[derive(Clone, Debug, PartialEq, Eq)] pub enum RefreshCookieSameSite { Lax, Strict, None, } // refresh cookie 的平台配置统一收口到 platform-auth,避免 api-server 直接散落 cookie 细节。 #[derive(Clone, Debug, PartialEq, Eq)] pub struct RefreshCookieConfig { cookie_name: String, cookie_path: String, cookie_secure: bool, cookie_same_site: RefreshCookieSameSite, refresh_session_ttl_days: u32, } #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum SmsAuthProviderKind { Mock, Aliyun, } #[derive(Clone, Debug, PartialEq, Eq)] pub struct SmsAuthConfig { pub provider: SmsAuthProviderKind, pub endpoint: String, pub access_key_id: Option, pub access_key_secret: Option, pub sign_name: String, pub template_code: String, pub template_param_key: String, pub country_code: String, pub scheme_name: Option, pub code_length: u8, pub code_type: u8, pub valid_time_seconds: u64, pub interval_seconds: u64, pub duplicate_policy: u8, pub case_auth_policy: u8, pub return_verify_code: bool, pub mock_verify_code: String, } #[derive(Clone, Debug, PartialEq, Eq)] pub struct SmsSendCodeRequest { pub national_phone_number: String, pub scene: String, } #[derive(Clone, Debug, PartialEq, Eq)] pub struct SmsSendCodeResult { pub cooldown_seconds: u64, pub expires_in_seconds: u64, pub provider_request_id: Option, pub provider_out_id: Option, } #[derive(Clone, Debug, PartialEq, Eq)] pub struct SmsVerifyCodeRequest { pub national_phone_number: String, pub verify_code: String, pub provider_out_id: Option, } #[derive(Clone, Debug, PartialEq, Eq)] pub enum WechatAuthScene { Desktop, WechatInApp, } #[derive(Clone, Debug, PartialEq, Eq)] pub struct WechatAuthConfig { pub enabled: bool, pub provider: String, pub app_id: Option, pub app_secret: Option, pub mini_program_app_id: Option, pub mini_program_app_secret: Option, pub authorize_endpoint: String, pub access_token_endpoint: String, pub user_info_endpoint: String, pub js_code_session_endpoint: String, pub mock_user_id: String, pub mock_union_id: Option, pub mock_display_name: String, pub mock_avatar_url: Option, } #[derive(Clone, Debug, PartialEq, Eq)] pub struct WechatIdentityProfile { pub provider_uid: String, pub provider_union_id: Option, pub display_name: Option, pub avatar_url: Option, } #[derive(Clone, Debug)] pub enum WechatProvider { Disabled, Mock(MockWechatProvider), Real(RealWechatProvider), } #[derive(Clone, Debug)] pub struct MockWechatProvider { mock_user_id: String, mock_union_id: Option, mock_display_name: String, mock_avatar_url: Option, } #[derive(Clone, Debug)] pub struct RealWechatProvider { client: Client, app_id: Option, app_secret: Option, mini_program_app_id: Option, mini_program_app_secret: Option, authorize_endpoint: String, access_token_endpoint: String, user_info_endpoint: String, js_code_session_endpoint: String, } #[derive(Clone, Debug)] pub enum SmsAuthProvider { Mock(MockSmsAuthProvider), Aliyun(AliyunSmsAuthProvider), } #[derive(Clone, Debug)] pub struct MockSmsAuthProvider { config: SmsAuthConfig, } #[derive(Clone, Debug)] pub struct AliyunSmsAuthProvider { client: Client, config: SmsAuthConfig, } #[derive(Debug, PartialEq, Eq)] pub enum JwtError { InvalidConfig(&'static str), InvalidClaims(&'static str), SignFailed(String), VerifyFailed(String), } #[derive(Debug, PartialEq, Eq)] pub enum RefreshCookieError { InvalidConfig(&'static str), } #[derive(Debug, PartialEq, Eq)] pub enum PasswordHashError { HashFailed(String), VerifyFailed(String), } #[derive(Debug, PartialEq, Eq)] pub enum SmsProviderError { InvalidConfig(String), InvalidVerifyCode, Upstream(String), } #[derive(Debug, PartialEq, Eq)] pub enum WechatProviderError { Disabled, MissingCode, InvalidConfig(String), InvalidCallback(String), RequestFailed(String), DeserializeFailed(String), Upstream(String), MissingProfile(String), } // 鉴权平台错误统一先归类,api-server 再决定 HTTP status 和错误 envelope。 #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum AuthPlatformErrorKind { InvalidConfig, InvalidClaims, SignFailed, VerifyFailed, CookieConfig, HashFailed, InvalidVerifyCode, Disabled, MissingCode, InvalidCallback, RequestFailed, DeserializeFailed, MissingProfile, Upstream, } #[derive(Debug, Deserialize)] struct WechatAccessTokenResponse { access_token: Option, openid: Option, unionid: Option, errmsg: Option, } #[derive(Debug, Deserialize)] struct WechatUserInfoResponse { openid: Option, unionid: Option, nickname: Option, headimgurl: Option, errmsg: Option, } #[derive(Debug, Deserialize)] struct WechatJsCodeSessionResponse { openid: Option, unionid: Option, errcode: Option, errmsg: Option, } #[derive(Debug, Deserialize)] struct AliyunSendSmsVerifyCodeResponse { // 阿里云 RPC 原始 JSON 使用首字母大写字段名,这里必须显式映射,避免把成功响应误判成空值。 #[serde(default, rename = "Code")] code: Option, #[serde(default, rename = "Message")] message: Option, #[serde(default, rename = "RequestId")] request_id: Option, #[serde(default, rename = "Success")] success: Option, #[serde(default, rename = "Model")] model: Option, } #[derive(Debug, Deserialize)] struct AliyunSendSmsVerifyCodeModel { #[serde(default, rename = "BizId")] _biz_id: Option, #[serde(default, rename = "OutId")] out_id: Option, #[serde(default, rename = "RequestId")] request_id: Option, } #[derive(Debug, Deserialize)] struct AliyunCheckSmsVerifyCodeResponse { // 校验接口同样返回首字母大写字段名,保持和发送接口一致的显式映射。 #[serde(default, rename = "Code")] code: Option, #[serde(default, rename = "Message")] message: Option, #[serde(default, rename = "Success")] success: Option, #[serde(default, rename = "Model")] model: Option, } #[derive(Debug, Deserialize)] struct AliyunCheckSmsVerifyCodeModel { #[serde(default, rename = "OutId")] _out_id: Option, #[serde(default, rename = "VerifyResult")] verify_result: Option, } impl JwtConfig { pub fn new( issuer: String, secret: String, access_token_ttl_seconds: u64, ) -> Result { let issuer = normalize_required_string(&issuer) .ok_or(JwtError::InvalidConfig("JWT issuer 不能为空"))?; let secret = normalize_required_string(&secret) .ok_or(JwtError::InvalidConfig("JWT secret 不能为空"))?; if access_token_ttl_seconds == 0 { return Err(JwtError::InvalidConfig( "JWT access token 过期时间必须大于 0", )); } Ok(Self { issuer, secret, access_token_ttl_seconds, }) } pub fn issuer(&self) -> &str { &self.issuer } pub fn access_token_ttl_seconds(&self) -> u64 { self.access_token_ttl_seconds } } impl RefreshCookieSameSite { pub fn parse(raw: &str) -> Option { match raw.trim().to_ascii_lowercase().as_str() { "lax" => Some(Self::Lax), "strict" => Some(Self::Strict), "none" => Some(Self::None), _ => None, } } pub fn as_str(&self) -> &'static str { match self { Self::Lax => "Lax", Self::Strict => "Strict", Self::None => "None", } } } impl RefreshCookieConfig { pub fn new( cookie_name: String, cookie_path: String, cookie_secure: bool, cookie_same_site: RefreshCookieSameSite, refresh_session_ttl_days: u32, ) -> Result { let cookie_name = normalize_required_string(&cookie_name).ok_or( RefreshCookieError::InvalidConfig("refresh cookie 名称不能为空"), )?; let cookie_path = normalize_required_string(&cookie_path).ok_or( RefreshCookieError::InvalidConfig("refresh cookie path 不能为空"), )?; if refresh_session_ttl_days == 0 { return Err(RefreshCookieError::InvalidConfig( "refresh session TTL 天数必须大于 0", )); } Ok(Self { cookie_name, cookie_path, cookie_secure, cookie_same_site, refresh_session_ttl_days, }) } pub fn cookie_name(&self) -> &str { &self.cookie_name } pub fn cookie_path(&self) -> &str { &self.cookie_path } pub fn cookie_secure(&self) -> bool { self.cookie_secure } pub fn cookie_same_site(&self) -> &RefreshCookieSameSite { &self.cookie_same_site } pub fn refresh_session_ttl_days(&self) -> u32 { self.refresh_session_ttl_days } } impl SmsAuthProviderKind { pub fn parse(raw: &str) -> Option { match raw.trim().to_ascii_lowercase().as_str() { "mock" => Some(Self::Mock), "aliyun" => Some(Self::Aliyun), _ => None, } } pub fn as_str(&self) -> &'static str { match self { Self::Mock => "mock", Self::Aliyun => "aliyun", } } } impl SmsAuthConfig { pub fn new( provider: SmsAuthProviderKind, endpoint: String, access_key_id: Option, access_key_secret: Option, sign_name: String, template_code: String, template_param_key: String, country_code: String, scheme_name: Option, code_length: u8, code_type: u8, valid_time_seconds: u64, interval_seconds: u64, duplicate_policy: u8, case_auth_policy: u8, return_verify_code: bool, mock_verify_code: String, ) -> Result { let endpoint = normalize_required_string(&endpoint) .unwrap_or_else(|| DEFAULT_SMS_ENDPOINT.to_string()); let template_param_key = normalize_required_string(&template_param_key) .unwrap_or_else(|| DEFAULT_SMS_TEMPLATE_PARAM_KEY.to_string()); let country_code = normalize_required_string(&country_code) .unwrap_or_else(|| DEFAULT_SMS_COUNTRY_CODE.to_string()); let scheme_name = normalize_optional_string(scheme_name); let mock_verify_code = normalize_required_string(&mock_verify_code) .unwrap_or_else(|| DEFAULT_SMS_MOCK_VERIFY_CODE.to_string()); if !(4..=8).contains(&code_length) { return Err(SmsProviderError::InvalidConfig( "短信验证码长度必须在 4 到 8 之间".to_string(), )); } if !(1..=7).contains(&code_type) { return Err(SmsProviderError::InvalidConfig( "短信验证码类型取值非法".to_string(), )); } if interval_seconds == 0 || valid_time_seconds == 0 { return Err(SmsProviderError::InvalidConfig( "短信验证码有效期和发送间隔必须大于 0".to_string(), )); } if !(1..=2).contains(&duplicate_policy) { return Err(SmsProviderError::InvalidConfig( "短信验证码重复策略取值非法".to_string(), )); } if !(1..=2).contains(&case_auth_policy) { return Err(SmsProviderError::InvalidConfig( "短信验证码大小写校验策略取值非法".to_string(), )); } match provider { SmsAuthProviderKind::Mock => {} SmsAuthProviderKind::Aliyun => { if normalize_required_string(&sign_name).is_none() { return Err(SmsProviderError::InvalidConfig( "阿里云短信签名不能为空".to_string(), )); } if normalize_required_string(&template_code).is_none() { return Err(SmsProviderError::InvalidConfig( "阿里云短信模板编码不能为空".to_string(), )); } if access_key_id .as_deref() .and_then(normalize_required_string) .is_none() || access_key_secret .as_deref() .and_then(normalize_required_string) .is_none() { return Err(SmsProviderError::InvalidConfig( "阿里云短信 AccessKey 未配置".to_string(), )); } } } Ok(Self { provider, endpoint, access_key_id: access_key_id.and_then(|value| normalize_required_string(&value)), access_key_secret: access_key_secret .and_then(|value| normalize_required_string(&value)), sign_name: sign_name.trim().to_string(), template_code: template_code.trim().to_string(), template_param_key, country_code, scheme_name, code_length, code_type, valid_time_seconds, interval_seconds, duplicate_policy, case_auth_policy, return_verify_code, mock_verify_code, }) } } impl SmsAuthProvider { pub fn new(config: SmsAuthConfig) -> Result { match config.provider { SmsAuthProviderKind::Mock => Ok(Self::Mock(MockSmsAuthProvider { config })), SmsAuthProviderKind::Aliyun => Ok(Self::Aliyun(AliyunSmsAuthProvider { client: Client::new(), config, })), } } pub fn kind(&self) -> SmsAuthProviderKind { match self { Self::Mock(_) => SmsAuthProviderKind::Mock, Self::Aliyun(_) => SmsAuthProviderKind::Aliyun, } } pub async fn send_code( &self, request: SmsSendCodeRequest, ) -> Result { match self { Self::Mock(provider) => provider.send_code(request).await, Self::Aliyun(provider) => provider.send_code(request).await, } } pub async fn verify_code(&self, request: SmsVerifyCodeRequest) -> Result<(), SmsProviderError> { match self { Self::Mock(provider) => provider.verify_code(request).await, Self::Aliyun(provider) => provider.verify_code(request).await, } } } impl WechatAuthConfig { #[allow(clippy::too_many_arguments)] pub fn new( enabled: bool, provider: String, app_id: Option, app_secret: Option, mini_program_app_id: Option, mini_program_app_secret: Option, authorize_endpoint: String, access_token_endpoint: String, user_info_endpoint: String, js_code_session_endpoint: String, mock_user_id: String, mock_union_id: Option, mock_display_name: String, mock_avatar_url: Option, ) -> Self { Self { enabled, provider, app_id, app_secret, mini_program_app_id, mini_program_app_secret, authorize_endpoint, access_token_endpoint, user_info_endpoint, js_code_session_endpoint, mock_user_id, mock_union_id, mock_display_name, mock_avatar_url, } } } impl WechatProvider { pub fn new(config: WechatAuthConfig) -> Self { if !config.enabled { return Self::Disabled; } if config.provider.trim().eq_ignore_ascii_case("mock") { return Self::Mock(MockWechatProvider { mock_user_id: config.mock_user_id, mock_union_id: config.mock_union_id, mock_display_name: config.mock_display_name, mock_avatar_url: config.mock_avatar_url, }); } let has_web_oauth_config = config .app_id .as_ref() .is_some_and(|value| !value.is_empty()) && config .app_secret .as_ref() .is_some_and(|value| !value.is_empty()); let has_mini_program_config = config .mini_program_app_id .as_ref() .is_some_and(|value| !value.is_empty()) && config .mini_program_app_secret .as_ref() .is_some_and(|value| !value.is_empty()); if !has_web_oauth_config && !has_mini_program_config { return Self::Disabled; }; Self::Real(RealWechatProvider { client: Client::new(), app_id: config.app_id, app_secret: config.app_secret, mini_program_app_id: config.mini_program_app_id, mini_program_app_secret: config.mini_program_app_secret, authorize_endpoint: config.authorize_endpoint, access_token_endpoint: config.access_token_endpoint, user_info_endpoint: config.user_info_endpoint, js_code_session_endpoint: config.js_code_session_endpoint, }) } pub fn build_authorization_url( &self, callback_url: &str, state: &str, scene: &WechatAuthScene, ) -> Result { match self { Self::Disabled => Err(WechatProviderError::Disabled), Self::Mock(_) => build_mock_wechat_authorization_url(callback_url, state), Self::Real(provider) => provider.build_authorization_url(callback_url, state, scene), } } pub async fn resolve_callback_profile( &self, code: Option<&str>, mock_code: Option<&str>, ) -> Result { match self { Self::Disabled => Err(WechatProviderError::Disabled), Self::Mock(provider) => Ok(provider.resolve_callback_profile(mock_code)), Self::Real(provider) => provider.resolve_callback_profile(code).await, } } pub async fn resolve_mini_program_login_profile( &self, code: Option<&str>, ) -> Result { match self { Self::Disabled => Err(WechatProviderError::Disabled), Self::Mock(provider) => Ok(provider.resolve_callback_profile(code)), Self::Real(provider) => provider.resolve_mini_program_login_profile(code).await, } } } impl MockWechatProvider { fn resolve_callback_profile(&self, mock_code: Option<&str>) -> WechatIdentityProfile { let provider_uid = mock_code .map(str::trim) .filter(|value| !value.is_empty()) .unwrap_or(self.mock_user_id.as_str()) .to_string(); WechatIdentityProfile { provider_uid, provider_union_id: self.mock_union_id.clone(), display_name: Some(self.mock_display_name.clone()), avatar_url: self.mock_avatar_url.clone(), } } } impl RealWechatProvider { fn build_authorization_url( &self, callback_url: &str, state: &str, scene: &WechatAuthScene, ) -> Result { let endpoint = match scene { WechatAuthScene::Desktop => &self.authorize_endpoint, WechatAuthScene::WechatInApp => DEFAULT_WECHAT_IN_APP_AUTHORIZE_ENDPOINT, }; let mut url = Url::parse(endpoint).map_err(|error| { WechatProviderError::InvalidConfig(format!("微信授权地址非法:{error}")) })?; let app_id = self.app_id.as_ref().ok_or_else(|| { WechatProviderError::InvalidConfig("微信开放平台 AppID 未配置".to_string()) })?; url.query_pairs_mut() .append_pair("appid", app_id) .append_pair("redirect_uri", callback_url) .append_pair("response_type", "code") .append_pair( "scope", match scene { WechatAuthScene::Desktop => "snsapi_login", WechatAuthScene::WechatInApp => "snsapi_userinfo", }, ) .append_pair("state", state); Ok(format!("{url}#wechat_redirect")) } async fn resolve_callback_profile( &self, code: Option<&str>, ) -> Result { let code = code .map(str::trim) .filter(|value| !value.is_empty()) .ok_or(WechatProviderError::MissingCode)?; let app_id = self.app_id.as_ref().ok_or_else(|| { WechatProviderError::InvalidConfig("微信开放平台 AppID 未配置".to_string()) })?; let app_secret = self.app_secret.as_ref().ok_or_else(|| { WechatProviderError::InvalidConfig("微信开放平台 AppSecret 未配置".to_string()) })?; let mut access_token_url = Url::parse(&self.access_token_endpoint).map_err(|error| { WechatProviderError::InvalidConfig(format!("微信 access_token 地址非法:{error}")) })?; access_token_url .query_pairs_mut() .append_pair("appid", app_id) .append_pair("secret", app_secret) .append_pair("code", code) .append_pair("grant_type", "authorization_code"); let access_token_payload = self .client .get(access_token_url.as_str()) .send() .await .map_err(|error| { warn!(error = %error, "微信 access_token 请求失败"); WechatProviderError::RequestFailed( "微信登录失败:access_token 请求失败".to_string(), ) })? .json::() .await .map_err(|error| { warn!(error = %error, "微信 access_token 响应解析失败"); WechatProviderError::DeserializeFailed( "微信登录失败:access_token 响应非法".to_string(), ) })?; let access_token = access_token_payload .access_token .filter(|value| !value.trim().is_empty()) .ok_or_else(|| { WechatProviderError::Upstream(format!( "微信登录失败:{}", access_token_payload .errmsg .unwrap_or_else(|| "缺少 access_token".to_string()) )) })?; let openid = access_token_payload .openid .filter(|value| !value.trim().is_empty()) .ok_or_else(|| { WechatProviderError::MissingProfile("微信登录失败:缺少 openid".to_string()) })?; let mut user_info_url = Url::parse(&self.user_info_endpoint).map_err(|error| { WechatProviderError::InvalidConfig(format!("微信用户信息地址非法:{error}")) })?; user_info_url .query_pairs_mut() .append_pair("access_token", &access_token) .append_pair("openid", &openid) .append_pair("lang", "zh_CN"); let user_info_payload = self .client .get(user_info_url.as_str()) .send() .await .map_err(|error| { warn!(error = %error, "微信用户信息请求失败"); WechatProviderError::RequestFailed("微信登录失败:用户信息请求失败".to_string()) })? .json::() .await .map_err(|error| { warn!(error = %error, "微信用户信息响应解析失败"); WechatProviderError::DeserializeFailed("微信登录失败:用户信息响应非法".to_string()) })?; let provider_uid = user_info_payload .openid .filter(|value| !value.trim().is_empty()) .ok_or_else(|| { WechatProviderError::Upstream(format!( "微信登录失败:{}", user_info_payload .errmsg .unwrap_or_else(|| "缺少 openid".to_string()) )) })?; Ok(WechatIdentityProfile { provider_uid, provider_union_id: user_info_payload.unionid.or(access_token_payload.unionid), display_name: user_info_payload.nickname, avatar_url: user_info_payload.headimgurl, }) } async fn resolve_mini_program_login_profile( &self, code: Option<&str>, ) -> Result { let code = code .map(str::trim) .filter(|value| !value.is_empty()) .ok_or(WechatProviderError::MissingCode)?; let app_id = self .mini_program_app_id .as_ref() .or(self.app_id.as_ref()) .ok_or_else(|| { WechatProviderError::InvalidConfig("微信小程序 AppID 未配置".to_string()) })?; let app_secret = self .mini_program_app_secret .as_ref() .or(self.app_secret.as_ref()) .ok_or_else(|| { WechatProviderError::InvalidConfig("微信小程序 AppSecret 未配置".to_string()) })?; let mut js_code_session_url = Url::parse(&self.js_code_session_endpoint).map_err(|error| { WechatProviderError::InvalidConfig(format!("微信 jscode2session 地址非法:{error}")) })?; js_code_session_url .query_pairs_mut() .append_pair("appid", app_id) .append_pair("secret", app_secret) .append_pair("js_code", code) .append_pair("grant_type", "authorization_code"); let payload = self .client .get(js_code_session_url.as_str()) .send() .await .map_err(|error| { warn!(error = %error, "微信小程序 jscode2session 请求失败"); WechatProviderError::RequestFailed( "微信小程序登录失败:jscode2session 请求失败".to_string(), ) })? .json::() .await .map_err(|error| { warn!(error = %error, "微信小程序 jscode2session 响应解析失败"); WechatProviderError::DeserializeFailed( "微信小程序登录失败:jscode2session 响应非法".to_string(), ) })?; if let Some(errcode) = payload.errcode.filter(|value| *value != 0) { return Err(WechatProviderError::Upstream(format!( "微信小程序登录失败:{}", payload .errmsg .unwrap_or_else(|| format!("jscode2session 返回错误 {errcode}")) ))); } let provider_uid = payload .openid .filter(|value| !value.trim().is_empty()) .ok_or_else(|| { WechatProviderError::MissingProfile("微信小程序登录失败:缺少 openid".to_string()) })?; Ok(WechatIdentityProfile { provider_uid, provider_union_id: payload.unionid, display_name: None, avatar_url: None, }) } } fn build_mock_wechat_authorization_url( callback_url: &str, state: &str, ) -> Result { let mut callback = Url::parse(callback_url).map_err(|error| { WechatProviderError::InvalidCallback(format!("微信回调地址非法:{error}")) })?; callback .query_pairs_mut() .append_pair("mock_code", "wx-mock-code") .append_pair("state", state); Ok(callback.to_string()) } impl MockSmsAuthProvider { async fn send_code( &self, request: SmsSendCodeRequest, ) -> Result { let provider_out_id = build_sms_provider_out_id(&request.scene, &request.national_phone_number); Ok(SmsSendCodeResult { cooldown_seconds: self.config.interval_seconds, expires_in_seconds: self.config.valid_time_seconds, provider_request_id: Some("mock-request-id".to_string()), provider_out_id: Some(provider_out_id), }) } async fn verify_code(&self, request: SmsVerifyCodeRequest) -> Result<(), SmsProviderError> { if request.verify_code.trim() != self.config.mock_verify_code { return Err(SmsProviderError::InvalidVerifyCode); } Ok(()) } } impl AliyunSmsAuthProvider { async fn send_code( &self, request: SmsSendCodeRequest, ) -> Result { let provider_out_id = build_sms_provider_out_id(&request.scene, &request.national_phone_number); let phone_masked = mask_phone_number(&request.national_phone_number); let template_param = serde_json::json!({ self.config.template_param_key.clone(): "##code##", "min": self.config.valid_time_seconds, }) .to_string(); info!( provider = "aliyun", scene = request.scene.as_str(), phone_masked = phone_masked.as_str(), endpoint = self.config.endpoint.as_str(), sign_name = self.config.sign_name.as_str(), template_code = self.config.template_code.as_str(), code_length = self.config.code_length, valid_time_seconds = self.config.valid_time_seconds, interval_seconds = self.config.interval_seconds, provider_out_id = provider_out_id.as_str(), "准备调用阿里云短信发送接口" ); let mut query = BTreeMap::new(); query.insert("Action".to_string(), "SendSmsVerifyCode".to_string()); query.insert("Format".to_string(), "json".to_string()); query.insert("Version".to_string(), "2017-05-25".to_string()); query.insert("Timestamp".to_string(), current_aliyun_timestamp()); query.insert("SignatureNonce".to_string(), new_uuid_simple_string()); query.insert("SignatureMethod".to_string(), "HMAC-SHA1".to_string()); query.insert("SignatureVersion".to_string(), "1.0".to_string()); query.insert( "AccessKeyId".to_string(), self.config.access_key_id.clone().unwrap_or_default(), ); query.insert( "PhoneNumber".to_string(), request.national_phone_number.trim().to_string(), ); query.insert("CountryCode".to_string(), self.config.country_code.clone()); query.insert("SignName".to_string(), self.config.sign_name.clone()); query.insert( "TemplateCode".to_string(), self.config.template_code.clone(), ); query.insert("TemplateParam".to_string(), template_param); query.insert( "CodeLength".to_string(), self.config.code_length.to_string(), ); query.insert("CodeType".to_string(), self.config.code_type.to_string()); query.insert( "ValidTime".to_string(), self.config.valid_time_seconds.to_string(), ); query.insert( "Interval".to_string(), self.config.interval_seconds.to_string(), ); query.insert( "DuplicatePolicy".to_string(), self.config.duplicate_policy.to_string(), ); query.insert( "ReturnVerifyCode".to_string(), self.config.return_verify_code.to_string(), ); query.insert("OutId".to_string(), provider_out_id.clone()); if let Some(scheme_name) = self.config.scheme_name.clone() { query.insert("SchemeName".to_string(), scheme_name); } self.sign_query(&mut query)?; let payload = self .client .post(build_aliyun_sms_url(&self.config.endpoint)?) .form(&query) .send() .await .map_err(|error| SmsProviderError::Upstream(format!("短信验证码发送失败:{error}")))?; let http_status = payload.status(); let body = parse_aliyun_json_response(payload, "短信验证码发送失败").await?; info!( provider = "aliyun", scene = request.scene.as_str(), phone_masked = phone_masked.as_str(), http_status = http_status.as_u16(), provider_code = body.code.as_deref().unwrap_or("unknown"), provider_message = body.message.as_deref().unwrap_or("unknown"), provider_request_id = body .request_id .as_deref() .or_else(|| body .model .as_ref() .and_then(|model| model.request_id.as_deref())) .unwrap_or("unknown"), provider_out_id = body .model .as_ref() .and_then(|model| model.out_id.as_deref()) .unwrap_or("unknown"), success = body.success.unwrap_or(false), "阿里云短信发送接口返回响应" ); if !body.success.unwrap_or(false) || body.code.as_deref() != Some("OK") { warn!( provider = "aliyun", scene = request.scene.as_str(), phone_masked = phone_masked.as_str(), http_status = http_status.as_u16(), provider_code = body.code.as_deref().unwrap_or("unknown"), provider_message = body.message.as_deref().unwrap_or("unknown"), provider_request_id = body .request_id .as_deref() .or_else(|| body .model .as_ref() .and_then(|model| model.request_id.as_deref())) .unwrap_or("unknown"), provider_out_id = body .model .as_ref() .and_then(|model| model.out_id.as_deref()) .unwrap_or("unknown"), "阿里云短信发送接口返回业务失败" ); return Err(map_aliyun_provider_error( "短信验证码发送失败", body.message, body.code, )); } Ok(SmsSendCodeResult { cooldown_seconds: self.config.interval_seconds, expires_in_seconds: self.config.valid_time_seconds, provider_request_id: body.request_id.or_else(|| { body.model .as_ref() .and_then(|model| model.request_id.clone()) }), provider_out_id: body.model.and_then(|model| model.out_id), }) } async fn verify_code(&self, request: SmsVerifyCodeRequest) -> Result<(), SmsProviderError> { let mut query = BTreeMap::new(); query.insert("Action".to_string(), "CheckSmsVerifyCode".to_string()); query.insert("Format".to_string(), "json".to_string()); query.insert("Version".to_string(), "2017-05-25".to_string()); query.insert("Timestamp".to_string(), current_aliyun_timestamp()); query.insert("SignatureNonce".to_string(), new_uuid_simple_string()); query.insert("SignatureMethod".to_string(), "HMAC-SHA1".to_string()); query.insert("SignatureVersion".to_string(), "1.0".to_string()); query.insert( "AccessKeyId".to_string(), self.config.access_key_id.clone().unwrap_or_default(), ); query.insert( "PhoneNumber".to_string(), request.national_phone_number.trim().to_string(), ); query.insert("CountryCode".to_string(), self.config.country_code.clone()); query.insert( "VerifyCode".to_string(), request.verify_code.trim().to_string(), ); query.insert( "CaseAuthPolicy".to_string(), self.config.case_auth_policy.to_string(), ); if let Some(scheme_name) = self.config.scheme_name.clone() { query.insert("SchemeName".to_string(), scheme_name); } if let Some(provider_out_id) = request.provider_out_id { query.insert("OutId".to_string(), provider_out_id); } self.sign_query(&mut query)?; let payload = self .client .post(build_aliyun_sms_url(&self.config.endpoint)?) .form(&query) .send() .await .map_err(|error| SmsProviderError::Upstream(format!("验证码校验失败:{error}")))?; let body = parse_aliyun_json_response_for_verify(payload).await?; if !body.success.unwrap_or(false) || body.code.as_deref() != Some("OK") { return Err(map_aliyun_provider_error( "验证码校验失败", body.message, body.code, )); } if body.model.and_then(|model| model.verify_result).as_deref() != Some("PASS") { return Err(SmsProviderError::InvalidVerifyCode); } Ok(()) } fn sign_query(&self, query: &mut BTreeMap) -> Result<(), SmsProviderError> { let access_key_secret = self.config.access_key_secret.as_deref().ok_or_else(|| { SmsProviderError::InvalidConfig("阿里云短信 AccessKeySecret 未配置".to_string()) })?; let canonicalized = canonicalize_aliyun_rpc_params(query); let string_to_sign = format!( "POST&{}&{}", aliyun_percent_encode("/"), aliyun_percent_encode(&canonicalized) ); let mut signer = HmacSha1::new_from_slice(format!("{access_key_secret}&").as_bytes()) .map_err(|error| { SmsProviderError::InvalidConfig(format!("初始化短信签名器失败:{error}")) })?; signer.update(string_to_sign.as_bytes()); let signature = BASE64_STANDARD.encode(signer.finalize().into_bytes()); query.insert("Signature".to_string(), signature); Ok(()) } } impl AccessTokenClaims { pub fn from_input( input: AccessTokenClaimsInput, config: &JwtConfig, issued_at: OffsetDateTime, ) -> Result { let user_id = normalize_required_field(input.user_id, "JWT sub 不能为空")?; let session_id = normalize_required_field(input.session_id, "JWT sid 不能为空")?; let roles = normalize_roles(input.roles)?; let display_name = normalize_optional_field(input.display_name); let issued_at_unix = issued_at.unix_timestamp(); if issued_at_unix < 0 { return Err(JwtError::InvalidClaims("JWT iat 不能早于 Unix epoch")); } let expires_at = issued_at .checked_add(Duration::seconds( i64::try_from(config.access_token_ttl_seconds()).map_err(|_| { JwtError::InvalidConfig("JWT access token 过期时间超出 i64 上限") })?, )) .ok_or(JwtError::InvalidConfig("JWT 过期时间计算溢出"))?; let expires_at_unix = expires_at.unix_timestamp(); if expires_at_unix <= issued_at_unix { return Err(JwtError::InvalidClaims("JWT exp 必须晚于 iat")); } let claims = Self { iss: config.issuer().to_string(), sub: user_id, sid: session_id, provider: input.provider, roles, ver: input.token_version, phone_verified: input.phone_verified, binding_status: input.binding_status, display_name, iat: issued_at_unix as u64, exp: expires_at_unix as u64, }; claims.validate_for_config(config)?; Ok(claims) } pub fn user_id(&self) -> &str { &self.sub } pub fn session_id(&self) -> &str { &self.sid } pub fn token_version(&self) -> u64 { self.ver } pub fn validate_for_config(&self, config: &JwtConfig) -> Result<(), JwtError> { if self.iss.trim() != config.issuer() { return Err(JwtError::InvalidClaims("JWT iss 与当前配置不一致")); } normalize_required_field(self.sub.clone(), "JWT sub 不能为空")?; normalize_required_field(self.sid.clone(), "JWT sid 不能为空")?; normalize_roles(self.roles.clone())?; if self.exp <= self.iat { return Err(JwtError::InvalidClaims("JWT exp 必须晚于 iat")); } Ok(()) } } pub fn sign_access_token( claims: &AccessTokenClaims, config: &JwtConfig, ) -> Result { claims.validate_for_config(config)?; let header = Header { alg: ACCESS_TOKEN_ALGORITHM, typ: Some("JWT".to_string()), ..Header::default() }; encode( &header, claims, &EncodingKey::from_secret(config.secret.as_bytes()), ) .map_err(|error| JwtError::SignFailed(format!("JWT 签发失败:{error}"))) } pub fn verify_access_token(token: &str, config: &JwtConfig) -> Result { let token = token.trim(); if token.is_empty() { return Err(JwtError::VerifyFailed("JWT 不能为空".to_string())); } let mut validation = Validation::new(ACCESS_TOKEN_ALGORITHM); validation.required_spec_claims = HashSet::from([ "exp".to_string(), "iat".to_string(), "iss".to_string(), "sub".to_string(), ]); validation.set_issuer(&[config.issuer()]); let decoded = decode::( token, &DecodingKey::from_secret(config.secret.as_bytes()), &validation, ) .map_err(map_verify_error)?; decoded.claims.validate_for_config(config)?; Ok(decoded.claims) } pub fn read_refresh_session_token( cookie_header: &str, config: &RefreshCookieConfig, ) -> Option { let cookie_header = cookie_header.trim(); if cookie_header.is_empty() { return None; } for entry in cookie_header.split(';') { let entry = entry.trim(); if entry.is_empty() { continue; } let (raw_name, raw_value) = entry.split_once('=')?; if raw_name.trim() != config.cookie_name() { continue; } let raw_value = raw_value.trim(); if raw_value.is_empty() { return None; } return urlencoding::decode(raw_value) .ok() .map(|decoded| decoded.into_owned()); } None } pub async fn hash_password(password: &str) -> Result { let salt = SaltString::generate(&mut OsRng); Argon2::default() .hash_password(password.as_bytes(), &salt) .map(|hash| hash.to_string()) .map_err(|error| PasswordHashError::HashFailed(format!("密码哈希失败:{error}"))) } pub async fn verify_password( password_hash: &str, password: &str, ) -> Result { let parsed_hash = PasswordHash::new(password_hash) .map_err(|error| PasswordHashError::VerifyFailed(format!("密码哈希格式非法:{error}")))?; Ok(Argon2::default() .verify_password(password.as_bytes(), &parsed_hash) .is_ok()) } pub fn create_refresh_session_token() -> String { new_uuid_simple_string() } pub fn hash_refresh_session_token(token: &str) -> String { let mut hasher = Sha256::new(); hasher.update(token.as_bytes()); format!("{:x}", hasher.finalize()) } pub fn build_refresh_session_set_cookie(token: &str, config: &RefreshCookieConfig) -> String { let mut parts = vec![ format!( "{}={}", config.cookie_name(), urlencoding::encode(token).into_owned() ), format!("Path={}", config.cookie_path()), "HttpOnly".to_string(), format!("SameSite={}", config.cookie_same_site().as_str()), format!( "Max-Age={}", u64::from(config.refresh_session_ttl_days()) * 24 * 60 * 60 ), ]; if config.cookie_secure() { parts.push("Secure".to_string()); } parts.join("; ") } pub fn build_refresh_session_clear_cookie(config: &RefreshCookieConfig) -> String { let mut parts = vec![ format!("{}=", config.cookie_name()), format!("Path={}", config.cookie_path()), "HttpOnly".to_string(), format!("SameSite={}", config.cookie_same_site().as_str()), "Max-Age=0".to_string(), ]; if config.cookie_secure() { parts.push("Secure".to_string()); } parts.join("; ") } impl fmt::Display for JwtError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::InvalidConfig(message) | Self::InvalidClaims(message) => f.write_str(message), Self::SignFailed(message) | Self::VerifyFailed(message) => f.write_str(message), } } } impl Error for JwtError {} impl fmt::Display for RefreshCookieError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::InvalidConfig(message) => f.write_str(message), } } } impl Error for RefreshCookieError {} impl fmt::Display for PasswordHashError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::HashFailed(message) | Self::VerifyFailed(message) => f.write_str(message), } } } impl Error for PasswordHashError {} fn normalize_required_field( value: String, error_message: &'static str, ) -> Result { normalize_required_string(&value).ok_or(JwtError::InvalidClaims(error_message)) } fn normalize_optional_field(value: Option) -> Option { normalize_optional_string(value) } fn normalize_roles(roles: Vec) -> Result, JwtError> { let roles = roles .into_iter() .map(|role| role.trim().to_string()) .filter(|role| !role.is_empty()) .collect::>(); if roles.is_empty() { return Err(JwtError::InvalidClaims("JWT roles 至少包含一个角色")); } Ok(roles) } fn map_verify_error(error: jsonwebtoken::errors::Error) -> JwtError { let message = match error.kind() { ErrorKind::ExpiredSignature => "JWT 已过期".to_string(), ErrorKind::InvalidIssuer => "JWT 发行者不匹配".to_string(), ErrorKind::InvalidSignature => "JWT 签名无效".to_string(), ErrorKind::InvalidAlgorithm => "JWT 算法不匹配".to_string(), ErrorKind::InvalidToken => "JWT 非法".to_string(), ErrorKind::ImmatureSignature => "JWT 尚未生效".to_string(), ErrorKind::MissingRequiredClaim(claim) => format!("JWT 缺少必填字段:{claim}"), _ => format!("JWT 校验失败:{error}"), }; JwtError::VerifyFailed(message) } fn build_sms_provider_out_id(scene: &str, national_phone_number: &str) -> String { let phone_suffix = national_phone_number .chars() .rev() .take(4) .collect::() .chars() .rev() .collect::(); format!("{scene}_{}_{}", phone_suffix, new_uuid_simple_string()) } fn build_aliyun_sms_url(endpoint: &str) -> Result { let endpoint = endpoint .trim() .trim_start_matches("https://") .trim_start_matches("http://") .trim_matches('/'); if endpoint.is_empty() { return Err(SmsProviderError::InvalidConfig( "阿里云短信 endpoint 不能为空".to_string(), )); } Ok(format!("https://{endpoint}/")) } fn current_aliyun_timestamp() -> String { OffsetDateTime::now_utc() .format(&time::format_description::well_known::Rfc3339) .unwrap_or_else(|_| "1970-01-01T00:00:00Z".to_string()) } fn canonicalize_aliyun_rpc_params(params: &BTreeMap) -> String { params .iter() .filter(|(key, _)| key.as_str() != "Signature") .map(|(key, value)| { format!( "{}={}", aliyun_percent_encode(key), aliyun_percent_encode(value) ) }) .collect::>() .join("&") } fn aliyun_percent_encode(value: &str) -> String { urlencoding::encode(value) .into_owned() .replace('+', "%20") .replace('*', "%2A") .replace("%7E", "~") } async fn parse_aliyun_json_response( response: reqwest::Response, fallback_message: &str, ) -> Result { let status = response.status(); let body = response .text() .await .map_err(|error| SmsProviderError::Upstream(format!("{fallback_message}:{error}")))?; let payload = serde_json::from_str::(&body).map_err(|error| { SmsProviderError::Upstream(format!("{fallback_message}:响应解析失败:{error}")) })?; if status.is_client_error() || status.is_server_error() { return Err(map_http_status_to_sms_provider_error( fallback_message, status, serde_json::from_str::(&body).ok(), )); } Ok(payload) } async fn parse_aliyun_json_response_for_verify( response: reqwest::Response, ) -> Result { let status = response.status(); let body = response .text() .await .map_err(|error| SmsProviderError::Upstream(format!("验证码校验失败:{error}")))?; let payload = serde_json::from_str::(&body).map_err(|error| { SmsProviderError::Upstream(format!("验证码校验失败:响应解析失败:{error}")) })?; if status.is_client_error() || status.is_server_error() { return Err(map_http_status_to_sms_provider_error( "验证码校验失败", status, serde_json::from_str::(&body).ok(), )); } Ok(payload) } fn map_http_status_to_sms_provider_error( fallback_message: &str, status: StatusCode, payload: Option, ) -> SmsProviderError { let provider_message = payload .as_ref() .and_then(|value| value.get("Message").and_then(Value::as_str)) .unwrap_or_default(); let provider_code = payload .as_ref() .and_then(|value| value.get("Code").and_then(Value::as_str)) .unwrap_or_default(); if status.is_client_error() { return map_aliyun_provider_error( fallback_message, Some(provider_message.to_string()), Some(provider_code.to_string()), ); } SmsProviderError::Upstream(build_provider_error_message( fallback_message, provider_message, )) } fn map_aliyun_provider_error( fallback_message: &str, provider_message: Option, provider_code: Option, ) -> SmsProviderError { let provider_message = provider_message.unwrap_or_default(); let provider_code = provider_code.unwrap_or_default(); let normalized_code = provider_code.trim().to_ascii_uppercase(); if normalized_code.contains("VERIFY") || normalized_code.contains("CODE") || normalized_code.contains("CHECK") { return SmsProviderError::InvalidVerifyCode; } if normalized_code.contains("MOBILE") || normalized_code.contains("PHONE") || normalized_code.contains("SIGN") || normalized_code.contains("TEMPLATE") || normalized_code.contains("ACCESSKEY") { return SmsProviderError::InvalidConfig(build_provider_error_message( fallback_message, &provider_message, )); } SmsProviderError::Upstream(build_provider_error_message( fallback_message, &provider_message, )) } fn build_provider_error_message(prefix: &str, provider_message: &str) -> String { let provider_message = provider_message.trim(); if provider_message.is_empty() { prefix.to_string() } else { format!("{prefix}:{provider_message}") } } fn mask_phone_number(phone_number: &str) -> String { let chars: Vec = phone_number.chars().collect(); if chars.len() <= 4 { return "*".repeat(chars.len().max(1)); } let prefix_len = chars.len().min(3); let suffix_len = 4.min(chars.len().saturating_sub(prefix_len)); let mask_len = chars.len().saturating_sub(prefix_len + suffix_len); let mut masked = String::new(); masked.extend(chars.iter().take(prefix_len)); masked.push_str(&"*".repeat(mask_len.max(1))); if suffix_len > 0 { masked.extend(chars.iter().skip(chars.len() - suffix_len)); } masked } impl fmt::Display for SmsProviderError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::InvalidConfig(message) | Self::Upstream(message) => f.write_str(message), Self::InvalidVerifyCode => f.write_str("验证码错误"), } } } impl Error for SmsProviderError {} impl fmt::Display for WechatProviderError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Disabled => f.write_str("微信登录暂未启用"), Self::MissingCode => f.write_str("缺少微信授权 code"), Self::InvalidConfig(message) | Self::InvalidCallback(message) | Self::RequestFailed(message) | Self::DeserializeFailed(message) | Self::Upstream(message) | Self::MissingProfile(message) => f.write_str(message), } } } impl Error for WechatProviderError {} impl JwtError { pub fn kind(&self) -> AuthPlatformErrorKind { match self { Self::InvalidConfig(_) => AuthPlatformErrorKind::InvalidConfig, Self::InvalidClaims(_) => AuthPlatformErrorKind::InvalidClaims, Self::SignFailed(_) => AuthPlatformErrorKind::SignFailed, Self::VerifyFailed(_) => AuthPlatformErrorKind::VerifyFailed, } } } impl RefreshCookieError { pub fn kind(&self) -> AuthPlatformErrorKind { match self { Self::InvalidConfig(_) => AuthPlatformErrorKind::CookieConfig, } } } impl PasswordHashError { pub fn kind(&self) -> AuthPlatformErrorKind { match self { Self::HashFailed(_) => AuthPlatformErrorKind::HashFailed, Self::VerifyFailed(_) => AuthPlatformErrorKind::VerifyFailed, } } } impl SmsProviderError { pub fn kind(&self) -> AuthPlatformErrorKind { match self { Self::InvalidConfig(_) => AuthPlatformErrorKind::InvalidConfig, Self::InvalidVerifyCode => AuthPlatformErrorKind::InvalidVerifyCode, Self::Upstream(_) => AuthPlatformErrorKind::Upstream, } } } impl WechatProviderError { pub fn kind(&self) -> AuthPlatformErrorKind { match self { Self::Disabled => AuthPlatformErrorKind::Disabled, Self::MissingCode => AuthPlatformErrorKind::MissingCode, Self::InvalidConfig(_) => AuthPlatformErrorKind::InvalidConfig, Self::InvalidCallback(_) => AuthPlatformErrorKind::InvalidCallback, Self::RequestFailed(_) => AuthPlatformErrorKind::RequestFailed, Self::DeserializeFailed(_) => AuthPlatformErrorKind::DeserializeFailed, Self::Upstream(_) => AuthPlatformErrorKind::Upstream, Self::MissingProfile(_) => AuthPlatformErrorKind::MissingProfile, } } } #[cfg(test)] mod tests { use super::*; #[test] fn auth_platform_error_kind_is_stable_for_adapter_mapping() { assert_eq!( JwtError::InvalidClaims("JWT roles 至少包含一个角色").kind(), AuthPlatformErrorKind::InvalidClaims ); assert_eq!( PasswordHashError::VerifyFailed("密码校验失败".to_string()).kind(), AuthPlatformErrorKind::VerifyFailed ); assert_eq!( SmsProviderError::InvalidVerifyCode.kind(), AuthPlatformErrorKind::InvalidVerifyCode ); assert_eq!( WechatProviderError::MissingCode.kind(), AuthPlatformErrorKind::MissingCode ); } #[test] fn mock_wechat_provider_builds_callback_authorization_url() { let provider = WechatProvider::new(WechatAuthConfig::new( true, "mock".to_string(), None, None, None, None, DEFAULT_WECHAT_AUTHORIZE_ENDPOINT.to_string(), DEFAULT_WECHAT_ACCESS_TOKEN_ENDPOINT.to_string(), DEFAULT_WECHAT_USER_INFO_ENDPOINT.to_string(), DEFAULT_WECHAT_JS_CODE_SESSION_ENDPOINT.to_string(), "wx-user-001".to_string(), Some("wx-union-001".to_string()), "微信测试用户".to_string(), Some("https://example.test/avatar.png".to_string()), )); let authorization_url = provider .build_authorization_url( "http://127.0.0.1:3000/api/auth/wechat/callback", "state_001", &WechatAuthScene::Desktop, ) .expect("mock authorization url should build"); assert!(authorization_url.contains("mock_code=wx-mock-code")); assert!(authorization_url.contains("state=state_001")); } #[tokio::test] async fn mock_wechat_provider_resolves_identity_profile() { let provider = WechatProvider::new(WechatAuthConfig::new( true, "mock".to_string(), None, None, None, None, DEFAULT_WECHAT_AUTHORIZE_ENDPOINT.to_string(), DEFAULT_WECHAT_ACCESS_TOKEN_ENDPOINT.to_string(), DEFAULT_WECHAT_USER_INFO_ENDPOINT.to_string(), DEFAULT_WECHAT_JS_CODE_SESSION_ENDPOINT.to_string(), "wx-user-001".to_string(), Some("wx-union-001".to_string()), "微信测试用户".to_string(), None, )); let profile = provider .resolve_callback_profile(None, Some("wx-code-001")) .await .expect("mock profile should resolve"); assert_eq!(profile.provider_uid, "wx-code-001"); assert_eq!(profile.provider_union_id.as_deref(), Some("wx-union-001")); assert_eq!(profile.display_name.as_deref(), Some("微信测试用户")); } fn build_jwt_config() -> JwtConfig { JwtConfig::new( "https://auth.genarrative.local".to_string(), "genarrative-dev-secret".to_string(), DEFAULT_ACCESS_TOKEN_TTL_SECONDS, ) .expect("jwt config should be valid") } fn build_claims_input() -> AccessTokenClaimsInput { AccessTokenClaimsInput { user_id: "usr_123".to_string(), session_id: "sess_456".to_string(), provider: AuthProvider::Wechat, roles: vec!["user".to_string()], token_version: 3, phone_verified: false, binding_status: BindingStatus::PendingBindPhone, display_name: Some("微信旅人".to_string()), } } fn build_refresh_cookie_config() -> RefreshCookieConfig { RefreshCookieConfig::new( DEFAULT_REFRESH_COOKIE_NAME.to_string(), DEFAULT_REFRESH_COOKIE_PATH.to_string(), false, RefreshCookieSameSite::Lax, DEFAULT_REFRESH_SESSION_TTL_DAYS, ) .expect("refresh cookie config should be valid") } fn build_mock_sms_config() -> SmsAuthConfig { SmsAuthConfig::new( SmsAuthProviderKind::Mock, DEFAULT_SMS_ENDPOINT.to_string(), None, None, String::new(), String::new(), DEFAULT_SMS_TEMPLATE_PARAM_KEY.to_string(), DEFAULT_SMS_COUNTRY_CODE.to_string(), None, DEFAULT_SMS_CODE_LENGTH, DEFAULT_SMS_CODE_TYPE, DEFAULT_SMS_VALID_TIME_SECONDS, DEFAULT_SMS_INTERVAL_SECONDS, DEFAULT_SMS_DUPLICATE_POLICY, DEFAULT_SMS_CASE_AUTH_POLICY, false, DEFAULT_SMS_MOCK_VERIFY_CODE.to_string(), ) .expect("mock sms config should be valid") } #[test] fn round_trip_sign_and_verify_access_token() { let config = build_jwt_config(); let claims = AccessTokenClaims::from_input(build_claims_input(), &config, OffsetDateTime::now_utc()) .expect("claims should build"); let token = sign_access_token(&claims, &config).expect("token should sign"); let verified = verify_access_token(&token, &config).expect("token should verify"); assert_eq!(verified, claims); assert_eq!(verified.user_id(), "usr_123"); assert_eq!(verified.session_id(), "sess_456"); assert_eq!(verified.token_version(), 3); } #[test] fn verify_rejects_invalid_issuer() { let config = build_jwt_config(); let claims = AccessTokenClaims::from_input(build_claims_input(), &config, OffsetDateTime::now_utc()) .expect("claims should build"); let token = sign_access_token(&claims, &config).expect("token should sign"); let other_config = JwtConfig::new( "https://auth.other.local".to_string(), "genarrative-dev-secret".to_string(), DEFAULT_ACCESS_TOKEN_TTL_SECONDS, ) .expect("other config should be valid"); let error = verify_access_token(&token, &other_config).expect_err("issuer should mismatch"); assert_eq!( error, JwtError::VerifyFailed("JWT 发行者不匹配".to_string()) ); } #[test] fn build_claims_rejects_empty_roles() { let error = AccessTokenClaims::from_input( AccessTokenClaimsInput { roles: Vec::new(), ..build_claims_input() }, &build_jwt_config(), OffsetDateTime::now_utc(), ) .expect_err("empty roles should be rejected"); assert_eq!(error, JwtError::InvalidClaims("JWT roles 至少包含一个角色")); } #[test] fn read_refresh_session_token_returns_matching_cookie() { let token = read_refresh_session_token( "theme=dark; genarrative_refresh_session=refresh-token-01; locale=zh-CN", &build_refresh_cookie_config(), ); assert_eq!(token.as_deref(), Some("refresh-token-01")); } #[test] fn read_refresh_session_token_decodes_urlencoded_value() { let token = read_refresh_session_token( "genarrative_refresh_session=refresh%2Ftoken%3D01", &build_refresh_cookie_config(), ); assert_eq!(token.as_deref(), Some("refresh/token=01")); } #[test] fn read_refresh_session_token_returns_none_when_missing() { let token = read_refresh_session_token("theme=dark; locale=zh-CN", &build_refresh_cookie_config()); assert!(token.is_none()); } #[tokio::test] async fn hash_and_verify_password_round_trip() { let password_hash = hash_password("secret123") .await .expect("password hash should build"); let is_valid = verify_password(&password_hash, "secret123") .await .expect("password hash should verify"); assert!(is_valid); } #[test] fn build_refresh_session_cookie_respects_config() { let cookie = build_refresh_session_set_cookie("refresh/token=01", &build_refresh_cookie_config()); assert!(cookie.contains("genarrative_refresh_session=refresh%2Ftoken%3D01")); assert!(cookie.contains("Path=/api/auth")); assert!(cookie.contains("HttpOnly")); assert!(cookie.contains("SameSite=Lax")); assert!(cookie.contains("Max-Age=2592000")); } #[test] fn hash_refresh_session_token_matches_sha256_hex() { let hash = hash_refresh_session_token("refresh-token-01"); assert_eq!( hash, "9fab76f9100ec6c151c8caa0c42ab10e10fbc7618f15e24cf3dffc93e19c4c4e" ); } #[test] fn build_refresh_session_clear_cookie_respects_config() { let cookie = build_refresh_session_clear_cookie(&build_refresh_cookie_config()); assert!(cookie.contains("genarrative_refresh_session=")); assert!(cookie.contains("Path=/api/auth")); assert!(cookie.contains("HttpOnly")); assert!(cookie.contains("SameSite=Lax")); assert!(cookie.contains("Max-Age=0")); } #[test] fn sms_auth_provider_kind_parses_supported_values() { assert_eq!( SmsAuthProviderKind::parse("mock"), Some(SmsAuthProviderKind::Mock) ); assert_eq!( SmsAuthProviderKind::parse("aliyun"), Some(SmsAuthProviderKind::Aliyun) ); assert_eq!(SmsAuthProviderKind::parse("other"), None); } #[tokio::test] async fn mock_sms_provider_sends_and_verifies_code() { let provider = SmsAuthProvider::new(build_mock_sms_config()).expect("provider should build"); let send_result = provider .send_code(SmsSendCodeRequest { national_phone_number: "13800138000".to_string(), scene: "login".to_string(), }) .await .expect("send code should succeed"); assert_eq!(send_result.cooldown_seconds, DEFAULT_SMS_INTERVAL_SECONDS); assert_eq!( send_result.expires_in_seconds, DEFAULT_SMS_VALID_TIME_SECONDS ); assert_eq!( send_result.provider_request_id.as_deref(), Some("mock-request-id") ); assert!(send_result.provider_out_id.is_some()); provider .verify_code(SmsVerifyCodeRequest { national_phone_number: "13800138000".to_string(), verify_code: DEFAULT_SMS_MOCK_VERIFY_CODE.to_string(), provider_out_id: send_result.provider_out_id, }) .await .expect("verify code should succeed"); } #[tokio::test] async fn mock_sms_provider_rejects_wrong_code() { let provider = SmsAuthProvider::new(build_mock_sms_config()).expect("provider should build"); let error = provider .verify_code(SmsVerifyCodeRequest { national_phone_number: "13800138000".to_string(), verify_code: "000000".to_string(), provider_out_id: None, }) .await .expect_err("wrong verify code should fail"); assert_eq!(error, SmsProviderError::InvalidVerifyCode); } #[test] fn aliyun_sms_config_requires_access_key() { let error = SmsAuthConfig::new( SmsAuthProviderKind::Aliyun, DEFAULT_SMS_ENDPOINT.to_string(), None, None, "测试签名".to_string(), "SMS_001".to_string(), DEFAULT_SMS_TEMPLATE_PARAM_KEY.to_string(), DEFAULT_SMS_COUNTRY_CODE.to_string(), None, DEFAULT_SMS_CODE_LENGTH, DEFAULT_SMS_CODE_TYPE, DEFAULT_SMS_VALID_TIME_SECONDS, DEFAULT_SMS_INTERVAL_SECONDS, DEFAULT_SMS_DUPLICATE_POLICY, DEFAULT_SMS_CASE_AUTH_POLICY, false, DEFAULT_SMS_MOCK_VERIFY_CODE.to_string(), ) .expect_err("aliyun config without access key should fail"); assert_eq!( error, SmsProviderError::InvalidConfig("阿里云短信 AccessKey 未配置".to_string()) ); } #[test] fn canonicalize_aliyun_rpc_params_keeps_sorted_percent_encoded_order() { let mut params = BTreeMap::new(); params.insert( "TemplateParam".to_string(), "{\"code\":\"##code##\"}".to_string(), ); params.insert("Action".to_string(), "SendSmsVerifyCode".to_string()); params.insert("PhoneNumber".to_string(), "13800138000".to_string()); assert_eq!( canonicalize_aliyun_rpc_params(¶ms), "Action=SendSmsVerifyCode&PhoneNumber=13800138000&TemplateParam=%7B%22code%22%3A%22%23%23code%23%23%22%7D" ); } #[test] fn aliyun_send_response_deserializes_pascal_case_fields() { let payload = serde_json::from_str::( r#"{ "Code": "OK", "Message": "成功", "RequestId": "req_123", "Success": true, "Model": { "BizId": "biz_456", "OutId": "out_789", "RequestId": "req_model_001" } }"#, ) .expect("aliyun send response should deserialize"); assert_eq!(payload.code.as_deref(), Some("OK")); assert_eq!(payload.message.as_deref(), Some("成功")); assert_eq!(payload.request_id.as_deref(), Some("req_123")); assert_eq!(payload.success, Some(true)); assert_eq!( payload .model .as_ref() .and_then(|model| model.out_id.as_deref()), Some("out_789") ); assert_eq!( payload .model .as_ref() .and_then(|model| model.request_id.as_deref()), Some("req_model_001") ); } #[test] fn aliyun_verify_response_deserializes_pascal_case_fields() { let payload = serde_json::from_str::( r#"{ "Code": "OK", "Message": "成功", "Success": true, "Model": { "OutId": "out_789", "VerifyResult": "PASS" } }"#, ) .expect("aliyun verify response should deserialize"); assert_eq!(payload.code.as_deref(), Some("OK")); assert_eq!(payload.message.as_deref(), Some("成功")); assert_eq!(payload.success, Some(true)); assert_eq!( payload .model .as_ref() .and_then(|model| model.verify_result.as_deref()), Some("PASS") ); } }