2270 lines
75 KiB
Rust
2270 lines
75 KiB
Rust
use std::{
|
||
collections::{BTreeMap, HashSet},
|
||
error::Error,
|
||
fmt,
|
||
};
|
||
|
||
use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier, password_hash::SaltString};
|
||
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
|
||
use hmac::{Hmac, Mac};
|
||
use jsonwebtoken::{
|
||
Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode, errors::ErrorKind,
|
||
};
|
||
use rand_core::OsRng;
|
||
use reqwest::{Client, StatusCode};
|
||
use serde::{Deserialize, Serialize};
|
||
use serde_json::Value;
|
||
use sha1::Sha1;
|
||
use sha2::{Digest, Sha256};
|
||
use shared_kernel::{new_uuid_simple_string, normalize_optional_string, normalize_required_string};
|
||
use time::{Duration, OffsetDateTime};
|
||
use tracing::{info, warn};
|
||
use url::Url;
|
||
|
||
pub const ACCESS_TOKEN_ALGORITHM: Algorithm = Algorithm::HS256;
|
||
pub const DEFAULT_ACCESS_TOKEN_TTL_SECONDS: u64 = 2 * 60 * 60;
|
||
pub const DEFAULT_REFRESH_COOKIE_NAME: &str = "genarrative_refresh_session";
|
||
pub const DEFAULT_REFRESH_COOKIE_PATH: &str = "/api/auth";
|
||
pub const DEFAULT_REFRESH_SESSION_TTL_DAYS: u32 = 30;
|
||
pub const DEFAULT_SMS_ENDPOINT: &str = "dypnsapi.aliyuncs.com";
|
||
pub const DEFAULT_SMS_COUNTRY_CODE: &str = "86";
|
||
pub const DEFAULT_SMS_TEMPLATE_PARAM_KEY: &str = "code";
|
||
pub const DEFAULT_SMS_MOCK_VERIFY_CODE: &str = "123456";
|
||
pub const DEFAULT_SMS_CODE_LENGTH: u8 = 6;
|
||
pub const DEFAULT_SMS_CODE_TYPE: u8 = 1;
|
||
pub const DEFAULT_SMS_VALID_TIME_SECONDS: u64 = 300;
|
||
pub const DEFAULT_SMS_INTERVAL_SECONDS: u64 = 60;
|
||
pub const DEFAULT_SMS_DUPLICATE_POLICY: u8 = 1;
|
||
pub const DEFAULT_SMS_CASE_AUTH_POLICY: u8 = 1;
|
||
pub const DEFAULT_WECHAT_AUTHORIZE_ENDPOINT: &str = "https://open.weixin.qq.com/connect/qrconnect";
|
||
pub const DEFAULT_WECHAT_IN_APP_AUTHORIZE_ENDPOINT: &str =
|
||
"https://open.weixin.qq.com/connect/oauth2/authorize";
|
||
pub const DEFAULT_WECHAT_ACCESS_TOKEN_ENDPOINT: &str =
|
||
"https://api.weixin.qq.com/sns/oauth2/access_token";
|
||
pub const DEFAULT_WECHAT_USER_INFO_ENDPOINT: &str = "https://api.weixin.qq.com/sns/userinfo";
|
||
pub const DEFAULT_WECHAT_JS_CODE_SESSION_ENDPOINT: &str =
|
||
"https://api.weixin.qq.com/sns/jscode2session";
|
||
|
||
type HmacSha1 = Hmac<Sha1>;
|
||
|
||
// 鉴权 provider 直接冻结成文档中约定的枚举,避免后续在多个 crate 内重复发明字符串字面量。
|
||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||
#[serde(rename_all = "snake_case")]
|
||
pub enum AuthProvider {
|
||
Password,
|
||
Phone,
|
||
Wechat,
|
||
}
|
||
|
||
// 绑定状态只保留当前 JWT 需要透传的最小快照,不把完整账号状态枚举直接泄漏到 token 中。
|
||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||
#[serde(rename_all = "snake_case")]
|
||
pub enum BindingStatus {
|
||
Active,
|
||
PendingBindPhone,
|
||
}
|
||
|
||
// 用于签发 access token 的领域输入,和最终 JWT claims 解耦,避免业务层手动拼 iat/exp/iss。
|
||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||
pub struct AccessTokenClaimsInput {
|
||
pub user_id: String,
|
||
pub session_id: String,
|
||
pub provider: AuthProvider,
|
||
pub roles: Vec<String>,
|
||
pub token_version: u64,
|
||
pub phone_verified: bool,
|
||
pub binding_status: BindingStatus,
|
||
pub display_name: Option<String>,
|
||
}
|
||
|
||
// 直接映射最终 JWT payload,字段名与文档冻结口径保持一致。
|
||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||
pub struct AccessTokenClaims {
|
||
pub iss: String,
|
||
pub sub: String,
|
||
pub sid: String,
|
||
pub provider: AuthProvider,
|
||
pub roles: Vec<String>,
|
||
pub ver: u64,
|
||
pub phone_verified: bool,
|
||
pub binding_status: BindingStatus,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
pub display_name: Option<String>,
|
||
pub iat: u64,
|
||
pub exp: u64,
|
||
}
|
||
|
||
// 统一承载 JWT 配置,避免 secret、issuer、ttl 在 api-server 与后续模块里散落。
|
||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||
pub struct JwtConfig {
|
||
issuer: String,
|
||
secret: String,
|
||
access_token_ttl_seconds: u64,
|
||
}
|
||
|
||
// refresh cookie 的 SameSite 固定约束成枚举,避免各层直接使用大小写不一致的字符串。
|
||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||
pub enum RefreshCookieSameSite {
|
||
Lax,
|
||
Strict,
|
||
None,
|
||
}
|
||
|
||
// refresh cookie 的平台配置统一收口到 platform-auth,避免 api-server 直接散落 cookie 细节。
|
||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||
pub struct RefreshCookieConfig {
|
||
cookie_name: String,
|
||
cookie_path: String,
|
||
cookie_secure: bool,
|
||
cookie_same_site: RefreshCookieSameSite,
|
||
refresh_session_ttl_days: u32,
|
||
}
|
||
|
||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||
pub enum SmsAuthProviderKind {
|
||
Mock,
|
||
Aliyun,
|
||
}
|
||
|
||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||
pub struct SmsAuthConfig {
|
||
pub provider: SmsAuthProviderKind,
|
||
pub endpoint: String,
|
||
pub access_key_id: Option<String>,
|
||
pub access_key_secret: Option<String>,
|
||
pub sign_name: String,
|
||
pub template_code: String,
|
||
pub template_param_key: String,
|
||
pub country_code: String,
|
||
pub scheme_name: Option<String>,
|
||
pub code_length: u8,
|
||
pub code_type: u8,
|
||
pub valid_time_seconds: u64,
|
||
pub interval_seconds: u64,
|
||
pub duplicate_policy: u8,
|
||
pub case_auth_policy: u8,
|
||
pub return_verify_code: bool,
|
||
pub mock_verify_code: String,
|
||
}
|
||
|
||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||
pub struct SmsSendCodeRequest {
|
||
pub national_phone_number: String,
|
||
pub scene: String,
|
||
}
|
||
|
||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||
pub struct SmsSendCodeResult {
|
||
pub cooldown_seconds: u64,
|
||
pub expires_in_seconds: u64,
|
||
pub provider_request_id: Option<String>,
|
||
pub provider_out_id: Option<String>,
|
||
}
|
||
|
||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||
pub struct SmsVerifyCodeRequest {
|
||
pub national_phone_number: String,
|
||
pub verify_code: String,
|
||
pub provider_out_id: Option<String>,
|
||
}
|
||
|
||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||
pub enum WechatAuthScene {
|
||
Desktop,
|
||
WechatInApp,
|
||
}
|
||
|
||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||
pub struct WechatAuthConfig {
|
||
pub enabled: bool,
|
||
pub provider: String,
|
||
pub app_id: Option<String>,
|
||
pub app_secret: Option<String>,
|
||
pub mini_program_app_id: Option<String>,
|
||
pub mini_program_app_secret: Option<String>,
|
||
pub authorize_endpoint: String,
|
||
pub access_token_endpoint: String,
|
||
pub user_info_endpoint: String,
|
||
pub js_code_session_endpoint: String,
|
||
pub mock_user_id: String,
|
||
pub mock_union_id: Option<String>,
|
||
pub mock_display_name: String,
|
||
pub mock_avatar_url: Option<String>,
|
||
}
|
||
|
||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||
pub struct WechatIdentityProfile {
|
||
pub provider_uid: String,
|
||
pub provider_union_id: Option<String>,
|
||
pub display_name: Option<String>,
|
||
pub avatar_url: Option<String>,
|
||
}
|
||
|
||
#[derive(Clone, Debug)]
|
||
pub enum WechatProvider {
|
||
Disabled,
|
||
Mock(MockWechatProvider),
|
||
Real(RealWechatProvider),
|
||
}
|
||
|
||
#[derive(Clone, Debug)]
|
||
pub struct MockWechatProvider {
|
||
mock_user_id: String,
|
||
mock_union_id: Option<String>,
|
||
mock_display_name: String,
|
||
mock_avatar_url: Option<String>,
|
||
}
|
||
|
||
#[derive(Clone, Debug)]
|
||
pub struct RealWechatProvider {
|
||
client: Client,
|
||
app_id: Option<String>,
|
||
app_secret: Option<String>,
|
||
mini_program_app_id: Option<String>,
|
||
mini_program_app_secret: Option<String>,
|
||
authorize_endpoint: String,
|
||
access_token_endpoint: String,
|
||
user_info_endpoint: String,
|
||
js_code_session_endpoint: String,
|
||
}
|
||
|
||
#[derive(Clone, Debug)]
|
||
pub enum SmsAuthProvider {
|
||
Mock(MockSmsAuthProvider),
|
||
Aliyun(AliyunSmsAuthProvider),
|
||
}
|
||
|
||
#[derive(Clone, Debug)]
|
||
pub struct MockSmsAuthProvider {
|
||
config: SmsAuthConfig,
|
||
}
|
||
|
||
#[derive(Clone, Debug)]
|
||
pub struct AliyunSmsAuthProvider {
|
||
client: Client,
|
||
config: SmsAuthConfig,
|
||
}
|
||
|
||
#[derive(Debug, PartialEq, Eq)]
|
||
pub enum JwtError {
|
||
InvalidConfig(&'static str),
|
||
InvalidClaims(&'static str),
|
||
SignFailed(String),
|
||
VerifyFailed(String),
|
||
}
|
||
|
||
#[derive(Debug, PartialEq, Eq)]
|
||
pub enum RefreshCookieError {
|
||
InvalidConfig(&'static str),
|
||
}
|
||
|
||
#[derive(Debug, PartialEq, Eq)]
|
||
pub enum PasswordHashError {
|
||
HashFailed(String),
|
||
VerifyFailed(String),
|
||
}
|
||
|
||
#[derive(Debug, PartialEq, Eq)]
|
||
pub enum SmsProviderError {
|
||
InvalidConfig(String),
|
||
InvalidVerifyCode,
|
||
Upstream(String),
|
||
}
|
||
|
||
#[derive(Debug, PartialEq, Eq)]
|
||
pub enum WechatProviderError {
|
||
Disabled,
|
||
MissingCode,
|
||
InvalidConfig(String),
|
||
InvalidCallback(String),
|
||
RequestFailed(String),
|
||
DeserializeFailed(String),
|
||
Upstream(String),
|
||
MissingProfile(String),
|
||
}
|
||
|
||
// 鉴权平台错误统一先归类,api-server 再决定 HTTP status 和错误 envelope。
|
||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||
pub enum AuthPlatformErrorKind {
|
||
InvalidConfig,
|
||
InvalidClaims,
|
||
SignFailed,
|
||
VerifyFailed,
|
||
CookieConfig,
|
||
HashFailed,
|
||
InvalidVerifyCode,
|
||
Disabled,
|
||
MissingCode,
|
||
InvalidCallback,
|
||
RequestFailed,
|
||
DeserializeFailed,
|
||
MissingProfile,
|
||
Upstream,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
struct WechatAccessTokenResponse {
|
||
access_token: Option<String>,
|
||
openid: Option<String>,
|
||
unionid: Option<String>,
|
||
errmsg: Option<String>,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
struct WechatUserInfoResponse {
|
||
openid: Option<String>,
|
||
unionid: Option<String>,
|
||
nickname: Option<String>,
|
||
headimgurl: Option<String>,
|
||
errmsg: Option<String>,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
struct WechatJsCodeSessionResponse {
|
||
openid: Option<String>,
|
||
unionid: Option<String>,
|
||
errcode: Option<i64>,
|
||
errmsg: Option<String>,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
struct AliyunSendSmsVerifyCodeResponse {
|
||
// 阿里云 RPC 原始 JSON 使用首字母大写字段名,这里必须显式映射,避免把成功响应误判成空值。
|
||
#[serde(default, rename = "Code")]
|
||
code: Option<String>,
|
||
#[serde(default, rename = "Message")]
|
||
message: Option<String>,
|
||
#[serde(default, rename = "RequestId")]
|
||
request_id: Option<String>,
|
||
#[serde(default, rename = "Success")]
|
||
success: Option<bool>,
|
||
#[serde(default, rename = "Model")]
|
||
model: Option<AliyunSendSmsVerifyCodeModel>,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
struct AliyunSendSmsVerifyCodeModel {
|
||
#[serde(default, rename = "BizId")]
|
||
_biz_id: Option<String>,
|
||
#[serde(default, rename = "OutId")]
|
||
out_id: Option<String>,
|
||
#[serde(default, rename = "RequestId")]
|
||
request_id: Option<String>,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
struct AliyunCheckSmsVerifyCodeResponse {
|
||
// 校验接口同样返回首字母大写字段名,保持和发送接口一致的显式映射。
|
||
#[serde(default, rename = "Code")]
|
||
code: Option<String>,
|
||
#[serde(default, rename = "Message")]
|
||
message: Option<String>,
|
||
#[serde(default, rename = "Success")]
|
||
success: Option<bool>,
|
||
#[serde(default, rename = "Model")]
|
||
model: Option<AliyunCheckSmsVerifyCodeModel>,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
struct AliyunCheckSmsVerifyCodeModel {
|
||
#[serde(default, rename = "OutId")]
|
||
_out_id: Option<String>,
|
||
#[serde(default, rename = "VerifyResult")]
|
||
verify_result: Option<String>,
|
||
}
|
||
|
||
impl JwtConfig {
|
||
pub fn new(
|
||
issuer: String,
|
||
secret: String,
|
||
access_token_ttl_seconds: u64,
|
||
) -> Result<Self, JwtError> {
|
||
let issuer = normalize_required_string(&issuer)
|
||
.ok_or(JwtError::InvalidConfig("JWT issuer 不能为空"))?;
|
||
let secret = normalize_required_string(&secret)
|
||
.ok_or(JwtError::InvalidConfig("JWT secret 不能为空"))?;
|
||
|
||
if access_token_ttl_seconds == 0 {
|
||
return Err(JwtError::InvalidConfig(
|
||
"JWT access token 过期时间必须大于 0",
|
||
));
|
||
}
|
||
|
||
Ok(Self {
|
||
issuer,
|
||
secret,
|
||
access_token_ttl_seconds,
|
||
})
|
||
}
|
||
|
||
pub fn issuer(&self) -> &str {
|
||
&self.issuer
|
||
}
|
||
|
||
pub fn access_token_ttl_seconds(&self) -> u64 {
|
||
self.access_token_ttl_seconds
|
||
}
|
||
}
|
||
|
||
impl RefreshCookieSameSite {
|
||
pub fn parse(raw: &str) -> Option<Self> {
|
||
match raw.trim().to_ascii_lowercase().as_str() {
|
||
"lax" => Some(Self::Lax),
|
||
"strict" => Some(Self::Strict),
|
||
"none" => Some(Self::None),
|
||
_ => None,
|
||
}
|
||
}
|
||
|
||
pub fn as_str(&self) -> &'static str {
|
||
match self {
|
||
Self::Lax => "Lax",
|
||
Self::Strict => "Strict",
|
||
Self::None => "None",
|
||
}
|
||
}
|
||
}
|
||
|
||
impl RefreshCookieConfig {
|
||
pub fn new(
|
||
cookie_name: String,
|
||
cookie_path: String,
|
||
cookie_secure: bool,
|
||
cookie_same_site: RefreshCookieSameSite,
|
||
refresh_session_ttl_days: u32,
|
||
) -> Result<Self, RefreshCookieError> {
|
||
let cookie_name = normalize_required_string(&cookie_name).ok_or(
|
||
RefreshCookieError::InvalidConfig("refresh cookie 名称不能为空"),
|
||
)?;
|
||
let cookie_path = normalize_required_string(&cookie_path).ok_or(
|
||
RefreshCookieError::InvalidConfig("refresh cookie path 不能为空"),
|
||
)?;
|
||
|
||
if refresh_session_ttl_days == 0 {
|
||
return Err(RefreshCookieError::InvalidConfig(
|
||
"refresh session TTL 天数必须大于 0",
|
||
));
|
||
}
|
||
|
||
Ok(Self {
|
||
cookie_name,
|
||
cookie_path,
|
||
cookie_secure,
|
||
cookie_same_site,
|
||
refresh_session_ttl_days,
|
||
})
|
||
}
|
||
|
||
pub fn cookie_name(&self) -> &str {
|
||
&self.cookie_name
|
||
}
|
||
|
||
pub fn cookie_path(&self) -> &str {
|
||
&self.cookie_path
|
||
}
|
||
|
||
pub fn cookie_secure(&self) -> bool {
|
||
self.cookie_secure
|
||
}
|
||
|
||
pub fn cookie_same_site(&self) -> &RefreshCookieSameSite {
|
||
&self.cookie_same_site
|
||
}
|
||
|
||
pub fn refresh_session_ttl_days(&self) -> u32 {
|
||
self.refresh_session_ttl_days
|
||
}
|
||
}
|
||
|
||
impl SmsAuthProviderKind {
|
||
pub fn parse(raw: &str) -> Option<Self> {
|
||
match raw.trim().to_ascii_lowercase().as_str() {
|
||
"mock" => Some(Self::Mock),
|
||
"aliyun" => Some(Self::Aliyun),
|
||
_ => None,
|
||
}
|
||
}
|
||
|
||
pub fn as_str(&self) -> &'static str {
|
||
match self {
|
||
Self::Mock => "mock",
|
||
Self::Aliyun => "aliyun",
|
||
}
|
||
}
|
||
}
|
||
|
||
impl SmsAuthConfig {
|
||
pub fn new(
|
||
provider: SmsAuthProviderKind,
|
||
endpoint: String,
|
||
access_key_id: Option<String>,
|
||
access_key_secret: Option<String>,
|
||
sign_name: String,
|
||
template_code: String,
|
||
template_param_key: String,
|
||
country_code: String,
|
||
scheme_name: Option<String>,
|
||
code_length: u8,
|
||
code_type: u8,
|
||
valid_time_seconds: u64,
|
||
interval_seconds: u64,
|
||
duplicate_policy: u8,
|
||
case_auth_policy: u8,
|
||
return_verify_code: bool,
|
||
mock_verify_code: String,
|
||
) -> Result<Self, SmsProviderError> {
|
||
let endpoint = normalize_required_string(&endpoint)
|
||
.unwrap_or_else(|| DEFAULT_SMS_ENDPOINT.to_string());
|
||
let template_param_key = normalize_required_string(&template_param_key)
|
||
.unwrap_or_else(|| DEFAULT_SMS_TEMPLATE_PARAM_KEY.to_string());
|
||
let country_code = normalize_required_string(&country_code)
|
||
.unwrap_or_else(|| DEFAULT_SMS_COUNTRY_CODE.to_string());
|
||
let scheme_name = normalize_optional_string(scheme_name);
|
||
let mock_verify_code = normalize_required_string(&mock_verify_code)
|
||
.unwrap_or_else(|| DEFAULT_SMS_MOCK_VERIFY_CODE.to_string());
|
||
|
||
if !(4..=8).contains(&code_length) {
|
||
return Err(SmsProviderError::InvalidConfig(
|
||
"短信验证码长度必须在 4 到 8 之间".to_string(),
|
||
));
|
||
}
|
||
if !(1..=7).contains(&code_type) {
|
||
return Err(SmsProviderError::InvalidConfig(
|
||
"短信验证码类型取值非法".to_string(),
|
||
));
|
||
}
|
||
if interval_seconds == 0 || valid_time_seconds == 0 {
|
||
return Err(SmsProviderError::InvalidConfig(
|
||
"短信验证码有效期和发送间隔必须大于 0".to_string(),
|
||
));
|
||
}
|
||
if !(1..=2).contains(&duplicate_policy) {
|
||
return Err(SmsProviderError::InvalidConfig(
|
||
"短信验证码重复策略取值非法".to_string(),
|
||
));
|
||
}
|
||
if !(1..=2).contains(&case_auth_policy) {
|
||
return Err(SmsProviderError::InvalidConfig(
|
||
"短信验证码大小写校验策略取值非法".to_string(),
|
||
));
|
||
}
|
||
|
||
match provider {
|
||
SmsAuthProviderKind::Mock => {}
|
||
SmsAuthProviderKind::Aliyun => {
|
||
if normalize_required_string(&sign_name).is_none() {
|
||
return Err(SmsProviderError::InvalidConfig(
|
||
"阿里云短信签名不能为空".to_string(),
|
||
));
|
||
}
|
||
if normalize_required_string(&template_code).is_none() {
|
||
return Err(SmsProviderError::InvalidConfig(
|
||
"阿里云短信模板编码不能为空".to_string(),
|
||
));
|
||
}
|
||
if access_key_id
|
||
.as_deref()
|
||
.and_then(normalize_required_string)
|
||
.is_none()
|
||
|| access_key_secret
|
||
.as_deref()
|
||
.and_then(normalize_required_string)
|
||
.is_none()
|
||
{
|
||
return Err(SmsProviderError::InvalidConfig(
|
||
"阿里云短信 AccessKey 未配置".to_string(),
|
||
));
|
||
}
|
||
}
|
||
}
|
||
|
||
Ok(Self {
|
||
provider,
|
||
endpoint,
|
||
access_key_id: access_key_id.and_then(|value| normalize_required_string(&value)),
|
||
access_key_secret: access_key_secret
|
||
.and_then(|value| normalize_required_string(&value)),
|
||
sign_name: sign_name.trim().to_string(),
|
||
template_code: template_code.trim().to_string(),
|
||
template_param_key,
|
||
country_code,
|
||
scheme_name,
|
||
code_length,
|
||
code_type,
|
||
valid_time_seconds,
|
||
interval_seconds,
|
||
duplicate_policy,
|
||
case_auth_policy,
|
||
return_verify_code,
|
||
mock_verify_code,
|
||
})
|
||
}
|
||
}
|
||
|
||
impl SmsAuthProvider {
|
||
pub fn new(config: SmsAuthConfig) -> Result<Self, SmsProviderError> {
|
||
match config.provider {
|
||
SmsAuthProviderKind::Mock => Ok(Self::Mock(MockSmsAuthProvider { config })),
|
||
SmsAuthProviderKind::Aliyun => Ok(Self::Aliyun(AliyunSmsAuthProvider {
|
||
client: Client::new(),
|
||
config,
|
||
})),
|
||
}
|
||
}
|
||
|
||
pub fn kind(&self) -> SmsAuthProviderKind {
|
||
match self {
|
||
Self::Mock(_) => SmsAuthProviderKind::Mock,
|
||
Self::Aliyun(_) => SmsAuthProviderKind::Aliyun,
|
||
}
|
||
}
|
||
|
||
pub async fn send_code(
|
||
&self,
|
||
request: SmsSendCodeRequest,
|
||
) -> Result<SmsSendCodeResult, SmsProviderError> {
|
||
match self {
|
||
Self::Mock(provider) => provider.send_code(request).await,
|
||
Self::Aliyun(provider) => provider.send_code(request).await,
|
||
}
|
||
}
|
||
|
||
pub async fn verify_code(&self, request: SmsVerifyCodeRequest) -> Result<(), SmsProviderError> {
|
||
match self {
|
||
Self::Mock(provider) => provider.verify_code(request).await,
|
||
Self::Aliyun(provider) => provider.verify_code(request).await,
|
||
}
|
||
}
|
||
}
|
||
|
||
impl WechatAuthConfig {
|
||
#[allow(clippy::too_many_arguments)]
|
||
pub fn new(
|
||
enabled: bool,
|
||
provider: String,
|
||
app_id: Option<String>,
|
||
app_secret: Option<String>,
|
||
mini_program_app_id: Option<String>,
|
||
mini_program_app_secret: Option<String>,
|
||
authorize_endpoint: String,
|
||
access_token_endpoint: String,
|
||
user_info_endpoint: String,
|
||
js_code_session_endpoint: String,
|
||
mock_user_id: String,
|
||
mock_union_id: Option<String>,
|
||
mock_display_name: String,
|
||
mock_avatar_url: Option<String>,
|
||
) -> Self {
|
||
Self {
|
||
enabled,
|
||
provider,
|
||
app_id,
|
||
app_secret,
|
||
mini_program_app_id,
|
||
mini_program_app_secret,
|
||
authorize_endpoint,
|
||
access_token_endpoint,
|
||
user_info_endpoint,
|
||
js_code_session_endpoint,
|
||
mock_user_id,
|
||
mock_union_id,
|
||
mock_display_name,
|
||
mock_avatar_url,
|
||
}
|
||
}
|
||
}
|
||
|
||
impl WechatProvider {
|
||
pub fn new(config: WechatAuthConfig) -> Self {
|
||
if !config.enabled {
|
||
return Self::Disabled;
|
||
}
|
||
|
||
if config.provider.trim().eq_ignore_ascii_case("mock") {
|
||
return Self::Mock(MockWechatProvider {
|
||
mock_user_id: config.mock_user_id,
|
||
mock_union_id: config.mock_union_id,
|
||
mock_display_name: config.mock_display_name,
|
||
mock_avatar_url: config.mock_avatar_url,
|
||
});
|
||
}
|
||
|
||
let has_web_oauth_config = config
|
||
.app_id
|
||
.as_ref()
|
||
.is_some_and(|value| !value.is_empty())
|
||
&& config
|
||
.app_secret
|
||
.as_ref()
|
||
.is_some_and(|value| !value.is_empty());
|
||
let has_mini_program_config = config
|
||
.mini_program_app_id
|
||
.as_ref()
|
||
.is_some_and(|value| !value.is_empty())
|
||
&& config
|
||
.mini_program_app_secret
|
||
.as_ref()
|
||
.is_some_and(|value| !value.is_empty());
|
||
if !has_web_oauth_config && !has_mini_program_config {
|
||
return Self::Disabled;
|
||
};
|
||
|
||
Self::Real(RealWechatProvider {
|
||
client: Client::new(),
|
||
app_id: config.app_id,
|
||
app_secret: config.app_secret,
|
||
mini_program_app_id: config.mini_program_app_id,
|
||
mini_program_app_secret: config.mini_program_app_secret,
|
||
authorize_endpoint: config.authorize_endpoint,
|
||
access_token_endpoint: config.access_token_endpoint,
|
||
user_info_endpoint: config.user_info_endpoint,
|
||
js_code_session_endpoint: config.js_code_session_endpoint,
|
||
})
|
||
}
|
||
|
||
pub fn build_authorization_url(
|
||
&self,
|
||
callback_url: &str,
|
||
state: &str,
|
||
scene: &WechatAuthScene,
|
||
) -> Result<String, WechatProviderError> {
|
||
match self {
|
||
Self::Disabled => Err(WechatProviderError::Disabled),
|
||
Self::Mock(_) => build_mock_wechat_authorization_url(callback_url, state),
|
||
Self::Real(provider) => provider.build_authorization_url(callback_url, state, scene),
|
||
}
|
||
}
|
||
|
||
pub async fn resolve_callback_profile(
|
||
&self,
|
||
code: Option<&str>,
|
||
mock_code: Option<&str>,
|
||
) -> Result<WechatIdentityProfile, WechatProviderError> {
|
||
match self {
|
||
Self::Disabled => Err(WechatProviderError::Disabled),
|
||
Self::Mock(provider) => Ok(provider.resolve_callback_profile(mock_code)),
|
||
Self::Real(provider) => provider.resolve_callback_profile(code).await,
|
||
}
|
||
}
|
||
|
||
pub async fn resolve_mini_program_login_profile(
|
||
&self,
|
||
code: Option<&str>,
|
||
) -> Result<WechatIdentityProfile, WechatProviderError> {
|
||
match self {
|
||
Self::Disabled => Err(WechatProviderError::Disabled),
|
||
Self::Mock(provider) => Ok(provider.resolve_callback_profile(code)),
|
||
Self::Real(provider) => provider.resolve_mini_program_login_profile(code).await,
|
||
}
|
||
}
|
||
}
|
||
|
||
impl MockWechatProvider {
|
||
fn resolve_callback_profile(&self, mock_code: Option<&str>) -> WechatIdentityProfile {
|
||
let provider_uid = mock_code
|
||
.map(str::trim)
|
||
.filter(|value| !value.is_empty())
|
||
.unwrap_or(self.mock_user_id.as_str())
|
||
.to_string();
|
||
WechatIdentityProfile {
|
||
provider_uid,
|
||
provider_union_id: self.mock_union_id.clone(),
|
||
display_name: Some(self.mock_display_name.clone()),
|
||
avatar_url: self.mock_avatar_url.clone(),
|
||
}
|
||
}
|
||
}
|
||
|
||
impl RealWechatProvider {
|
||
fn build_authorization_url(
|
||
&self,
|
||
callback_url: &str,
|
||
state: &str,
|
||
scene: &WechatAuthScene,
|
||
) -> Result<String, WechatProviderError> {
|
||
let endpoint = match scene {
|
||
WechatAuthScene::Desktop => &self.authorize_endpoint,
|
||
WechatAuthScene::WechatInApp => DEFAULT_WECHAT_IN_APP_AUTHORIZE_ENDPOINT,
|
||
};
|
||
let mut url = Url::parse(endpoint).map_err(|error| {
|
||
WechatProviderError::InvalidConfig(format!("微信授权地址非法:{error}"))
|
||
})?;
|
||
let app_id = self.app_id.as_ref().ok_or_else(|| {
|
||
WechatProviderError::InvalidConfig("微信开放平台 AppID 未配置".to_string())
|
||
})?;
|
||
url.query_pairs_mut()
|
||
.append_pair("appid", app_id)
|
||
.append_pair("redirect_uri", callback_url)
|
||
.append_pair("response_type", "code")
|
||
.append_pair(
|
||
"scope",
|
||
match scene {
|
||
WechatAuthScene::Desktop => "snsapi_login",
|
||
WechatAuthScene::WechatInApp => "snsapi_userinfo",
|
||
},
|
||
)
|
||
.append_pair("state", state);
|
||
Ok(format!("{url}#wechat_redirect"))
|
||
}
|
||
|
||
async fn resolve_callback_profile(
|
||
&self,
|
||
code: Option<&str>,
|
||
) -> Result<WechatIdentityProfile, WechatProviderError> {
|
||
let code = code
|
||
.map(str::trim)
|
||
.filter(|value| !value.is_empty())
|
||
.ok_or(WechatProviderError::MissingCode)?;
|
||
|
||
let app_id = self.app_id.as_ref().ok_or_else(|| {
|
||
WechatProviderError::InvalidConfig("微信开放平台 AppID 未配置".to_string())
|
||
})?;
|
||
let app_secret = self.app_secret.as_ref().ok_or_else(|| {
|
||
WechatProviderError::InvalidConfig("微信开放平台 AppSecret 未配置".to_string())
|
||
})?;
|
||
let mut access_token_url = Url::parse(&self.access_token_endpoint).map_err(|error| {
|
||
WechatProviderError::InvalidConfig(format!("微信 access_token 地址非法:{error}"))
|
||
})?;
|
||
access_token_url
|
||
.query_pairs_mut()
|
||
.append_pair("appid", app_id)
|
||
.append_pair("secret", app_secret)
|
||
.append_pair("code", code)
|
||
.append_pair("grant_type", "authorization_code");
|
||
|
||
let access_token_payload = self
|
||
.client
|
||
.get(access_token_url.as_str())
|
||
.send()
|
||
.await
|
||
.map_err(|error| {
|
||
warn!(error = %error, "微信 access_token 请求失败");
|
||
WechatProviderError::RequestFailed(
|
||
"微信登录失败:access_token 请求失败".to_string(),
|
||
)
|
||
})?
|
||
.json::<WechatAccessTokenResponse>()
|
||
.await
|
||
.map_err(|error| {
|
||
warn!(error = %error, "微信 access_token 响应解析失败");
|
||
WechatProviderError::DeserializeFailed(
|
||
"微信登录失败:access_token 响应非法".to_string(),
|
||
)
|
||
})?;
|
||
|
||
let access_token = access_token_payload
|
||
.access_token
|
||
.filter(|value| !value.trim().is_empty())
|
||
.ok_or_else(|| {
|
||
WechatProviderError::Upstream(format!(
|
||
"微信登录失败:{}",
|
||
access_token_payload
|
||
.errmsg
|
||
.unwrap_or_else(|| "缺少 access_token".to_string())
|
||
))
|
||
})?;
|
||
let openid = access_token_payload
|
||
.openid
|
||
.filter(|value| !value.trim().is_empty())
|
||
.ok_or_else(|| {
|
||
WechatProviderError::MissingProfile("微信登录失败:缺少 openid".to_string())
|
||
})?;
|
||
|
||
let mut user_info_url = Url::parse(&self.user_info_endpoint).map_err(|error| {
|
||
WechatProviderError::InvalidConfig(format!("微信用户信息地址非法:{error}"))
|
||
})?;
|
||
user_info_url
|
||
.query_pairs_mut()
|
||
.append_pair("access_token", &access_token)
|
||
.append_pair("openid", &openid)
|
||
.append_pair("lang", "zh_CN");
|
||
|
||
let user_info_payload = self
|
||
.client
|
||
.get(user_info_url.as_str())
|
||
.send()
|
||
.await
|
||
.map_err(|error| {
|
||
warn!(error = %error, "微信用户信息请求失败");
|
||
WechatProviderError::RequestFailed("微信登录失败:用户信息请求失败".to_string())
|
||
})?
|
||
.json::<WechatUserInfoResponse>()
|
||
.await
|
||
.map_err(|error| {
|
||
warn!(error = %error, "微信用户信息响应解析失败");
|
||
WechatProviderError::DeserializeFailed("微信登录失败:用户信息响应非法".to_string())
|
||
})?;
|
||
|
||
let provider_uid = user_info_payload
|
||
.openid
|
||
.filter(|value| !value.trim().is_empty())
|
||
.ok_or_else(|| {
|
||
WechatProviderError::Upstream(format!(
|
||
"微信登录失败:{}",
|
||
user_info_payload
|
||
.errmsg
|
||
.unwrap_or_else(|| "缺少 openid".to_string())
|
||
))
|
||
})?;
|
||
|
||
Ok(WechatIdentityProfile {
|
||
provider_uid,
|
||
provider_union_id: user_info_payload.unionid.or(access_token_payload.unionid),
|
||
display_name: user_info_payload.nickname,
|
||
avatar_url: user_info_payload.headimgurl,
|
||
})
|
||
}
|
||
|
||
async fn resolve_mini_program_login_profile(
|
||
&self,
|
||
code: Option<&str>,
|
||
) -> Result<WechatIdentityProfile, WechatProviderError> {
|
||
let code = code
|
||
.map(str::trim)
|
||
.filter(|value| !value.is_empty())
|
||
.ok_or(WechatProviderError::MissingCode)?;
|
||
let app_id = self
|
||
.mini_program_app_id
|
||
.as_ref()
|
||
.or(self.app_id.as_ref())
|
||
.ok_or_else(|| {
|
||
WechatProviderError::InvalidConfig("微信小程序 AppID 未配置".to_string())
|
||
})?;
|
||
let app_secret = self
|
||
.mini_program_app_secret
|
||
.as_ref()
|
||
.or(self.app_secret.as_ref())
|
||
.ok_or_else(|| {
|
||
WechatProviderError::InvalidConfig("微信小程序 AppSecret 未配置".to_string())
|
||
})?;
|
||
|
||
let mut js_code_session_url =
|
||
Url::parse(&self.js_code_session_endpoint).map_err(|error| {
|
||
WechatProviderError::InvalidConfig(format!("微信 jscode2session 地址非法:{error}"))
|
||
})?;
|
||
js_code_session_url
|
||
.query_pairs_mut()
|
||
.append_pair("appid", app_id)
|
||
.append_pair("secret", app_secret)
|
||
.append_pair("js_code", code)
|
||
.append_pair("grant_type", "authorization_code");
|
||
|
||
let payload = self
|
||
.client
|
||
.get(js_code_session_url.as_str())
|
||
.send()
|
||
.await
|
||
.map_err(|error| {
|
||
warn!(error = %error, "微信小程序 jscode2session 请求失败");
|
||
WechatProviderError::RequestFailed(
|
||
"微信小程序登录失败:jscode2session 请求失败".to_string(),
|
||
)
|
||
})?
|
||
.json::<WechatJsCodeSessionResponse>()
|
||
.await
|
||
.map_err(|error| {
|
||
warn!(error = %error, "微信小程序 jscode2session 响应解析失败");
|
||
WechatProviderError::DeserializeFailed(
|
||
"微信小程序登录失败:jscode2session 响应非法".to_string(),
|
||
)
|
||
})?;
|
||
|
||
if let Some(errcode) = payload.errcode.filter(|value| *value != 0) {
|
||
return Err(WechatProviderError::Upstream(format!(
|
||
"微信小程序登录失败:{}",
|
||
payload
|
||
.errmsg
|
||
.unwrap_or_else(|| format!("jscode2session 返回错误 {errcode}"))
|
||
)));
|
||
}
|
||
|
||
let provider_uid = payload
|
||
.openid
|
||
.filter(|value| !value.trim().is_empty())
|
||
.ok_or_else(|| {
|
||
WechatProviderError::MissingProfile("微信小程序登录失败:缺少 openid".to_string())
|
||
})?;
|
||
|
||
Ok(WechatIdentityProfile {
|
||
provider_uid,
|
||
provider_union_id: payload.unionid,
|
||
display_name: None,
|
||
avatar_url: None,
|
||
})
|
||
}
|
||
}
|
||
|
||
fn build_mock_wechat_authorization_url(
|
||
callback_url: &str,
|
||
state: &str,
|
||
) -> Result<String, WechatProviderError> {
|
||
let mut callback = Url::parse(callback_url).map_err(|error| {
|
||
WechatProviderError::InvalidCallback(format!("微信回调地址非法:{error}"))
|
||
})?;
|
||
callback
|
||
.query_pairs_mut()
|
||
.append_pair("mock_code", "wx-mock-code")
|
||
.append_pair("state", state);
|
||
Ok(callback.to_string())
|
||
}
|
||
|
||
impl MockSmsAuthProvider {
|
||
async fn send_code(
|
||
&self,
|
||
request: SmsSendCodeRequest,
|
||
) -> Result<SmsSendCodeResult, SmsProviderError> {
|
||
let provider_out_id =
|
||
build_sms_provider_out_id(&request.scene, &request.national_phone_number);
|
||
|
||
Ok(SmsSendCodeResult {
|
||
cooldown_seconds: self.config.interval_seconds,
|
||
expires_in_seconds: self.config.valid_time_seconds,
|
||
provider_request_id: Some("mock-request-id".to_string()),
|
||
provider_out_id: Some(provider_out_id),
|
||
})
|
||
}
|
||
|
||
async fn verify_code(&self, request: SmsVerifyCodeRequest) -> Result<(), SmsProviderError> {
|
||
if request.verify_code.trim() != self.config.mock_verify_code {
|
||
return Err(SmsProviderError::InvalidVerifyCode);
|
||
}
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
impl AliyunSmsAuthProvider {
|
||
async fn send_code(
|
||
&self,
|
||
request: SmsSendCodeRequest,
|
||
) -> Result<SmsSendCodeResult, SmsProviderError> {
|
||
let provider_out_id =
|
||
build_sms_provider_out_id(&request.scene, &request.national_phone_number);
|
||
let phone_masked = mask_phone_number(&request.national_phone_number);
|
||
let template_param = serde_json::json!({
|
||
self.config.template_param_key.clone(): "##code##",
|
||
"min": self.config.valid_time_seconds,
|
||
})
|
||
.to_string();
|
||
info!(
|
||
provider = "aliyun",
|
||
scene = request.scene.as_str(),
|
||
phone_masked = phone_masked.as_str(),
|
||
endpoint = self.config.endpoint.as_str(),
|
||
sign_name = self.config.sign_name.as_str(),
|
||
template_code = self.config.template_code.as_str(),
|
||
code_length = self.config.code_length,
|
||
valid_time_seconds = self.config.valid_time_seconds,
|
||
interval_seconds = self.config.interval_seconds,
|
||
provider_out_id = provider_out_id.as_str(),
|
||
"准备调用阿里云短信发送接口"
|
||
);
|
||
|
||
let mut query = BTreeMap::new();
|
||
query.insert("Action".to_string(), "SendSmsVerifyCode".to_string());
|
||
query.insert("Format".to_string(), "json".to_string());
|
||
query.insert("Version".to_string(), "2017-05-25".to_string());
|
||
query.insert("Timestamp".to_string(), current_aliyun_timestamp());
|
||
query.insert("SignatureNonce".to_string(), new_uuid_simple_string());
|
||
query.insert("SignatureMethod".to_string(), "HMAC-SHA1".to_string());
|
||
query.insert("SignatureVersion".to_string(), "1.0".to_string());
|
||
query.insert(
|
||
"AccessKeyId".to_string(),
|
||
self.config.access_key_id.clone().unwrap_or_default(),
|
||
);
|
||
query.insert(
|
||
"PhoneNumber".to_string(),
|
||
request.national_phone_number.trim().to_string(),
|
||
);
|
||
query.insert("CountryCode".to_string(), self.config.country_code.clone());
|
||
query.insert("SignName".to_string(), self.config.sign_name.clone());
|
||
query.insert(
|
||
"TemplateCode".to_string(),
|
||
self.config.template_code.clone(),
|
||
);
|
||
query.insert("TemplateParam".to_string(), template_param);
|
||
query.insert(
|
||
"CodeLength".to_string(),
|
||
self.config.code_length.to_string(),
|
||
);
|
||
query.insert("CodeType".to_string(), self.config.code_type.to_string());
|
||
query.insert(
|
||
"ValidTime".to_string(),
|
||
self.config.valid_time_seconds.to_string(),
|
||
);
|
||
query.insert(
|
||
"Interval".to_string(),
|
||
self.config.interval_seconds.to_string(),
|
||
);
|
||
query.insert(
|
||
"DuplicatePolicy".to_string(),
|
||
self.config.duplicate_policy.to_string(),
|
||
);
|
||
query.insert(
|
||
"ReturnVerifyCode".to_string(),
|
||
self.config.return_verify_code.to_string(),
|
||
);
|
||
query.insert("OutId".to_string(), provider_out_id.clone());
|
||
if let Some(scheme_name) = self.config.scheme_name.clone() {
|
||
query.insert("SchemeName".to_string(), scheme_name);
|
||
}
|
||
self.sign_query(&mut query)?;
|
||
|
||
let payload = self
|
||
.client
|
||
.post(build_aliyun_sms_url(&self.config.endpoint)?)
|
||
.form(&query)
|
||
.send()
|
||
.await
|
||
.map_err(|error| SmsProviderError::Upstream(format!("短信验证码发送失败:{error}")))?;
|
||
let http_status = payload.status();
|
||
|
||
let body = parse_aliyun_json_response(payload, "短信验证码发送失败").await?;
|
||
info!(
|
||
provider = "aliyun",
|
||
scene = request.scene.as_str(),
|
||
phone_masked = phone_masked.as_str(),
|
||
http_status = http_status.as_u16(),
|
||
provider_code = body.code.as_deref().unwrap_or("unknown"),
|
||
provider_message = body.message.as_deref().unwrap_or("unknown"),
|
||
provider_request_id = body
|
||
.request_id
|
||
.as_deref()
|
||
.or_else(|| body
|
||
.model
|
||
.as_ref()
|
||
.and_then(|model| model.request_id.as_deref()))
|
||
.unwrap_or("unknown"),
|
||
provider_out_id = body
|
||
.model
|
||
.as_ref()
|
||
.and_then(|model| model.out_id.as_deref())
|
||
.unwrap_or("unknown"),
|
||
success = body.success.unwrap_or(false),
|
||
"阿里云短信发送接口返回响应"
|
||
);
|
||
if !body.success.unwrap_or(false) || body.code.as_deref() != Some("OK") {
|
||
warn!(
|
||
provider = "aliyun",
|
||
scene = request.scene.as_str(),
|
||
phone_masked = phone_masked.as_str(),
|
||
http_status = http_status.as_u16(),
|
||
provider_code = body.code.as_deref().unwrap_or("unknown"),
|
||
provider_message = body.message.as_deref().unwrap_or("unknown"),
|
||
provider_request_id = body
|
||
.request_id
|
||
.as_deref()
|
||
.or_else(|| body
|
||
.model
|
||
.as_ref()
|
||
.and_then(|model| model.request_id.as_deref()))
|
||
.unwrap_or("unknown"),
|
||
provider_out_id = body
|
||
.model
|
||
.as_ref()
|
||
.and_then(|model| model.out_id.as_deref())
|
||
.unwrap_or("unknown"),
|
||
"阿里云短信发送接口返回业务失败"
|
||
);
|
||
return Err(map_aliyun_provider_error(
|
||
"短信验证码发送失败",
|
||
body.message,
|
||
body.code,
|
||
));
|
||
}
|
||
|
||
Ok(SmsSendCodeResult {
|
||
cooldown_seconds: self.config.interval_seconds,
|
||
expires_in_seconds: self.config.valid_time_seconds,
|
||
provider_request_id: body.request_id.or_else(|| {
|
||
body.model
|
||
.as_ref()
|
||
.and_then(|model| model.request_id.clone())
|
||
}),
|
||
provider_out_id: body.model.and_then(|model| model.out_id),
|
||
})
|
||
}
|
||
|
||
async fn verify_code(&self, request: SmsVerifyCodeRequest) -> Result<(), SmsProviderError> {
|
||
let mut query = BTreeMap::new();
|
||
query.insert("Action".to_string(), "CheckSmsVerifyCode".to_string());
|
||
query.insert("Format".to_string(), "json".to_string());
|
||
query.insert("Version".to_string(), "2017-05-25".to_string());
|
||
query.insert("Timestamp".to_string(), current_aliyun_timestamp());
|
||
query.insert("SignatureNonce".to_string(), new_uuid_simple_string());
|
||
query.insert("SignatureMethod".to_string(), "HMAC-SHA1".to_string());
|
||
query.insert("SignatureVersion".to_string(), "1.0".to_string());
|
||
query.insert(
|
||
"AccessKeyId".to_string(),
|
||
self.config.access_key_id.clone().unwrap_or_default(),
|
||
);
|
||
query.insert(
|
||
"PhoneNumber".to_string(),
|
||
request.national_phone_number.trim().to_string(),
|
||
);
|
||
query.insert("CountryCode".to_string(), self.config.country_code.clone());
|
||
query.insert(
|
||
"VerifyCode".to_string(),
|
||
request.verify_code.trim().to_string(),
|
||
);
|
||
query.insert(
|
||
"CaseAuthPolicy".to_string(),
|
||
self.config.case_auth_policy.to_string(),
|
||
);
|
||
if let Some(scheme_name) = self.config.scheme_name.clone() {
|
||
query.insert("SchemeName".to_string(), scheme_name);
|
||
}
|
||
if let Some(provider_out_id) = request.provider_out_id {
|
||
query.insert("OutId".to_string(), provider_out_id);
|
||
}
|
||
self.sign_query(&mut query)?;
|
||
|
||
let payload = self
|
||
.client
|
||
.post(build_aliyun_sms_url(&self.config.endpoint)?)
|
||
.form(&query)
|
||
.send()
|
||
.await
|
||
.map_err(|error| SmsProviderError::Upstream(format!("验证码校验失败:{error}")))?;
|
||
|
||
let body = parse_aliyun_json_response_for_verify(payload).await?;
|
||
if !body.success.unwrap_or(false) || body.code.as_deref() != Some("OK") {
|
||
return Err(map_aliyun_provider_error(
|
||
"验证码校验失败",
|
||
body.message,
|
||
body.code,
|
||
));
|
||
}
|
||
if body.model.and_then(|model| model.verify_result).as_deref() != Some("PASS") {
|
||
return Err(SmsProviderError::InvalidVerifyCode);
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
fn sign_query(&self, query: &mut BTreeMap<String, String>) -> Result<(), SmsProviderError> {
|
||
let access_key_secret = self.config.access_key_secret.as_deref().ok_or_else(|| {
|
||
SmsProviderError::InvalidConfig("阿里云短信 AccessKeySecret 未配置".to_string())
|
||
})?;
|
||
let canonicalized = canonicalize_aliyun_rpc_params(query);
|
||
let string_to_sign = format!(
|
||
"POST&{}&{}",
|
||
aliyun_percent_encode("/"),
|
||
aliyun_percent_encode(&canonicalized)
|
||
);
|
||
let mut signer = HmacSha1::new_from_slice(format!("{access_key_secret}&").as_bytes())
|
||
.map_err(|error| {
|
||
SmsProviderError::InvalidConfig(format!("初始化短信签名器失败:{error}"))
|
||
})?;
|
||
signer.update(string_to_sign.as_bytes());
|
||
let signature = BASE64_STANDARD.encode(signer.finalize().into_bytes());
|
||
query.insert("Signature".to_string(), signature);
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
impl AccessTokenClaims {
|
||
pub fn from_input(
|
||
input: AccessTokenClaimsInput,
|
||
config: &JwtConfig,
|
||
issued_at: OffsetDateTime,
|
||
) -> Result<Self, JwtError> {
|
||
let user_id = normalize_required_field(input.user_id, "JWT sub 不能为空")?;
|
||
let session_id = normalize_required_field(input.session_id, "JWT sid 不能为空")?;
|
||
let roles = normalize_roles(input.roles)?;
|
||
let display_name = normalize_optional_field(input.display_name);
|
||
|
||
let issued_at_unix = issued_at.unix_timestamp();
|
||
if issued_at_unix < 0 {
|
||
return Err(JwtError::InvalidClaims("JWT iat 不能早于 Unix epoch"));
|
||
}
|
||
|
||
let expires_at = issued_at
|
||
.checked_add(Duration::seconds(
|
||
i64::try_from(config.access_token_ttl_seconds()).map_err(|_| {
|
||
JwtError::InvalidConfig("JWT access token 过期时间超出 i64 上限")
|
||
})?,
|
||
))
|
||
.ok_or(JwtError::InvalidConfig("JWT 过期时间计算溢出"))?;
|
||
|
||
let expires_at_unix = expires_at.unix_timestamp();
|
||
if expires_at_unix <= issued_at_unix {
|
||
return Err(JwtError::InvalidClaims("JWT exp 必须晚于 iat"));
|
||
}
|
||
|
||
let claims = Self {
|
||
iss: config.issuer().to_string(),
|
||
sub: user_id,
|
||
sid: session_id,
|
||
provider: input.provider,
|
||
roles,
|
||
ver: input.token_version,
|
||
phone_verified: input.phone_verified,
|
||
binding_status: input.binding_status,
|
||
display_name,
|
||
iat: issued_at_unix as u64,
|
||
exp: expires_at_unix as u64,
|
||
};
|
||
|
||
claims.validate_for_config(config)?;
|
||
Ok(claims)
|
||
}
|
||
|
||
pub fn user_id(&self) -> &str {
|
||
&self.sub
|
||
}
|
||
|
||
pub fn session_id(&self) -> &str {
|
||
&self.sid
|
||
}
|
||
|
||
pub fn token_version(&self) -> u64 {
|
||
self.ver
|
||
}
|
||
|
||
pub fn validate_for_config(&self, config: &JwtConfig) -> Result<(), JwtError> {
|
||
if self.iss.trim() != config.issuer() {
|
||
return Err(JwtError::InvalidClaims("JWT iss 与当前配置不一致"));
|
||
}
|
||
|
||
normalize_required_field(self.sub.clone(), "JWT sub 不能为空")?;
|
||
normalize_required_field(self.sid.clone(), "JWT sid 不能为空")?;
|
||
normalize_roles(self.roles.clone())?;
|
||
|
||
if self.exp <= self.iat {
|
||
return Err(JwtError::InvalidClaims("JWT exp 必须晚于 iat"));
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
pub fn sign_access_token(
|
||
claims: &AccessTokenClaims,
|
||
config: &JwtConfig,
|
||
) -> Result<String, JwtError> {
|
||
claims.validate_for_config(config)?;
|
||
|
||
let header = Header {
|
||
alg: ACCESS_TOKEN_ALGORITHM,
|
||
typ: Some("JWT".to_string()),
|
||
..Header::default()
|
||
};
|
||
|
||
encode(
|
||
&header,
|
||
claims,
|
||
&EncodingKey::from_secret(config.secret.as_bytes()),
|
||
)
|
||
.map_err(|error| JwtError::SignFailed(format!("JWT 签发失败:{error}")))
|
||
}
|
||
|
||
pub fn verify_access_token(token: &str, config: &JwtConfig) -> Result<AccessTokenClaims, JwtError> {
|
||
let token = token.trim();
|
||
if token.is_empty() {
|
||
return Err(JwtError::VerifyFailed("JWT 不能为空".to_string()));
|
||
}
|
||
|
||
let mut validation = Validation::new(ACCESS_TOKEN_ALGORITHM);
|
||
validation.required_spec_claims = HashSet::from([
|
||
"exp".to_string(),
|
||
"iat".to_string(),
|
||
"iss".to_string(),
|
||
"sub".to_string(),
|
||
]);
|
||
validation.set_issuer(&[config.issuer()]);
|
||
|
||
let decoded = decode::<AccessTokenClaims>(
|
||
token,
|
||
&DecodingKey::from_secret(config.secret.as_bytes()),
|
||
&validation,
|
||
)
|
||
.map_err(map_verify_error)?;
|
||
|
||
decoded.claims.validate_for_config(config)?;
|
||
Ok(decoded.claims)
|
||
}
|
||
|
||
pub fn read_refresh_session_token(
|
||
cookie_header: &str,
|
||
config: &RefreshCookieConfig,
|
||
) -> Option<String> {
|
||
let cookie_header = cookie_header.trim();
|
||
if cookie_header.is_empty() {
|
||
return None;
|
||
}
|
||
|
||
for entry in cookie_header.split(';') {
|
||
let entry = entry.trim();
|
||
if entry.is_empty() {
|
||
continue;
|
||
}
|
||
|
||
let (raw_name, raw_value) = entry.split_once('=')?;
|
||
if raw_name.trim() != config.cookie_name() {
|
||
continue;
|
||
}
|
||
|
||
let raw_value = raw_value.trim();
|
||
if raw_value.is_empty() {
|
||
return None;
|
||
}
|
||
|
||
return urlencoding::decode(raw_value)
|
||
.ok()
|
||
.map(|decoded| decoded.into_owned());
|
||
}
|
||
|
||
None
|
||
}
|
||
|
||
pub async fn hash_password(password: &str) -> Result<String, PasswordHashError> {
|
||
let salt = SaltString::generate(&mut OsRng);
|
||
Argon2::default()
|
||
.hash_password(password.as_bytes(), &salt)
|
||
.map(|hash| hash.to_string())
|
||
.map_err(|error| PasswordHashError::HashFailed(format!("密码哈希失败:{error}")))
|
||
}
|
||
|
||
pub async fn verify_password(
|
||
password_hash: &str,
|
||
password: &str,
|
||
) -> Result<bool, PasswordHashError> {
|
||
let parsed_hash = PasswordHash::new(password_hash)
|
||
.map_err(|error| PasswordHashError::VerifyFailed(format!("密码哈希格式非法:{error}")))?;
|
||
|
||
Ok(Argon2::default()
|
||
.verify_password(password.as_bytes(), &parsed_hash)
|
||
.is_ok())
|
||
}
|
||
|
||
pub fn create_refresh_session_token() -> String {
|
||
new_uuid_simple_string()
|
||
}
|
||
|
||
pub fn hash_refresh_session_token(token: &str) -> String {
|
||
let mut hasher = Sha256::new();
|
||
hasher.update(token.as_bytes());
|
||
format!("{:x}", hasher.finalize())
|
||
}
|
||
|
||
pub fn build_refresh_session_set_cookie(token: &str, config: &RefreshCookieConfig) -> String {
|
||
let mut parts = vec![
|
||
format!(
|
||
"{}={}",
|
||
config.cookie_name(),
|
||
urlencoding::encode(token).into_owned()
|
||
),
|
||
format!("Path={}", config.cookie_path()),
|
||
"HttpOnly".to_string(),
|
||
format!("SameSite={}", config.cookie_same_site().as_str()),
|
||
format!(
|
||
"Max-Age={}",
|
||
u64::from(config.refresh_session_ttl_days()) * 24 * 60 * 60
|
||
),
|
||
];
|
||
|
||
if config.cookie_secure() {
|
||
parts.push("Secure".to_string());
|
||
}
|
||
|
||
parts.join("; ")
|
||
}
|
||
|
||
pub fn build_refresh_session_clear_cookie(config: &RefreshCookieConfig) -> String {
|
||
let mut parts = vec![
|
||
format!("{}=", config.cookie_name()),
|
||
format!("Path={}", config.cookie_path()),
|
||
"HttpOnly".to_string(),
|
||
format!("SameSite={}", config.cookie_same_site().as_str()),
|
||
"Max-Age=0".to_string(),
|
||
];
|
||
|
||
if config.cookie_secure() {
|
||
parts.push("Secure".to_string());
|
||
}
|
||
|
||
parts.join("; ")
|
||
}
|
||
|
||
impl fmt::Display for JwtError {
|
||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||
match self {
|
||
Self::InvalidConfig(message) | Self::InvalidClaims(message) => f.write_str(message),
|
||
Self::SignFailed(message) | Self::VerifyFailed(message) => f.write_str(message),
|
||
}
|
||
}
|
||
}
|
||
|
||
impl Error for JwtError {}
|
||
|
||
impl fmt::Display for RefreshCookieError {
|
||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||
match self {
|
||
Self::InvalidConfig(message) => f.write_str(message),
|
||
}
|
||
}
|
||
}
|
||
|
||
impl Error for RefreshCookieError {}
|
||
|
||
impl fmt::Display for PasswordHashError {
|
||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||
match self {
|
||
Self::HashFailed(message) | Self::VerifyFailed(message) => f.write_str(message),
|
||
}
|
||
}
|
||
}
|
||
|
||
impl Error for PasswordHashError {}
|
||
|
||
fn normalize_required_field(
|
||
value: String,
|
||
error_message: &'static str,
|
||
) -> Result<String, JwtError> {
|
||
normalize_required_string(&value).ok_or(JwtError::InvalidClaims(error_message))
|
||
}
|
||
|
||
fn normalize_optional_field(value: Option<String>) -> Option<String> {
|
||
normalize_optional_string(value)
|
||
}
|
||
|
||
fn normalize_roles(roles: Vec<String>) -> Result<Vec<String>, JwtError> {
|
||
let roles = roles
|
||
.into_iter()
|
||
.map(|role| role.trim().to_string())
|
||
.filter(|role| !role.is_empty())
|
||
.collect::<Vec<_>>();
|
||
|
||
if roles.is_empty() {
|
||
return Err(JwtError::InvalidClaims("JWT roles 至少包含一个角色"));
|
||
}
|
||
|
||
Ok(roles)
|
||
}
|
||
|
||
fn map_verify_error(error: jsonwebtoken::errors::Error) -> JwtError {
|
||
let message = match error.kind() {
|
||
ErrorKind::ExpiredSignature => "JWT 已过期".to_string(),
|
||
ErrorKind::InvalidIssuer => "JWT 发行者不匹配".to_string(),
|
||
ErrorKind::InvalidSignature => "JWT 签名无效".to_string(),
|
||
ErrorKind::InvalidAlgorithm => "JWT 算法不匹配".to_string(),
|
||
ErrorKind::InvalidToken => "JWT 非法".to_string(),
|
||
ErrorKind::ImmatureSignature => "JWT 尚未生效".to_string(),
|
||
ErrorKind::MissingRequiredClaim(claim) => format!("JWT 缺少必填字段:{claim}"),
|
||
_ => format!("JWT 校验失败:{error}"),
|
||
};
|
||
|
||
JwtError::VerifyFailed(message)
|
||
}
|
||
|
||
fn build_sms_provider_out_id(scene: &str, national_phone_number: &str) -> String {
|
||
let phone_suffix = national_phone_number
|
||
.chars()
|
||
.rev()
|
||
.take(4)
|
||
.collect::<String>()
|
||
.chars()
|
||
.rev()
|
||
.collect::<String>();
|
||
format!("{scene}_{}_{}", phone_suffix, new_uuid_simple_string())
|
||
}
|
||
|
||
fn build_aliyun_sms_url(endpoint: &str) -> Result<String, SmsProviderError> {
|
||
let endpoint = endpoint
|
||
.trim()
|
||
.trim_start_matches("https://")
|
||
.trim_start_matches("http://")
|
||
.trim_matches('/');
|
||
if endpoint.is_empty() {
|
||
return Err(SmsProviderError::InvalidConfig(
|
||
"阿里云短信 endpoint 不能为空".to_string(),
|
||
));
|
||
}
|
||
Ok(format!("https://{endpoint}/"))
|
||
}
|
||
|
||
fn current_aliyun_timestamp() -> String {
|
||
OffsetDateTime::now_utc()
|
||
.format(&time::format_description::well_known::Rfc3339)
|
||
.unwrap_or_else(|_| "1970-01-01T00:00:00Z".to_string())
|
||
}
|
||
|
||
fn canonicalize_aliyun_rpc_params(params: &BTreeMap<String, String>) -> String {
|
||
params
|
||
.iter()
|
||
.filter(|(key, _)| key.as_str() != "Signature")
|
||
.map(|(key, value)| {
|
||
format!(
|
||
"{}={}",
|
||
aliyun_percent_encode(key),
|
||
aliyun_percent_encode(value)
|
||
)
|
||
})
|
||
.collect::<Vec<_>>()
|
||
.join("&")
|
||
}
|
||
|
||
fn aliyun_percent_encode(value: &str) -> String {
|
||
urlencoding::encode(value)
|
||
.into_owned()
|
||
.replace('+', "%20")
|
||
.replace('*', "%2A")
|
||
.replace("%7E", "~")
|
||
}
|
||
|
||
async fn parse_aliyun_json_response(
|
||
response: reqwest::Response,
|
||
fallback_message: &str,
|
||
) -> Result<AliyunSendSmsVerifyCodeResponse, SmsProviderError> {
|
||
let status = response.status();
|
||
let body = response
|
||
.text()
|
||
.await
|
||
.map_err(|error| SmsProviderError::Upstream(format!("{fallback_message}:{error}")))?;
|
||
let payload =
|
||
serde_json::from_str::<AliyunSendSmsVerifyCodeResponse>(&body).map_err(|error| {
|
||
SmsProviderError::Upstream(format!("{fallback_message}:响应解析失败:{error}"))
|
||
})?;
|
||
if status.is_client_error() || status.is_server_error() {
|
||
return Err(map_http_status_to_sms_provider_error(
|
||
fallback_message,
|
||
status,
|
||
serde_json::from_str::<Value>(&body).ok(),
|
||
));
|
||
}
|
||
|
||
Ok(payload)
|
||
}
|
||
|
||
async fn parse_aliyun_json_response_for_verify(
|
||
response: reqwest::Response,
|
||
) -> Result<AliyunCheckSmsVerifyCodeResponse, SmsProviderError> {
|
||
let status = response.status();
|
||
let body = response
|
||
.text()
|
||
.await
|
||
.map_err(|error| SmsProviderError::Upstream(format!("验证码校验失败:{error}")))?;
|
||
let payload =
|
||
serde_json::from_str::<AliyunCheckSmsVerifyCodeResponse>(&body).map_err(|error| {
|
||
SmsProviderError::Upstream(format!("验证码校验失败:响应解析失败:{error}"))
|
||
})?;
|
||
if status.is_client_error() || status.is_server_error() {
|
||
return Err(map_http_status_to_sms_provider_error(
|
||
"验证码校验失败",
|
||
status,
|
||
serde_json::from_str::<Value>(&body).ok(),
|
||
));
|
||
}
|
||
|
||
Ok(payload)
|
||
}
|
||
|
||
fn map_http_status_to_sms_provider_error(
|
||
fallback_message: &str,
|
||
status: StatusCode,
|
||
payload: Option<Value>,
|
||
) -> SmsProviderError {
|
||
let provider_message = payload
|
||
.as_ref()
|
||
.and_then(|value| value.get("Message").and_then(Value::as_str))
|
||
.unwrap_or_default();
|
||
let provider_code = payload
|
||
.as_ref()
|
||
.and_then(|value| value.get("Code").and_then(Value::as_str))
|
||
.unwrap_or_default();
|
||
|
||
if status.is_client_error() {
|
||
return map_aliyun_provider_error(
|
||
fallback_message,
|
||
Some(provider_message.to_string()),
|
||
Some(provider_code.to_string()),
|
||
);
|
||
}
|
||
|
||
SmsProviderError::Upstream(build_provider_error_message(
|
||
fallback_message,
|
||
provider_message,
|
||
))
|
||
}
|
||
|
||
fn map_aliyun_provider_error(
|
||
fallback_message: &str,
|
||
provider_message: Option<String>,
|
||
provider_code: Option<String>,
|
||
) -> SmsProviderError {
|
||
let provider_message = provider_message.unwrap_or_default();
|
||
let provider_code = provider_code.unwrap_or_default();
|
||
let normalized_code = provider_code.trim().to_ascii_uppercase();
|
||
|
||
if normalized_code.contains("VERIFY")
|
||
|| normalized_code.contains("CODE")
|
||
|| normalized_code.contains("CHECK")
|
||
{
|
||
return SmsProviderError::InvalidVerifyCode;
|
||
}
|
||
|
||
if normalized_code.contains("MOBILE")
|
||
|| normalized_code.contains("PHONE")
|
||
|| normalized_code.contains("SIGN")
|
||
|| normalized_code.contains("TEMPLATE")
|
||
|| normalized_code.contains("ACCESSKEY")
|
||
{
|
||
return SmsProviderError::InvalidConfig(build_provider_error_message(
|
||
fallback_message,
|
||
&provider_message,
|
||
));
|
||
}
|
||
|
||
SmsProviderError::Upstream(build_provider_error_message(
|
||
fallback_message,
|
||
&provider_message,
|
||
))
|
||
}
|
||
|
||
fn build_provider_error_message(prefix: &str, provider_message: &str) -> String {
|
||
let provider_message = provider_message.trim();
|
||
if provider_message.is_empty() {
|
||
prefix.to_string()
|
||
} else {
|
||
format!("{prefix}:{provider_message}")
|
||
}
|
||
}
|
||
|
||
fn mask_phone_number(phone_number: &str) -> String {
|
||
let chars: Vec<char> = phone_number.chars().collect();
|
||
if chars.len() <= 4 {
|
||
return "*".repeat(chars.len().max(1));
|
||
}
|
||
let prefix_len = chars.len().min(3);
|
||
let suffix_len = 4.min(chars.len().saturating_sub(prefix_len));
|
||
let mask_len = chars.len().saturating_sub(prefix_len + suffix_len);
|
||
let mut masked = String::new();
|
||
masked.extend(chars.iter().take(prefix_len));
|
||
masked.push_str(&"*".repeat(mask_len.max(1)));
|
||
if suffix_len > 0 {
|
||
masked.extend(chars.iter().skip(chars.len() - suffix_len));
|
||
}
|
||
masked
|
||
}
|
||
|
||
impl fmt::Display for SmsProviderError {
|
||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||
match self {
|
||
Self::InvalidConfig(message) | Self::Upstream(message) => f.write_str(message),
|
||
Self::InvalidVerifyCode => f.write_str("验证码错误"),
|
||
}
|
||
}
|
||
}
|
||
|
||
impl Error for SmsProviderError {}
|
||
|
||
impl fmt::Display for WechatProviderError {
|
||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||
match self {
|
||
Self::Disabled => f.write_str("微信登录暂未启用"),
|
||
Self::MissingCode => f.write_str("缺少微信授权 code"),
|
||
Self::InvalidConfig(message)
|
||
| Self::InvalidCallback(message)
|
||
| Self::RequestFailed(message)
|
||
| Self::DeserializeFailed(message)
|
||
| Self::Upstream(message)
|
||
| Self::MissingProfile(message) => f.write_str(message),
|
||
}
|
||
}
|
||
}
|
||
|
||
impl Error for WechatProviderError {}
|
||
|
||
impl JwtError {
|
||
pub fn kind(&self) -> AuthPlatformErrorKind {
|
||
match self {
|
||
Self::InvalidConfig(_) => AuthPlatformErrorKind::InvalidConfig,
|
||
Self::InvalidClaims(_) => AuthPlatformErrorKind::InvalidClaims,
|
||
Self::SignFailed(_) => AuthPlatformErrorKind::SignFailed,
|
||
Self::VerifyFailed(_) => AuthPlatformErrorKind::VerifyFailed,
|
||
}
|
||
}
|
||
}
|
||
|
||
impl RefreshCookieError {
|
||
pub fn kind(&self) -> AuthPlatformErrorKind {
|
||
match self {
|
||
Self::InvalidConfig(_) => AuthPlatformErrorKind::CookieConfig,
|
||
}
|
||
}
|
||
}
|
||
|
||
impl PasswordHashError {
|
||
pub fn kind(&self) -> AuthPlatformErrorKind {
|
||
match self {
|
||
Self::HashFailed(_) => AuthPlatformErrorKind::HashFailed,
|
||
Self::VerifyFailed(_) => AuthPlatformErrorKind::VerifyFailed,
|
||
}
|
||
}
|
||
}
|
||
|
||
impl SmsProviderError {
|
||
pub fn kind(&self) -> AuthPlatformErrorKind {
|
||
match self {
|
||
Self::InvalidConfig(_) => AuthPlatformErrorKind::InvalidConfig,
|
||
Self::InvalidVerifyCode => AuthPlatformErrorKind::InvalidVerifyCode,
|
||
Self::Upstream(_) => AuthPlatformErrorKind::Upstream,
|
||
}
|
||
}
|
||
}
|
||
|
||
impl WechatProviderError {
|
||
pub fn kind(&self) -> AuthPlatformErrorKind {
|
||
match self {
|
||
Self::Disabled => AuthPlatformErrorKind::Disabled,
|
||
Self::MissingCode => AuthPlatformErrorKind::MissingCode,
|
||
Self::InvalidConfig(_) => AuthPlatformErrorKind::InvalidConfig,
|
||
Self::InvalidCallback(_) => AuthPlatformErrorKind::InvalidCallback,
|
||
Self::RequestFailed(_) => AuthPlatformErrorKind::RequestFailed,
|
||
Self::DeserializeFailed(_) => AuthPlatformErrorKind::DeserializeFailed,
|
||
Self::Upstream(_) => AuthPlatformErrorKind::Upstream,
|
||
Self::MissingProfile(_) => AuthPlatformErrorKind::MissingProfile,
|
||
}
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn auth_platform_error_kind_is_stable_for_adapter_mapping() {
|
||
assert_eq!(
|
||
JwtError::InvalidClaims("JWT roles 至少包含一个角色").kind(),
|
||
AuthPlatformErrorKind::InvalidClaims
|
||
);
|
||
assert_eq!(
|
||
PasswordHashError::VerifyFailed("密码校验失败".to_string()).kind(),
|
||
AuthPlatformErrorKind::VerifyFailed
|
||
);
|
||
assert_eq!(
|
||
SmsProviderError::InvalidVerifyCode.kind(),
|
||
AuthPlatformErrorKind::InvalidVerifyCode
|
||
);
|
||
assert_eq!(
|
||
WechatProviderError::MissingCode.kind(),
|
||
AuthPlatformErrorKind::MissingCode
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn mock_wechat_provider_builds_callback_authorization_url() {
|
||
let provider = WechatProvider::new(WechatAuthConfig::new(
|
||
true,
|
||
"mock".to_string(),
|
||
None,
|
||
None,
|
||
None,
|
||
None,
|
||
DEFAULT_WECHAT_AUTHORIZE_ENDPOINT.to_string(),
|
||
DEFAULT_WECHAT_ACCESS_TOKEN_ENDPOINT.to_string(),
|
||
DEFAULT_WECHAT_USER_INFO_ENDPOINT.to_string(),
|
||
DEFAULT_WECHAT_JS_CODE_SESSION_ENDPOINT.to_string(),
|
||
"wx-user-001".to_string(),
|
||
Some("wx-union-001".to_string()),
|
||
"微信测试用户".to_string(),
|
||
Some("https://example.test/avatar.png".to_string()),
|
||
));
|
||
|
||
let authorization_url = provider
|
||
.build_authorization_url(
|
||
"http://127.0.0.1:3000/api/auth/wechat/callback",
|
||
"state_001",
|
||
&WechatAuthScene::Desktop,
|
||
)
|
||
.expect("mock authorization url should build");
|
||
|
||
assert!(authorization_url.contains("mock_code=wx-mock-code"));
|
||
assert!(authorization_url.contains("state=state_001"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn mock_wechat_provider_resolves_identity_profile() {
|
||
let provider = WechatProvider::new(WechatAuthConfig::new(
|
||
true,
|
||
"mock".to_string(),
|
||
None,
|
||
None,
|
||
None,
|
||
None,
|
||
DEFAULT_WECHAT_AUTHORIZE_ENDPOINT.to_string(),
|
||
DEFAULT_WECHAT_ACCESS_TOKEN_ENDPOINT.to_string(),
|
||
DEFAULT_WECHAT_USER_INFO_ENDPOINT.to_string(),
|
||
DEFAULT_WECHAT_JS_CODE_SESSION_ENDPOINT.to_string(),
|
||
"wx-user-001".to_string(),
|
||
Some("wx-union-001".to_string()),
|
||
"微信测试用户".to_string(),
|
||
None,
|
||
));
|
||
|
||
let profile = provider
|
||
.resolve_callback_profile(None, Some("wx-code-001"))
|
||
.await
|
||
.expect("mock profile should resolve");
|
||
|
||
assert_eq!(profile.provider_uid, "wx-code-001");
|
||
assert_eq!(profile.provider_union_id.as_deref(), Some("wx-union-001"));
|
||
assert_eq!(profile.display_name.as_deref(), Some("微信测试用户"));
|
||
}
|
||
|
||
fn build_jwt_config() -> JwtConfig {
|
||
JwtConfig::new(
|
||
"https://auth.genarrative.local".to_string(),
|
||
"genarrative-dev-secret".to_string(),
|
||
DEFAULT_ACCESS_TOKEN_TTL_SECONDS,
|
||
)
|
||
.expect("jwt config should be valid")
|
||
}
|
||
|
||
fn build_claims_input() -> AccessTokenClaimsInput {
|
||
AccessTokenClaimsInput {
|
||
user_id: "usr_123".to_string(),
|
||
session_id: "sess_456".to_string(),
|
||
provider: AuthProvider::Wechat,
|
||
roles: vec!["user".to_string()],
|
||
token_version: 3,
|
||
phone_verified: false,
|
||
binding_status: BindingStatus::PendingBindPhone,
|
||
display_name: Some("微信旅人".to_string()),
|
||
}
|
||
}
|
||
|
||
fn build_refresh_cookie_config() -> RefreshCookieConfig {
|
||
RefreshCookieConfig::new(
|
||
DEFAULT_REFRESH_COOKIE_NAME.to_string(),
|
||
DEFAULT_REFRESH_COOKIE_PATH.to_string(),
|
||
false,
|
||
RefreshCookieSameSite::Lax,
|
||
DEFAULT_REFRESH_SESSION_TTL_DAYS,
|
||
)
|
||
.expect("refresh cookie config should be valid")
|
||
}
|
||
|
||
fn build_mock_sms_config() -> SmsAuthConfig {
|
||
SmsAuthConfig::new(
|
||
SmsAuthProviderKind::Mock,
|
||
DEFAULT_SMS_ENDPOINT.to_string(),
|
||
None,
|
||
None,
|
||
String::new(),
|
||
String::new(),
|
||
DEFAULT_SMS_TEMPLATE_PARAM_KEY.to_string(),
|
||
DEFAULT_SMS_COUNTRY_CODE.to_string(),
|
||
None,
|
||
DEFAULT_SMS_CODE_LENGTH,
|
||
DEFAULT_SMS_CODE_TYPE,
|
||
DEFAULT_SMS_VALID_TIME_SECONDS,
|
||
DEFAULT_SMS_INTERVAL_SECONDS,
|
||
DEFAULT_SMS_DUPLICATE_POLICY,
|
||
DEFAULT_SMS_CASE_AUTH_POLICY,
|
||
false,
|
||
DEFAULT_SMS_MOCK_VERIFY_CODE.to_string(),
|
||
)
|
||
.expect("mock sms config should be valid")
|
||
}
|
||
|
||
#[test]
|
||
fn round_trip_sign_and_verify_access_token() {
|
||
let config = build_jwt_config();
|
||
let claims =
|
||
AccessTokenClaims::from_input(build_claims_input(), &config, OffsetDateTime::now_utc())
|
||
.expect("claims should build");
|
||
|
||
let token = sign_access_token(&claims, &config).expect("token should sign");
|
||
let verified = verify_access_token(&token, &config).expect("token should verify");
|
||
|
||
assert_eq!(verified, claims);
|
||
assert_eq!(verified.user_id(), "usr_123");
|
||
assert_eq!(verified.session_id(), "sess_456");
|
||
assert_eq!(verified.token_version(), 3);
|
||
}
|
||
|
||
#[test]
|
||
fn verify_rejects_invalid_issuer() {
|
||
let config = build_jwt_config();
|
||
let claims =
|
||
AccessTokenClaims::from_input(build_claims_input(), &config, OffsetDateTime::now_utc())
|
||
.expect("claims should build");
|
||
let token = sign_access_token(&claims, &config).expect("token should sign");
|
||
let other_config = JwtConfig::new(
|
||
"https://auth.other.local".to_string(),
|
||
"genarrative-dev-secret".to_string(),
|
||
DEFAULT_ACCESS_TOKEN_TTL_SECONDS,
|
||
)
|
||
.expect("other config should be valid");
|
||
|
||
let error = verify_access_token(&token, &other_config).expect_err("issuer should mismatch");
|
||
|
||
assert_eq!(
|
||
error,
|
||
JwtError::VerifyFailed("JWT 发行者不匹配".to_string())
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn build_claims_rejects_empty_roles() {
|
||
let error = AccessTokenClaims::from_input(
|
||
AccessTokenClaimsInput {
|
||
roles: Vec::new(),
|
||
..build_claims_input()
|
||
},
|
||
&build_jwt_config(),
|
||
OffsetDateTime::now_utc(),
|
||
)
|
||
.expect_err("empty roles should be rejected");
|
||
|
||
assert_eq!(error, JwtError::InvalidClaims("JWT roles 至少包含一个角色"));
|
||
}
|
||
|
||
#[test]
|
||
fn read_refresh_session_token_returns_matching_cookie() {
|
||
let token = read_refresh_session_token(
|
||
"theme=dark; genarrative_refresh_session=refresh-token-01; locale=zh-CN",
|
||
&build_refresh_cookie_config(),
|
||
);
|
||
|
||
assert_eq!(token.as_deref(), Some("refresh-token-01"));
|
||
}
|
||
|
||
#[test]
|
||
fn read_refresh_session_token_decodes_urlencoded_value() {
|
||
let token = read_refresh_session_token(
|
||
"genarrative_refresh_session=refresh%2Ftoken%3D01",
|
||
&build_refresh_cookie_config(),
|
||
);
|
||
|
||
assert_eq!(token.as_deref(), Some("refresh/token=01"));
|
||
}
|
||
|
||
#[test]
|
||
fn read_refresh_session_token_returns_none_when_missing() {
|
||
let token =
|
||
read_refresh_session_token("theme=dark; locale=zh-CN", &build_refresh_cookie_config());
|
||
|
||
assert!(token.is_none());
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn hash_and_verify_password_round_trip() {
|
||
let password_hash = hash_password("secret123")
|
||
.await
|
||
.expect("password hash should build");
|
||
|
||
let is_valid = verify_password(&password_hash, "secret123")
|
||
.await
|
||
.expect("password hash should verify");
|
||
|
||
assert!(is_valid);
|
||
}
|
||
|
||
#[test]
|
||
fn build_refresh_session_cookie_respects_config() {
|
||
let cookie =
|
||
build_refresh_session_set_cookie("refresh/token=01", &build_refresh_cookie_config());
|
||
|
||
assert!(cookie.contains("genarrative_refresh_session=refresh%2Ftoken%3D01"));
|
||
assert!(cookie.contains("Path=/api/auth"));
|
||
assert!(cookie.contains("HttpOnly"));
|
||
assert!(cookie.contains("SameSite=Lax"));
|
||
assert!(cookie.contains("Max-Age=2592000"));
|
||
}
|
||
|
||
#[test]
|
||
fn hash_refresh_session_token_matches_sha256_hex() {
|
||
let hash = hash_refresh_session_token("refresh-token-01");
|
||
|
||
assert_eq!(
|
||
hash,
|
||
"9fab76f9100ec6c151c8caa0c42ab10e10fbc7618f15e24cf3dffc93e19c4c4e"
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn build_refresh_session_clear_cookie_respects_config() {
|
||
let cookie = build_refresh_session_clear_cookie(&build_refresh_cookie_config());
|
||
|
||
assert!(cookie.contains("genarrative_refresh_session="));
|
||
assert!(cookie.contains("Path=/api/auth"));
|
||
assert!(cookie.contains("HttpOnly"));
|
||
assert!(cookie.contains("SameSite=Lax"));
|
||
assert!(cookie.contains("Max-Age=0"));
|
||
}
|
||
|
||
#[test]
|
||
fn sms_auth_provider_kind_parses_supported_values() {
|
||
assert_eq!(
|
||
SmsAuthProviderKind::parse("mock"),
|
||
Some(SmsAuthProviderKind::Mock)
|
||
);
|
||
assert_eq!(
|
||
SmsAuthProviderKind::parse("aliyun"),
|
||
Some(SmsAuthProviderKind::Aliyun)
|
||
);
|
||
assert_eq!(SmsAuthProviderKind::parse("other"), None);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn mock_sms_provider_sends_and_verifies_code() {
|
||
let provider =
|
||
SmsAuthProvider::new(build_mock_sms_config()).expect("provider should build");
|
||
let send_result = provider
|
||
.send_code(SmsSendCodeRequest {
|
||
national_phone_number: "13800138000".to_string(),
|
||
scene: "login".to_string(),
|
||
})
|
||
.await
|
||
.expect("send code should succeed");
|
||
|
||
assert_eq!(send_result.cooldown_seconds, DEFAULT_SMS_INTERVAL_SECONDS);
|
||
assert_eq!(
|
||
send_result.expires_in_seconds,
|
||
DEFAULT_SMS_VALID_TIME_SECONDS
|
||
);
|
||
assert_eq!(
|
||
send_result.provider_request_id.as_deref(),
|
||
Some("mock-request-id")
|
||
);
|
||
assert!(send_result.provider_out_id.is_some());
|
||
|
||
provider
|
||
.verify_code(SmsVerifyCodeRequest {
|
||
national_phone_number: "13800138000".to_string(),
|
||
verify_code: DEFAULT_SMS_MOCK_VERIFY_CODE.to_string(),
|
||
provider_out_id: send_result.provider_out_id,
|
||
})
|
||
.await
|
||
.expect("verify code should succeed");
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn mock_sms_provider_rejects_wrong_code() {
|
||
let provider =
|
||
SmsAuthProvider::new(build_mock_sms_config()).expect("provider should build");
|
||
|
||
let error = provider
|
||
.verify_code(SmsVerifyCodeRequest {
|
||
national_phone_number: "13800138000".to_string(),
|
||
verify_code: "000000".to_string(),
|
||
provider_out_id: None,
|
||
})
|
||
.await
|
||
.expect_err("wrong verify code should fail");
|
||
|
||
assert_eq!(error, SmsProviderError::InvalidVerifyCode);
|
||
}
|
||
|
||
#[test]
|
||
fn aliyun_sms_config_requires_access_key() {
|
||
let error = SmsAuthConfig::new(
|
||
SmsAuthProviderKind::Aliyun,
|
||
DEFAULT_SMS_ENDPOINT.to_string(),
|
||
None,
|
||
None,
|
||
"测试签名".to_string(),
|
||
"SMS_001".to_string(),
|
||
DEFAULT_SMS_TEMPLATE_PARAM_KEY.to_string(),
|
||
DEFAULT_SMS_COUNTRY_CODE.to_string(),
|
||
None,
|
||
DEFAULT_SMS_CODE_LENGTH,
|
||
DEFAULT_SMS_CODE_TYPE,
|
||
DEFAULT_SMS_VALID_TIME_SECONDS,
|
||
DEFAULT_SMS_INTERVAL_SECONDS,
|
||
DEFAULT_SMS_DUPLICATE_POLICY,
|
||
DEFAULT_SMS_CASE_AUTH_POLICY,
|
||
false,
|
||
DEFAULT_SMS_MOCK_VERIFY_CODE.to_string(),
|
||
)
|
||
.expect_err("aliyun config without access key should fail");
|
||
|
||
assert_eq!(
|
||
error,
|
||
SmsProviderError::InvalidConfig("阿里云短信 AccessKey 未配置".to_string())
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn canonicalize_aliyun_rpc_params_keeps_sorted_percent_encoded_order() {
|
||
let mut params = BTreeMap::new();
|
||
params.insert(
|
||
"TemplateParam".to_string(),
|
||
"{\"code\":\"##code##\"}".to_string(),
|
||
);
|
||
params.insert("Action".to_string(), "SendSmsVerifyCode".to_string());
|
||
params.insert("PhoneNumber".to_string(), "13800138000".to_string());
|
||
|
||
assert_eq!(
|
||
canonicalize_aliyun_rpc_params(¶ms),
|
||
"Action=SendSmsVerifyCode&PhoneNumber=13800138000&TemplateParam=%7B%22code%22%3A%22%23%23code%23%23%22%7D"
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn aliyun_send_response_deserializes_pascal_case_fields() {
|
||
let payload = serde_json::from_str::<AliyunSendSmsVerifyCodeResponse>(
|
||
r#"{
|
||
"Code": "OK",
|
||
"Message": "成功",
|
||
"RequestId": "req_123",
|
||
"Success": true,
|
||
"Model": {
|
||
"BizId": "biz_456",
|
||
"OutId": "out_789",
|
||
"RequestId": "req_model_001"
|
||
}
|
||
}"#,
|
||
)
|
||
.expect("aliyun send response should deserialize");
|
||
|
||
assert_eq!(payload.code.as_deref(), Some("OK"));
|
||
assert_eq!(payload.message.as_deref(), Some("成功"));
|
||
assert_eq!(payload.request_id.as_deref(), Some("req_123"));
|
||
assert_eq!(payload.success, Some(true));
|
||
assert_eq!(
|
||
payload
|
||
.model
|
||
.as_ref()
|
||
.and_then(|model| model.out_id.as_deref()),
|
||
Some("out_789")
|
||
);
|
||
assert_eq!(
|
||
payload
|
||
.model
|
||
.as_ref()
|
||
.and_then(|model| model.request_id.as_deref()),
|
||
Some("req_model_001")
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn aliyun_verify_response_deserializes_pascal_case_fields() {
|
||
let payload = serde_json::from_str::<AliyunCheckSmsVerifyCodeResponse>(
|
||
r#"{
|
||
"Code": "OK",
|
||
"Message": "成功",
|
||
"Success": true,
|
||
"Model": {
|
||
"OutId": "out_789",
|
||
"VerifyResult": "PASS"
|
||
}
|
||
}"#,
|
||
)
|
||
.expect("aliyun verify response should deserialize");
|
||
|
||
assert_eq!(payload.code.as_deref(), Some("OK"));
|
||
assert_eq!(payload.message.as_deref(), Some("成功"));
|
||
assert_eq!(payload.success, Some(true));
|
||
assert_eq!(
|
||
payload
|
||
.model
|
||
.as_ref()
|
||
.and_then(|model| model.verify_result.as_deref()),
|
||
Some("PASS")
|
||
);
|
||
}
|
||
}
|