use std::{ collections::{BTreeMap, HashSet}, error::Error, fmt, }; use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier, password_hash::SaltString}; use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD}; use hmac::{Hmac, Mac}; use jsonwebtoken::{ Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode, errors::ErrorKind, }; use rand_core::OsRng; use reqwest::{Client, StatusCode}; use serde::{Deserialize, Serialize}; use serde_json::Value; use sha1::Sha1; use sha2::{Digest, Sha256}; use shared_kernel::{new_uuid_simple_string, normalize_optional_string, normalize_required_string}; use time::{Duration, OffsetDateTime}; use tracing::{info, warn}; pub const ACCESS_TOKEN_ALGORITHM: Algorithm = Algorithm::HS256; pub const DEFAULT_ACCESS_TOKEN_TTL_SECONDS: u64 = 2 * 60 * 60; pub const DEFAULT_REFRESH_COOKIE_NAME: &str = "genarrative_refresh_session"; pub const DEFAULT_REFRESH_COOKIE_PATH: &str = "/api/auth"; pub const DEFAULT_REFRESH_SESSION_TTL_DAYS: u32 = 30; pub const DEFAULT_SMS_ENDPOINT: &str = "dypnsapi.aliyuncs.com"; pub const DEFAULT_SMS_COUNTRY_CODE: &str = "86"; pub const DEFAULT_SMS_TEMPLATE_PARAM_KEY: &str = "code"; pub const DEFAULT_SMS_MOCK_VERIFY_CODE: &str = "123456"; pub const DEFAULT_SMS_CODE_LENGTH: u8 = 6; pub const DEFAULT_SMS_CODE_TYPE: u8 = 1; pub const DEFAULT_SMS_VALID_TIME_SECONDS: u64 = 300; pub const DEFAULT_SMS_INTERVAL_SECONDS: u64 = 60; pub const DEFAULT_SMS_DUPLICATE_POLICY: u8 = 1; pub const DEFAULT_SMS_CASE_AUTH_POLICY: u8 = 1; type HmacSha1 = Hmac; // 鉴权 provider 直接冻结成文档中约定的枚举,避免后续在多个 crate 内重复发明字符串字面量。 #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum AuthProvider { Password, Phone, Wechat, } // 绑定状态只保留当前 JWT 需要透传的最小快照,不把完整账号状态枚举直接泄漏到 token 中。 #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum BindingStatus { Active, PendingBindPhone, } // 用于签发 access token 的领域输入,和最终 JWT claims 解耦,避免业务层手动拼 iat/exp/iss。 #[derive(Clone, Debug, PartialEq, Eq)] pub struct AccessTokenClaimsInput { pub user_id: String, pub session_id: String, pub provider: AuthProvider, pub roles: Vec, pub token_version: u64, pub phone_verified: bool, pub binding_status: BindingStatus, pub display_name: Option, } // 直接映射最终 JWT payload,字段名与文档冻结口径保持一致。 #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct AccessTokenClaims { pub iss: String, pub sub: String, pub sid: String, pub provider: AuthProvider, pub roles: Vec, pub ver: u64, pub phone_verified: bool, pub binding_status: BindingStatus, #[serde(skip_serializing_if = "Option::is_none")] pub display_name: Option, pub iat: u64, pub exp: u64, } // 统一承载 JWT 配置,避免 secret、issuer、ttl 在 api-server 与后续模块里散落。 #[derive(Clone, Debug, PartialEq, Eq)] pub struct JwtConfig { issuer: String, secret: String, access_token_ttl_seconds: u64, } // refresh cookie 的 SameSite 固定约束成枚举,避免各层直接使用大小写不一致的字符串。 #[derive(Clone, Debug, PartialEq, Eq)] pub enum RefreshCookieSameSite { Lax, Strict, None, } // refresh cookie 的平台配置统一收口到 platform-auth,避免 api-server 直接散落 cookie 细节。 #[derive(Clone, Debug, PartialEq, Eq)] pub struct RefreshCookieConfig { cookie_name: String, cookie_path: String, cookie_secure: bool, cookie_same_site: RefreshCookieSameSite, refresh_session_ttl_days: u32, } #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum SmsAuthProviderKind { Mock, Aliyun, } #[derive(Clone, Debug, PartialEq, Eq)] pub struct SmsAuthConfig { pub provider: SmsAuthProviderKind, pub endpoint: String, pub access_key_id: Option, pub access_key_secret: Option, pub sign_name: String, pub template_code: String, pub template_param_key: String, pub country_code: String, pub scheme_name: Option, pub code_length: u8, pub code_type: u8, pub valid_time_seconds: u64, pub interval_seconds: u64, pub duplicate_policy: u8, pub case_auth_policy: u8, pub return_verify_code: bool, pub mock_verify_code: String, } #[derive(Clone, Debug, PartialEq, Eq)] pub struct SmsSendCodeRequest { pub national_phone_number: String, pub scene: String, } #[derive(Clone, Debug, PartialEq, Eq)] pub struct SmsSendCodeResult { pub cooldown_seconds: u64, pub expires_in_seconds: u64, pub provider_request_id: Option, pub provider_out_id: Option, } #[derive(Clone, Debug, PartialEq, Eq)] pub struct SmsVerifyCodeRequest { pub national_phone_number: String, pub verify_code: String, pub provider_out_id: Option, } #[derive(Clone, Debug)] pub enum SmsAuthProvider { Mock(MockSmsAuthProvider), Aliyun(AliyunSmsAuthProvider), } #[derive(Clone, Debug)] pub struct MockSmsAuthProvider { config: SmsAuthConfig, } #[derive(Clone, Debug)] pub struct AliyunSmsAuthProvider { client: Client, config: SmsAuthConfig, } #[derive(Debug, PartialEq, Eq)] pub enum JwtError { InvalidConfig(&'static str), InvalidClaims(&'static str), SignFailed(String), VerifyFailed(String), } #[derive(Debug, PartialEq, Eq)] pub enum RefreshCookieError { InvalidConfig(&'static str), } #[derive(Debug, PartialEq, Eq)] pub enum PasswordHashError { HashFailed(String), VerifyFailed(String), } #[derive(Debug, PartialEq, Eq)] pub enum SmsProviderError { InvalidConfig(String), InvalidVerifyCode, Upstream(String), } #[derive(Debug, Deserialize)] struct AliyunSendSmsVerifyCodeResponse { // 阿里云 RPC 原始 JSON 使用首字母大写字段名,这里必须显式映射,避免把成功响应误判成空值。 #[serde(default, rename = "Code")] code: Option, #[serde(default, rename = "Message")] message: Option, #[serde(default, rename = "RequestId")] request_id: Option, #[serde(default, rename = "Success")] success: Option, #[serde(default, rename = "Model")] model: Option, } #[derive(Debug, Deserialize)] struct AliyunSendSmsVerifyCodeModel { #[serde(default, rename = "BizId")] _biz_id: Option, #[serde(default, rename = "OutId")] out_id: Option, #[serde(default, rename = "RequestId")] request_id: Option, } #[derive(Debug, Deserialize)] struct AliyunCheckSmsVerifyCodeResponse { // 校验接口同样返回首字母大写字段名,保持和发送接口一致的显式映射。 #[serde(default, rename = "Code")] code: Option, #[serde(default, rename = "Message")] message: Option, #[serde(default, rename = "Success")] success: Option, #[serde(default, rename = "Model")] model: Option, } #[derive(Debug, Deserialize)] struct AliyunCheckSmsVerifyCodeModel { #[serde(default, rename = "OutId")] _out_id: Option, #[serde(default, rename = "VerifyResult")] verify_result: Option, } impl JwtConfig { pub fn new( issuer: String, secret: String, access_token_ttl_seconds: u64, ) -> Result { let issuer = normalize_required_string(&issuer) .ok_or(JwtError::InvalidConfig("JWT issuer 不能为空"))?; let secret = normalize_required_string(&secret) .ok_or(JwtError::InvalidConfig("JWT secret 不能为空"))?; if access_token_ttl_seconds == 0 { return Err(JwtError::InvalidConfig( "JWT access token 过期时间必须大于 0", )); } Ok(Self { issuer, secret, access_token_ttl_seconds, }) } pub fn issuer(&self) -> &str { &self.issuer } pub fn access_token_ttl_seconds(&self) -> u64 { self.access_token_ttl_seconds } } impl RefreshCookieSameSite { pub fn parse(raw: &str) -> Option { match raw.trim().to_ascii_lowercase().as_str() { "lax" => Some(Self::Lax), "strict" => Some(Self::Strict), "none" => Some(Self::None), _ => None, } } pub fn as_str(&self) -> &'static str { match self { Self::Lax => "Lax", Self::Strict => "Strict", Self::None => "None", } } } impl RefreshCookieConfig { pub fn new( cookie_name: String, cookie_path: String, cookie_secure: bool, cookie_same_site: RefreshCookieSameSite, refresh_session_ttl_days: u32, ) -> Result { let cookie_name = normalize_required_string(&cookie_name).ok_or( RefreshCookieError::InvalidConfig("refresh cookie 名称不能为空"), )?; let cookie_path = normalize_required_string(&cookie_path).ok_or( RefreshCookieError::InvalidConfig("refresh cookie path 不能为空"), )?; if refresh_session_ttl_days == 0 { return Err(RefreshCookieError::InvalidConfig( "refresh session TTL 天数必须大于 0", )); } Ok(Self { cookie_name, cookie_path, cookie_secure, cookie_same_site, refresh_session_ttl_days, }) } pub fn cookie_name(&self) -> &str { &self.cookie_name } pub fn cookie_path(&self) -> &str { &self.cookie_path } pub fn cookie_secure(&self) -> bool { self.cookie_secure } pub fn cookie_same_site(&self) -> &RefreshCookieSameSite { &self.cookie_same_site } pub fn refresh_session_ttl_days(&self) -> u32 { self.refresh_session_ttl_days } } impl SmsAuthProviderKind { pub fn parse(raw: &str) -> Option { match raw.trim().to_ascii_lowercase().as_str() { "mock" => Some(Self::Mock), "aliyun" => Some(Self::Aliyun), _ => None, } } pub fn as_str(&self) -> &'static str { match self { Self::Mock => "mock", Self::Aliyun => "aliyun", } } } impl SmsAuthConfig { pub fn new( provider: SmsAuthProviderKind, endpoint: String, access_key_id: Option, access_key_secret: Option, sign_name: String, template_code: String, template_param_key: String, country_code: String, scheme_name: Option, code_length: u8, code_type: u8, valid_time_seconds: u64, interval_seconds: u64, duplicate_policy: u8, case_auth_policy: u8, return_verify_code: bool, mock_verify_code: String, ) -> Result { let endpoint = normalize_required_string(&endpoint) .unwrap_or_else(|| DEFAULT_SMS_ENDPOINT.to_string()); let template_param_key = normalize_required_string(&template_param_key) .unwrap_or_else(|| DEFAULT_SMS_TEMPLATE_PARAM_KEY.to_string()); let country_code = normalize_required_string(&country_code) .unwrap_or_else(|| DEFAULT_SMS_COUNTRY_CODE.to_string()); let scheme_name = normalize_optional_string(scheme_name); let mock_verify_code = normalize_required_string(&mock_verify_code) .unwrap_or_else(|| DEFAULT_SMS_MOCK_VERIFY_CODE.to_string()); if !(4..=8).contains(&code_length) { return Err(SmsProviderError::InvalidConfig( "短信验证码长度必须在 4 到 8 之间".to_string(), )); } if !(1..=7).contains(&code_type) { return Err(SmsProviderError::InvalidConfig( "短信验证码类型取值非法".to_string(), )); } if interval_seconds == 0 || valid_time_seconds == 0 { return Err(SmsProviderError::InvalidConfig( "短信验证码有效期和发送间隔必须大于 0".to_string(), )); } if !(1..=2).contains(&duplicate_policy) { return Err(SmsProviderError::InvalidConfig( "短信验证码重复策略取值非法".to_string(), )); } if !(1..=2).contains(&case_auth_policy) { return Err(SmsProviderError::InvalidConfig( "短信验证码大小写校验策略取值非法".to_string(), )); } match provider { SmsAuthProviderKind::Mock => {} SmsAuthProviderKind::Aliyun => { if normalize_required_string(&sign_name).is_none() { return Err(SmsProviderError::InvalidConfig( "阿里云短信签名不能为空".to_string(), )); } if normalize_required_string(&template_code).is_none() { return Err(SmsProviderError::InvalidConfig( "阿里云短信模板编码不能为空".to_string(), )); } if access_key_id .as_deref() .and_then(normalize_required_string) .is_none() || access_key_secret .as_deref() .and_then(normalize_required_string) .is_none() { return Err(SmsProviderError::InvalidConfig( "阿里云短信 AccessKey 未配置".to_string(), )); } } } Ok(Self { provider, endpoint, access_key_id: access_key_id.and_then(|value| normalize_required_string(&value)), access_key_secret: access_key_secret .and_then(|value| normalize_required_string(&value)), sign_name: sign_name.trim().to_string(), template_code: template_code.trim().to_string(), template_param_key, country_code, scheme_name, code_length, code_type, valid_time_seconds, interval_seconds, duplicate_policy, case_auth_policy, return_verify_code, mock_verify_code, }) } } impl SmsAuthProvider { pub fn new(config: SmsAuthConfig) -> Result { match config.provider { SmsAuthProviderKind::Mock => Ok(Self::Mock(MockSmsAuthProvider { config })), SmsAuthProviderKind::Aliyun => Ok(Self::Aliyun(AliyunSmsAuthProvider { client: Client::new(), config, })), } } pub fn kind(&self) -> SmsAuthProviderKind { match self { Self::Mock(_) => SmsAuthProviderKind::Mock, Self::Aliyun(_) => SmsAuthProviderKind::Aliyun, } } pub async fn send_code( &self, request: SmsSendCodeRequest, ) -> Result { match self { Self::Mock(provider) => provider.send_code(request).await, Self::Aliyun(provider) => provider.send_code(request).await, } } pub async fn verify_code(&self, request: SmsVerifyCodeRequest) -> Result<(), SmsProviderError> { match self { Self::Mock(provider) => provider.verify_code(request).await, Self::Aliyun(provider) => provider.verify_code(request).await, } } } impl MockSmsAuthProvider { async fn send_code( &self, request: SmsSendCodeRequest, ) -> Result { let provider_out_id = build_sms_provider_out_id(&request.scene, &request.national_phone_number); Ok(SmsSendCodeResult { cooldown_seconds: self.config.interval_seconds, expires_in_seconds: self.config.valid_time_seconds, provider_request_id: Some("mock-request-id".to_string()), provider_out_id: Some(provider_out_id), }) } async fn verify_code(&self, request: SmsVerifyCodeRequest) -> Result<(), SmsProviderError> { if request.verify_code.trim() != self.config.mock_verify_code { return Err(SmsProviderError::InvalidVerifyCode); } Ok(()) } } impl AliyunSmsAuthProvider { async fn send_code( &self, request: SmsSendCodeRequest, ) -> Result { let provider_out_id = build_sms_provider_out_id(&request.scene, &request.national_phone_number); let phone_masked = mask_phone_number(&request.national_phone_number); let template_param = serde_json::json!({ self.config.template_param_key.clone(): "##code##", "min": self.config.valid_time_seconds, }) .to_string(); info!( provider = "aliyun", scene = request.scene.as_str(), phone_masked = phone_masked.as_str(), endpoint = self.config.endpoint.as_str(), sign_name = self.config.sign_name.as_str(), template_code = self.config.template_code.as_str(), code_length = self.config.code_length, valid_time_seconds = self.config.valid_time_seconds, interval_seconds = self.config.interval_seconds, provider_out_id = provider_out_id.as_str(), "准备调用阿里云短信发送接口" ); let mut query = BTreeMap::new(); query.insert("Action".to_string(), "SendSmsVerifyCode".to_string()); query.insert("Format".to_string(), "json".to_string()); query.insert("Version".to_string(), "2017-05-25".to_string()); query.insert("Timestamp".to_string(), current_aliyun_timestamp()); query.insert("SignatureNonce".to_string(), new_uuid_simple_string()); query.insert("SignatureMethod".to_string(), "HMAC-SHA1".to_string()); query.insert("SignatureVersion".to_string(), "1.0".to_string()); query.insert( "AccessKeyId".to_string(), self.config.access_key_id.clone().unwrap_or_default(), ); query.insert( "PhoneNumber".to_string(), request.national_phone_number.trim().to_string(), ); query.insert("CountryCode".to_string(), self.config.country_code.clone()); query.insert("SignName".to_string(), self.config.sign_name.clone()); query.insert( "TemplateCode".to_string(), self.config.template_code.clone(), ); query.insert("TemplateParam".to_string(), template_param); query.insert( "CodeLength".to_string(), self.config.code_length.to_string(), ); query.insert("CodeType".to_string(), self.config.code_type.to_string()); query.insert( "ValidTime".to_string(), self.config.valid_time_seconds.to_string(), ); query.insert( "Interval".to_string(), self.config.interval_seconds.to_string(), ); query.insert( "DuplicatePolicy".to_string(), self.config.duplicate_policy.to_string(), ); query.insert( "ReturnVerifyCode".to_string(), self.config.return_verify_code.to_string(), ); query.insert("OutId".to_string(), provider_out_id.clone()); if let Some(scheme_name) = self.config.scheme_name.clone() { query.insert("SchemeName".to_string(), scheme_name); } self.sign_query(&mut query)?; let payload = self .client .post(build_aliyun_sms_url(&self.config.endpoint)?) .form(&query) .send() .await .map_err(|error| SmsProviderError::Upstream(format!("短信验证码发送失败:{error}")))?; let http_status = payload.status(); let body = parse_aliyun_json_response(payload, "短信验证码发送失败").await?; info!( provider = "aliyun", scene = request.scene.as_str(), phone_masked = phone_masked.as_str(), http_status = http_status.as_u16(), provider_code = body.code.as_deref().unwrap_or("unknown"), provider_message = body.message.as_deref().unwrap_or("unknown"), provider_request_id = body .request_id .as_deref() .or_else(|| body .model .as_ref() .and_then(|model| model.request_id.as_deref())) .unwrap_or("unknown"), provider_out_id = body .model .as_ref() .and_then(|model| model.out_id.as_deref()) .unwrap_or("unknown"), success = body.success.unwrap_or(false), "阿里云短信发送接口返回响应" ); if !body.success.unwrap_or(false) || body.code.as_deref() != Some("OK") { warn!( provider = "aliyun", scene = request.scene.as_str(), phone_masked = phone_masked.as_str(), http_status = http_status.as_u16(), provider_code = body.code.as_deref().unwrap_or("unknown"), provider_message = body.message.as_deref().unwrap_or("unknown"), provider_request_id = body .request_id .as_deref() .or_else(|| body .model .as_ref() .and_then(|model| model.request_id.as_deref())) .unwrap_or("unknown"), provider_out_id = body .model .as_ref() .and_then(|model| model.out_id.as_deref()) .unwrap_or("unknown"), "阿里云短信发送接口返回业务失败" ); return Err(map_aliyun_provider_error( "短信验证码发送失败", body.message, body.code, )); } Ok(SmsSendCodeResult { cooldown_seconds: self.config.interval_seconds, expires_in_seconds: self.config.valid_time_seconds, provider_request_id: body.request_id.or_else(|| { body.model .as_ref() .and_then(|model| model.request_id.clone()) }), provider_out_id: body.model.and_then(|model| model.out_id), }) } async fn verify_code(&self, request: SmsVerifyCodeRequest) -> Result<(), SmsProviderError> { let mut query = BTreeMap::new(); query.insert("Action".to_string(), "CheckSmsVerifyCode".to_string()); query.insert("Format".to_string(), "json".to_string()); query.insert("Version".to_string(), "2017-05-25".to_string()); query.insert("Timestamp".to_string(), current_aliyun_timestamp()); query.insert("SignatureNonce".to_string(), new_uuid_simple_string()); query.insert("SignatureMethod".to_string(), "HMAC-SHA1".to_string()); query.insert("SignatureVersion".to_string(), "1.0".to_string()); query.insert( "AccessKeyId".to_string(), self.config.access_key_id.clone().unwrap_or_default(), ); query.insert( "PhoneNumber".to_string(), request.national_phone_number.trim().to_string(), ); query.insert("CountryCode".to_string(), self.config.country_code.clone()); query.insert( "VerifyCode".to_string(), request.verify_code.trim().to_string(), ); query.insert( "CaseAuthPolicy".to_string(), self.config.case_auth_policy.to_string(), ); if let Some(scheme_name) = self.config.scheme_name.clone() { query.insert("SchemeName".to_string(), scheme_name); } if let Some(provider_out_id) = request.provider_out_id { query.insert("OutId".to_string(), provider_out_id); } self.sign_query(&mut query)?; let payload = self .client .post(build_aliyun_sms_url(&self.config.endpoint)?) .form(&query) .send() .await .map_err(|error| SmsProviderError::Upstream(format!("验证码校验失败:{error}")))?; let body = parse_aliyun_json_response_for_verify(payload).await?; if !body.success.unwrap_or(false) || body.code.as_deref() != Some("OK") { return Err(map_aliyun_provider_error( "验证码校验失败", body.message, body.code, )); } if body.model.and_then(|model| model.verify_result).as_deref() != Some("PASS") { return Err(SmsProviderError::InvalidVerifyCode); } Ok(()) } fn sign_query(&self, query: &mut BTreeMap) -> Result<(), SmsProviderError> { let access_key_secret = self.config.access_key_secret.as_deref().ok_or_else(|| { SmsProviderError::InvalidConfig("阿里云短信 AccessKeySecret 未配置".to_string()) })?; let canonicalized = canonicalize_aliyun_rpc_params(query); let string_to_sign = format!( "POST&{}&{}", aliyun_percent_encode("/"), aliyun_percent_encode(&canonicalized) ); let mut signer = HmacSha1::new_from_slice(format!("{access_key_secret}&").as_bytes()) .map_err(|error| { SmsProviderError::InvalidConfig(format!("初始化短信签名器失败:{error}")) })?; signer.update(string_to_sign.as_bytes()); let signature = BASE64_STANDARD.encode(signer.finalize().into_bytes()); query.insert("Signature".to_string(), signature); Ok(()) } } impl AccessTokenClaims { pub fn from_input( input: AccessTokenClaimsInput, config: &JwtConfig, issued_at: OffsetDateTime, ) -> Result { let user_id = normalize_required_field(input.user_id, "JWT sub 不能为空")?; let session_id = normalize_required_field(input.session_id, "JWT sid 不能为空")?; let roles = normalize_roles(input.roles)?; let display_name = normalize_optional_field(input.display_name); let issued_at_unix = issued_at.unix_timestamp(); if issued_at_unix < 0 { return Err(JwtError::InvalidClaims("JWT iat 不能早于 Unix epoch")); } let expires_at = issued_at .checked_add(Duration::seconds( i64::try_from(config.access_token_ttl_seconds()).map_err(|_| { JwtError::InvalidConfig("JWT access token 过期时间超出 i64 上限") })?, )) .ok_or(JwtError::InvalidConfig("JWT 过期时间计算溢出"))?; let expires_at_unix = expires_at.unix_timestamp(); if expires_at_unix <= issued_at_unix { return Err(JwtError::InvalidClaims("JWT exp 必须晚于 iat")); } let claims = Self { iss: config.issuer().to_string(), sub: user_id, sid: session_id, provider: input.provider, roles, ver: input.token_version, phone_verified: input.phone_verified, binding_status: input.binding_status, display_name, iat: issued_at_unix as u64, exp: expires_at_unix as u64, }; claims.validate_for_config(config)?; Ok(claims) } pub fn user_id(&self) -> &str { &self.sub } pub fn session_id(&self) -> &str { &self.sid } pub fn token_version(&self) -> u64 { self.ver } pub fn validate_for_config(&self, config: &JwtConfig) -> Result<(), JwtError> { if self.iss.trim() != config.issuer() { return Err(JwtError::InvalidClaims("JWT iss 与当前配置不一致")); } normalize_required_field(self.sub.clone(), "JWT sub 不能为空")?; normalize_required_field(self.sid.clone(), "JWT sid 不能为空")?; normalize_roles(self.roles.clone())?; if self.exp <= self.iat { return Err(JwtError::InvalidClaims("JWT exp 必须晚于 iat")); } Ok(()) } } pub fn sign_access_token( claims: &AccessTokenClaims, config: &JwtConfig, ) -> Result { claims.validate_for_config(config)?; let header = Header { alg: ACCESS_TOKEN_ALGORITHM, typ: Some("JWT".to_string()), ..Header::default() }; encode( &header, claims, &EncodingKey::from_secret(config.secret.as_bytes()), ) .map_err(|error| JwtError::SignFailed(format!("JWT 签发失败:{error}"))) } pub fn verify_access_token(token: &str, config: &JwtConfig) -> Result { let token = token.trim(); if token.is_empty() { return Err(JwtError::VerifyFailed("JWT 不能为空".to_string())); } let mut validation = Validation::new(ACCESS_TOKEN_ALGORITHM); validation.required_spec_claims = HashSet::from([ "exp".to_string(), "iat".to_string(), "iss".to_string(), "sub".to_string(), ]); validation.set_issuer(&[config.issuer()]); let decoded = decode::( token, &DecodingKey::from_secret(config.secret.as_bytes()), &validation, ) .map_err(map_verify_error)?; decoded.claims.validate_for_config(config)?; Ok(decoded.claims) } pub fn read_refresh_session_token( cookie_header: &str, config: &RefreshCookieConfig, ) -> Option { let cookie_header = cookie_header.trim(); if cookie_header.is_empty() { return None; } for entry in cookie_header.split(';') { let entry = entry.trim(); if entry.is_empty() { continue; } let (raw_name, raw_value) = entry.split_once('=')?; if raw_name.trim() != config.cookie_name() { continue; } let raw_value = raw_value.trim(); if raw_value.is_empty() { return None; } return urlencoding::decode(raw_value) .ok() .map(|decoded| decoded.into_owned()); } None } pub async fn hash_password(password: &str) -> Result { let salt = SaltString::generate(&mut OsRng); Argon2::default() .hash_password(password.as_bytes(), &salt) .map(|hash| hash.to_string()) .map_err(|error| PasswordHashError::HashFailed(format!("密码哈希失败:{error}"))) } pub async fn verify_password( password_hash: &str, password: &str, ) -> Result { let parsed_hash = PasswordHash::new(password_hash) .map_err(|error| PasswordHashError::VerifyFailed(format!("密码哈希格式非法:{error}")))?; Ok(Argon2::default() .verify_password(password.as_bytes(), &parsed_hash) .is_ok()) } pub fn create_refresh_session_token() -> String { new_uuid_simple_string() } pub fn hash_refresh_session_token(token: &str) -> String { let mut hasher = Sha256::new(); hasher.update(token.as_bytes()); format!("{:x}", hasher.finalize()) } pub fn build_refresh_session_set_cookie(token: &str, config: &RefreshCookieConfig) -> String { let mut parts = vec![ format!( "{}={}", config.cookie_name(), urlencoding::encode(token).into_owned() ), format!("Path={}", config.cookie_path()), "HttpOnly".to_string(), format!("SameSite={}", config.cookie_same_site().as_str()), format!( "Max-Age={}", u64::from(config.refresh_session_ttl_days()) * 24 * 60 * 60 ), ]; if config.cookie_secure() { parts.push("Secure".to_string()); } parts.join("; ") } pub fn build_refresh_session_clear_cookie(config: &RefreshCookieConfig) -> String { let mut parts = vec![ format!("{}=", config.cookie_name()), format!("Path={}", config.cookie_path()), "HttpOnly".to_string(), format!("SameSite={}", config.cookie_same_site().as_str()), "Max-Age=0".to_string(), ]; if config.cookie_secure() { parts.push("Secure".to_string()); } parts.join("; ") } impl fmt::Display for JwtError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::InvalidConfig(message) | Self::InvalidClaims(message) => f.write_str(message), Self::SignFailed(message) | Self::VerifyFailed(message) => f.write_str(message), } } } impl Error for JwtError {} impl fmt::Display for RefreshCookieError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::InvalidConfig(message) => f.write_str(message), } } } impl Error for RefreshCookieError {} impl fmt::Display for PasswordHashError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::HashFailed(message) | Self::VerifyFailed(message) => f.write_str(message), } } } impl Error for PasswordHashError {} fn normalize_required_field( value: String, error_message: &'static str, ) -> Result { normalize_required_string(&value).ok_or(JwtError::InvalidClaims(error_message)) } fn normalize_optional_field(value: Option) -> Option { normalize_optional_string(value) } fn normalize_roles(roles: Vec) -> Result, JwtError> { let roles = roles .into_iter() .map(|role| role.trim().to_string()) .filter(|role| !role.is_empty()) .collect::>(); if roles.is_empty() { return Err(JwtError::InvalidClaims("JWT roles 至少包含一个角色")); } Ok(roles) } fn map_verify_error(error: jsonwebtoken::errors::Error) -> JwtError { let message = match error.kind() { ErrorKind::ExpiredSignature => "JWT 已过期".to_string(), ErrorKind::InvalidIssuer => "JWT 发行者不匹配".to_string(), ErrorKind::InvalidSignature => "JWT 签名无效".to_string(), ErrorKind::InvalidAlgorithm => "JWT 算法不匹配".to_string(), ErrorKind::InvalidToken => "JWT 非法".to_string(), ErrorKind::ImmatureSignature => "JWT 尚未生效".to_string(), ErrorKind::MissingRequiredClaim(claim) => format!("JWT 缺少必填字段:{claim}"), _ => format!("JWT 校验失败:{error}"), }; JwtError::VerifyFailed(message) } fn build_sms_provider_out_id(scene: &str, national_phone_number: &str) -> String { let phone_suffix = national_phone_number .chars() .rev() .take(4) .collect::() .chars() .rev() .collect::(); format!("{scene}_{}_{}", phone_suffix, new_uuid_simple_string()) } fn build_aliyun_sms_url(endpoint: &str) -> Result { let endpoint = endpoint .trim() .trim_start_matches("https://") .trim_start_matches("http://") .trim_matches('/'); if endpoint.is_empty() { return Err(SmsProviderError::InvalidConfig( "阿里云短信 endpoint 不能为空".to_string(), )); } Ok(format!("https://{endpoint}/")) } fn current_aliyun_timestamp() -> String { OffsetDateTime::now_utc() .format(&time::format_description::well_known::Rfc3339) .unwrap_or_else(|_| "1970-01-01T00:00:00Z".to_string()) } fn canonicalize_aliyun_rpc_params(params: &BTreeMap) -> String { params .iter() .filter(|(key, _)| key.as_str() != "Signature") .map(|(key, value)| { format!( "{}={}", aliyun_percent_encode(key), aliyun_percent_encode(value) ) }) .collect::>() .join("&") } fn aliyun_percent_encode(value: &str) -> String { urlencoding::encode(value) .into_owned() .replace('+', "%20") .replace('*', "%2A") .replace("%7E", "~") } async fn parse_aliyun_json_response( response: reqwest::Response, fallback_message: &str, ) -> Result { let status = response.status(); let body = response .text() .await .map_err(|error| SmsProviderError::Upstream(format!("{fallback_message}:{error}")))?; let payload = serde_json::from_str::(&body).map_err(|error| { SmsProviderError::Upstream(format!("{fallback_message}:响应解析失败:{error}")) })?; if status.is_client_error() || status.is_server_error() { return Err(map_http_status_to_sms_provider_error( fallback_message, status, serde_json::from_str::(&body).ok(), )); } Ok(payload) } async fn parse_aliyun_json_response_for_verify( response: reqwest::Response, ) -> Result { let status = response.status(); let body = response .text() .await .map_err(|error| SmsProviderError::Upstream(format!("验证码校验失败:{error}")))?; let payload = serde_json::from_str::(&body).map_err(|error| { SmsProviderError::Upstream(format!("验证码校验失败:响应解析失败:{error}")) })?; if status.is_client_error() || status.is_server_error() { return Err(map_http_status_to_sms_provider_error( "验证码校验失败", status, serde_json::from_str::(&body).ok(), )); } Ok(payload) } fn map_http_status_to_sms_provider_error( fallback_message: &str, status: StatusCode, payload: Option, ) -> SmsProviderError { let provider_message = payload .as_ref() .and_then(|value| value.get("Message").and_then(Value::as_str)) .unwrap_or_default(); let provider_code = payload .as_ref() .and_then(|value| value.get("Code").and_then(Value::as_str)) .unwrap_or_default(); if status.is_client_error() { return map_aliyun_provider_error( fallback_message, Some(provider_message.to_string()), Some(provider_code.to_string()), ); } SmsProviderError::Upstream(build_provider_error_message( fallback_message, provider_message, )) } fn map_aliyun_provider_error( fallback_message: &str, provider_message: Option, provider_code: Option, ) -> SmsProviderError { let provider_message = provider_message.unwrap_or_default(); let provider_code = provider_code.unwrap_or_default(); let normalized_code = provider_code.trim().to_ascii_uppercase(); if normalized_code.contains("VERIFY") || normalized_code.contains("CODE") || normalized_code.contains("CHECK") { return SmsProviderError::InvalidVerifyCode; } if normalized_code.contains("MOBILE") || normalized_code.contains("PHONE") || normalized_code.contains("SIGN") || normalized_code.contains("TEMPLATE") || normalized_code.contains("ACCESSKEY") { return SmsProviderError::InvalidConfig(build_provider_error_message( fallback_message, &provider_message, )); } SmsProviderError::Upstream(build_provider_error_message( fallback_message, &provider_message, )) } fn build_provider_error_message(prefix: &str, provider_message: &str) -> String { let provider_message = provider_message.trim(); if provider_message.is_empty() { prefix.to_string() } else { format!("{prefix}:{provider_message}") } } fn mask_phone_number(phone_number: &str) -> String { let chars: Vec = phone_number.chars().collect(); if chars.len() <= 4 { return "*".repeat(chars.len().max(1)); } let prefix_len = chars.len().min(3); let suffix_len = 4.min(chars.len().saturating_sub(prefix_len)); let mask_len = chars.len().saturating_sub(prefix_len + suffix_len); let mut masked = String::new(); masked.extend(chars.iter().take(prefix_len)); masked.push_str(&"*".repeat(mask_len.max(1))); if suffix_len > 0 { masked.extend(chars.iter().skip(chars.len() - suffix_len)); } masked } impl fmt::Display for SmsProviderError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::InvalidConfig(message) | Self::Upstream(message) => f.write_str(message), Self::InvalidVerifyCode => f.write_str("验证码错误"), } } } impl Error for SmsProviderError {} #[cfg(test)] mod tests { use super::*; fn build_jwt_config() -> JwtConfig { JwtConfig::new( "https://auth.genarrative.local".to_string(), "genarrative-dev-secret".to_string(), DEFAULT_ACCESS_TOKEN_TTL_SECONDS, ) .expect("jwt config should be valid") } fn build_claims_input() -> AccessTokenClaimsInput { AccessTokenClaimsInput { user_id: "usr_123".to_string(), session_id: "sess_456".to_string(), provider: AuthProvider::Wechat, roles: vec!["user".to_string()], token_version: 3, phone_verified: false, binding_status: BindingStatus::PendingBindPhone, display_name: Some("微信旅人".to_string()), } } fn build_refresh_cookie_config() -> RefreshCookieConfig { RefreshCookieConfig::new( DEFAULT_REFRESH_COOKIE_NAME.to_string(), DEFAULT_REFRESH_COOKIE_PATH.to_string(), false, RefreshCookieSameSite::Lax, DEFAULT_REFRESH_SESSION_TTL_DAYS, ) .expect("refresh cookie config should be valid") } fn build_mock_sms_config() -> SmsAuthConfig { SmsAuthConfig::new( SmsAuthProviderKind::Mock, DEFAULT_SMS_ENDPOINT.to_string(), None, None, String::new(), String::new(), DEFAULT_SMS_TEMPLATE_PARAM_KEY.to_string(), DEFAULT_SMS_COUNTRY_CODE.to_string(), None, DEFAULT_SMS_CODE_LENGTH, DEFAULT_SMS_CODE_TYPE, DEFAULT_SMS_VALID_TIME_SECONDS, DEFAULT_SMS_INTERVAL_SECONDS, DEFAULT_SMS_DUPLICATE_POLICY, DEFAULT_SMS_CASE_AUTH_POLICY, false, DEFAULT_SMS_MOCK_VERIFY_CODE.to_string(), ) .expect("mock sms config should be valid") } #[test] fn round_trip_sign_and_verify_access_token() { let config = build_jwt_config(); let claims = AccessTokenClaims::from_input(build_claims_input(), &config, OffsetDateTime::now_utc()) .expect("claims should build"); let token = sign_access_token(&claims, &config).expect("token should sign"); let verified = verify_access_token(&token, &config).expect("token should verify"); assert_eq!(verified, claims); assert_eq!(verified.user_id(), "usr_123"); assert_eq!(verified.session_id(), "sess_456"); assert_eq!(verified.token_version(), 3); } #[test] fn verify_rejects_invalid_issuer() { let config = build_jwt_config(); let claims = AccessTokenClaims::from_input(build_claims_input(), &config, OffsetDateTime::now_utc()) .expect("claims should build"); let token = sign_access_token(&claims, &config).expect("token should sign"); let other_config = JwtConfig::new( "https://auth.other.local".to_string(), "genarrative-dev-secret".to_string(), DEFAULT_ACCESS_TOKEN_TTL_SECONDS, ) .expect("other config should be valid"); let error = verify_access_token(&token, &other_config).expect_err("issuer should mismatch"); assert_eq!( error, JwtError::VerifyFailed("JWT 发行者不匹配".to_string()) ); } #[test] fn build_claims_rejects_empty_roles() { let error = AccessTokenClaims::from_input( AccessTokenClaimsInput { roles: Vec::new(), ..build_claims_input() }, &build_jwt_config(), OffsetDateTime::now_utc(), ) .expect_err("empty roles should be rejected"); assert_eq!(error, JwtError::InvalidClaims("JWT roles 至少包含一个角色")); } #[test] fn read_refresh_session_token_returns_matching_cookie() { let token = read_refresh_session_token( "theme=dark; genarrative_refresh_session=refresh-token-01; locale=zh-CN", &build_refresh_cookie_config(), ); assert_eq!(token.as_deref(), Some("refresh-token-01")); } #[test] fn read_refresh_session_token_decodes_urlencoded_value() { let token = read_refresh_session_token( "genarrative_refresh_session=refresh%2Ftoken%3D01", &build_refresh_cookie_config(), ); assert_eq!(token.as_deref(), Some("refresh/token=01")); } #[test] fn read_refresh_session_token_returns_none_when_missing() { let token = read_refresh_session_token("theme=dark; locale=zh-CN", &build_refresh_cookie_config()); assert!(token.is_none()); } #[tokio::test] async fn hash_and_verify_password_round_trip() { let password_hash = hash_password("secret123") .await .expect("password hash should build"); let is_valid = verify_password(&password_hash, "secret123") .await .expect("password hash should verify"); assert!(is_valid); } #[test] fn build_refresh_session_cookie_respects_config() { let cookie = build_refresh_session_set_cookie("refresh/token=01", &build_refresh_cookie_config()); assert!(cookie.contains("genarrative_refresh_session=refresh%2Ftoken%3D01")); assert!(cookie.contains("Path=/api/auth")); assert!(cookie.contains("HttpOnly")); assert!(cookie.contains("SameSite=Lax")); assert!(cookie.contains("Max-Age=2592000")); } #[test] fn hash_refresh_session_token_matches_sha256_hex() { let hash = hash_refresh_session_token("refresh-token-01"); assert_eq!( hash, "9fab76f9100ec6c151c8caa0c42ab10e10fbc7618f15e24cf3dffc93e19c4c4e" ); } #[test] fn build_refresh_session_clear_cookie_respects_config() { let cookie = build_refresh_session_clear_cookie(&build_refresh_cookie_config()); assert!(cookie.contains("genarrative_refresh_session=")); assert!(cookie.contains("Path=/api/auth")); assert!(cookie.contains("HttpOnly")); assert!(cookie.contains("SameSite=Lax")); assert!(cookie.contains("Max-Age=0")); } #[test] fn sms_auth_provider_kind_parses_supported_values() { assert_eq!( SmsAuthProviderKind::parse("mock"), Some(SmsAuthProviderKind::Mock) ); assert_eq!( SmsAuthProviderKind::parse("aliyun"), Some(SmsAuthProviderKind::Aliyun) ); assert_eq!(SmsAuthProviderKind::parse("other"), None); } #[tokio::test] async fn mock_sms_provider_sends_and_verifies_code() { let provider = SmsAuthProvider::new(build_mock_sms_config()).expect("provider should build"); let send_result = provider .send_code(SmsSendCodeRequest { national_phone_number: "13800138000".to_string(), scene: "login".to_string(), }) .await .expect("send code should succeed"); assert_eq!(send_result.cooldown_seconds, DEFAULT_SMS_INTERVAL_SECONDS); assert_eq!( send_result.expires_in_seconds, DEFAULT_SMS_VALID_TIME_SECONDS ); assert_eq!( send_result.provider_request_id.as_deref(), Some("mock-request-id") ); assert!(send_result.provider_out_id.is_some()); provider .verify_code(SmsVerifyCodeRequest { national_phone_number: "13800138000".to_string(), verify_code: DEFAULT_SMS_MOCK_VERIFY_CODE.to_string(), provider_out_id: send_result.provider_out_id, }) .await .expect("verify code should succeed"); } #[tokio::test] async fn mock_sms_provider_rejects_wrong_code() { let provider = SmsAuthProvider::new(build_mock_sms_config()).expect("provider should build"); let error = provider .verify_code(SmsVerifyCodeRequest { national_phone_number: "13800138000".to_string(), verify_code: "000000".to_string(), provider_out_id: None, }) .await .expect_err("wrong verify code should fail"); assert_eq!(error, SmsProviderError::InvalidVerifyCode); } #[test] fn aliyun_sms_config_requires_access_key() { let error = SmsAuthConfig::new( SmsAuthProviderKind::Aliyun, DEFAULT_SMS_ENDPOINT.to_string(), None, None, "测试签名".to_string(), "SMS_001".to_string(), DEFAULT_SMS_TEMPLATE_PARAM_KEY.to_string(), DEFAULT_SMS_COUNTRY_CODE.to_string(), None, DEFAULT_SMS_CODE_LENGTH, DEFAULT_SMS_CODE_TYPE, DEFAULT_SMS_VALID_TIME_SECONDS, DEFAULT_SMS_INTERVAL_SECONDS, DEFAULT_SMS_DUPLICATE_POLICY, DEFAULT_SMS_CASE_AUTH_POLICY, false, DEFAULT_SMS_MOCK_VERIFY_CODE.to_string(), ) .expect_err("aliyun config without access key should fail"); assert_eq!( error, SmsProviderError::InvalidConfig("阿里云短信 AccessKey 未配置".to_string()) ); } #[test] fn canonicalize_aliyun_rpc_params_keeps_sorted_percent_encoded_order() { let mut params = BTreeMap::new(); params.insert( "TemplateParam".to_string(), "{\"code\":\"##code##\"}".to_string(), ); params.insert("Action".to_string(), "SendSmsVerifyCode".to_string()); params.insert("PhoneNumber".to_string(), "13800138000".to_string()); assert_eq!( canonicalize_aliyun_rpc_params(¶ms), "Action=SendSmsVerifyCode&PhoneNumber=13800138000&TemplateParam=%7B%22code%22%3A%22%23%23code%23%23%22%7D" ); } #[test] fn aliyun_send_response_deserializes_pascal_case_fields() { let payload = serde_json::from_str::( r#"{ "Code": "OK", "Message": "成功", "RequestId": "req_123", "Success": true, "Model": { "BizId": "biz_456", "OutId": "out_789", "RequestId": "req_model_001" } }"#, ) .expect("aliyun send response should deserialize"); assert_eq!(payload.code.as_deref(), Some("OK")); assert_eq!(payload.message.as_deref(), Some("成功")); assert_eq!(payload.request_id.as_deref(), Some("req_123")); assert_eq!(payload.success, Some(true)); assert_eq!( payload .model .as_ref() .and_then(|model| model.out_id.as_deref()), Some("out_789") ); assert_eq!( payload .model .as_ref() .and_then(|model| model.request_id.as_deref()), Some("req_model_001") ); } #[test] fn aliyun_verify_response_deserializes_pascal_case_fields() { let payload = serde_json::from_str::( r#"{ "Code": "OK", "Message": "成功", "Success": true, "Model": { "OutId": "out_789", "VerifyResult": "PASS" } }"#, ) .expect("aliyun verify response should deserialize"); assert_eq!(payload.code.as_deref(), Some("OK")); assert_eq!(payload.message.as_deref(), Some("成功")); assert_eq!(payload.success, Some(true)); assert_eq!( payload .model .as_ref() .and_then(|model| model.verify_result.as_deref()), Some("PASS") ); } }