Files
Genarrative/server-rs/crates/platform-auth/src/lib.rs

2270 lines
75 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use std::{
collections::{BTreeMap, HashSet},
error::Error,
fmt,
};
use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier, password_hash::SaltString};
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
use hmac::{Hmac, Mac};
use jsonwebtoken::{
Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode, errors::ErrorKind,
};
use rand_core::OsRng;
use reqwest::{Client, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use sha1::Sha1;
use sha2::{Digest, Sha256};
use shared_kernel::{new_uuid_simple_string, normalize_optional_string, normalize_required_string};
use time::{Duration, OffsetDateTime};
use tracing::{info, warn};
use url::Url;
pub const ACCESS_TOKEN_ALGORITHM: Algorithm = Algorithm::HS256;
pub const DEFAULT_ACCESS_TOKEN_TTL_SECONDS: u64 = 2 * 60 * 60;
pub const DEFAULT_REFRESH_COOKIE_NAME: &str = "genarrative_refresh_session";
pub const DEFAULT_REFRESH_COOKIE_PATH: &str = "/api/auth";
pub const DEFAULT_REFRESH_SESSION_TTL_DAYS: u32 = 30;
pub const DEFAULT_SMS_ENDPOINT: &str = "dypnsapi.aliyuncs.com";
pub const DEFAULT_SMS_COUNTRY_CODE: &str = "86";
pub const DEFAULT_SMS_TEMPLATE_PARAM_KEY: &str = "code";
pub const DEFAULT_SMS_MOCK_VERIFY_CODE: &str = "123456";
pub const DEFAULT_SMS_CODE_LENGTH: u8 = 6;
pub const DEFAULT_SMS_CODE_TYPE: u8 = 1;
pub const DEFAULT_SMS_VALID_TIME_SECONDS: u64 = 300;
pub const DEFAULT_SMS_INTERVAL_SECONDS: u64 = 60;
pub const DEFAULT_SMS_DUPLICATE_POLICY: u8 = 1;
pub const DEFAULT_SMS_CASE_AUTH_POLICY: u8 = 1;
pub const DEFAULT_WECHAT_AUTHORIZE_ENDPOINT: &str = "https://open.weixin.qq.com/connect/qrconnect";
pub const DEFAULT_WECHAT_IN_APP_AUTHORIZE_ENDPOINT: &str =
"https://open.weixin.qq.com/connect/oauth2/authorize";
pub const DEFAULT_WECHAT_ACCESS_TOKEN_ENDPOINT: &str =
"https://api.weixin.qq.com/sns/oauth2/access_token";
pub const DEFAULT_WECHAT_USER_INFO_ENDPOINT: &str = "https://api.weixin.qq.com/sns/userinfo";
pub const DEFAULT_WECHAT_JS_CODE_SESSION_ENDPOINT: &str =
"https://api.weixin.qq.com/sns/jscode2session";
type HmacSha1 = Hmac<Sha1>;
// 鉴权 provider 直接冻结成文档中约定的枚举,避免后续在多个 crate 内重复发明字符串字面量。
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AuthProvider {
Password,
Phone,
Wechat,
}
// 绑定状态只保留当前 JWT 需要透传的最小快照,不把完整账号状态枚举直接泄漏到 token 中。
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum BindingStatus {
Active,
PendingBindPhone,
}
// 用于签发 access token 的领域输入,和最终 JWT claims 解耦,避免业务层手动拼 iat/exp/iss。
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct AccessTokenClaimsInput {
pub user_id: String,
pub session_id: String,
pub provider: AuthProvider,
pub roles: Vec<String>,
pub token_version: u64,
pub phone_verified: bool,
pub binding_status: BindingStatus,
pub display_name: Option<String>,
}
// 直接映射最终 JWT payload字段名与文档冻结口径保持一致。
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct AccessTokenClaims {
pub iss: String,
pub sub: String,
pub sid: String,
pub provider: AuthProvider,
pub roles: Vec<String>,
pub ver: u64,
pub phone_verified: bool,
pub binding_status: BindingStatus,
#[serde(skip_serializing_if = "Option::is_none")]
pub display_name: Option<String>,
pub iat: u64,
pub exp: u64,
}
// 统一承载 JWT 配置,避免 secret、issuer、ttl 在 api-server 与后续模块里散落。
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct JwtConfig {
issuer: String,
secret: String,
access_token_ttl_seconds: u64,
}
// refresh cookie 的 SameSite 固定约束成枚举,避免各层直接使用大小写不一致的字符串。
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum RefreshCookieSameSite {
Lax,
Strict,
None,
}
// refresh cookie 的平台配置统一收口到 platform-auth避免 api-server 直接散落 cookie 细节。
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct RefreshCookieConfig {
cookie_name: String,
cookie_path: String,
cookie_secure: bool,
cookie_same_site: RefreshCookieSameSite,
refresh_session_ttl_days: u32,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum SmsAuthProviderKind {
Mock,
Aliyun,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct SmsAuthConfig {
pub provider: SmsAuthProviderKind,
pub endpoint: String,
pub access_key_id: Option<String>,
pub access_key_secret: Option<String>,
pub sign_name: String,
pub template_code: String,
pub template_param_key: String,
pub country_code: String,
pub scheme_name: Option<String>,
pub code_length: u8,
pub code_type: u8,
pub valid_time_seconds: u64,
pub interval_seconds: u64,
pub duplicate_policy: u8,
pub case_auth_policy: u8,
pub return_verify_code: bool,
pub mock_verify_code: String,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct SmsSendCodeRequest {
pub national_phone_number: String,
pub scene: String,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct SmsSendCodeResult {
pub cooldown_seconds: u64,
pub expires_in_seconds: u64,
pub provider_request_id: Option<String>,
pub provider_out_id: Option<String>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct SmsVerifyCodeRequest {
pub national_phone_number: String,
pub verify_code: String,
pub provider_out_id: Option<String>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum WechatAuthScene {
Desktop,
WechatInApp,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct WechatAuthConfig {
pub enabled: bool,
pub provider: String,
pub app_id: Option<String>,
pub app_secret: Option<String>,
pub mini_program_app_id: Option<String>,
pub mini_program_app_secret: Option<String>,
pub authorize_endpoint: String,
pub access_token_endpoint: String,
pub user_info_endpoint: String,
pub js_code_session_endpoint: String,
pub mock_user_id: String,
pub mock_union_id: Option<String>,
pub mock_display_name: String,
pub mock_avatar_url: Option<String>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct WechatIdentityProfile {
pub provider_uid: String,
pub provider_union_id: Option<String>,
pub display_name: Option<String>,
pub avatar_url: Option<String>,
}
#[derive(Clone, Debug)]
pub enum WechatProvider {
Disabled,
Mock(MockWechatProvider),
Real(RealWechatProvider),
}
#[derive(Clone, Debug)]
pub struct MockWechatProvider {
mock_user_id: String,
mock_union_id: Option<String>,
mock_display_name: String,
mock_avatar_url: Option<String>,
}
#[derive(Clone, Debug)]
pub struct RealWechatProvider {
client: Client,
app_id: Option<String>,
app_secret: Option<String>,
mini_program_app_id: Option<String>,
mini_program_app_secret: Option<String>,
authorize_endpoint: String,
access_token_endpoint: String,
user_info_endpoint: String,
js_code_session_endpoint: String,
}
#[derive(Clone, Debug)]
pub enum SmsAuthProvider {
Mock(MockSmsAuthProvider),
Aliyun(AliyunSmsAuthProvider),
}
#[derive(Clone, Debug)]
pub struct MockSmsAuthProvider {
config: SmsAuthConfig,
}
#[derive(Clone, Debug)]
pub struct AliyunSmsAuthProvider {
client: Client,
config: SmsAuthConfig,
}
#[derive(Debug, PartialEq, Eq)]
pub enum JwtError {
InvalidConfig(&'static str),
InvalidClaims(&'static str),
SignFailed(String),
VerifyFailed(String),
}
#[derive(Debug, PartialEq, Eq)]
pub enum RefreshCookieError {
InvalidConfig(&'static str),
}
#[derive(Debug, PartialEq, Eq)]
pub enum PasswordHashError {
HashFailed(String),
VerifyFailed(String),
}
#[derive(Debug, PartialEq, Eq)]
pub enum SmsProviderError {
InvalidConfig(String),
InvalidVerifyCode,
Upstream(String),
}
#[derive(Debug, PartialEq, Eq)]
pub enum WechatProviderError {
Disabled,
MissingCode,
InvalidConfig(String),
InvalidCallback(String),
RequestFailed(String),
DeserializeFailed(String),
Upstream(String),
MissingProfile(String),
}
// 鉴权平台错误统一先归类api-server 再决定 HTTP status 和错误 envelope。
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum AuthPlatformErrorKind {
InvalidConfig,
InvalidClaims,
SignFailed,
VerifyFailed,
CookieConfig,
HashFailed,
InvalidVerifyCode,
Disabled,
MissingCode,
InvalidCallback,
RequestFailed,
DeserializeFailed,
MissingProfile,
Upstream,
}
#[derive(Debug, Deserialize)]
struct WechatAccessTokenResponse {
access_token: Option<String>,
openid: Option<String>,
unionid: Option<String>,
errmsg: Option<String>,
}
#[derive(Debug, Deserialize)]
struct WechatUserInfoResponse {
openid: Option<String>,
unionid: Option<String>,
nickname: Option<String>,
headimgurl: Option<String>,
errmsg: Option<String>,
}
#[derive(Debug, Deserialize)]
struct WechatJsCodeSessionResponse {
openid: Option<String>,
unionid: Option<String>,
errcode: Option<i64>,
errmsg: Option<String>,
}
#[derive(Debug, Deserialize)]
struct AliyunSendSmsVerifyCodeResponse {
// 阿里云 RPC 原始 JSON 使用首字母大写字段名,这里必须显式映射,避免把成功响应误判成空值。
#[serde(default, rename = "Code")]
code: Option<String>,
#[serde(default, rename = "Message")]
message: Option<String>,
#[serde(default, rename = "RequestId")]
request_id: Option<String>,
#[serde(default, rename = "Success")]
success: Option<bool>,
#[serde(default, rename = "Model")]
model: Option<AliyunSendSmsVerifyCodeModel>,
}
#[derive(Debug, Deserialize)]
struct AliyunSendSmsVerifyCodeModel {
#[serde(default, rename = "BizId")]
_biz_id: Option<String>,
#[serde(default, rename = "OutId")]
out_id: Option<String>,
#[serde(default, rename = "RequestId")]
request_id: Option<String>,
}
#[derive(Debug, Deserialize)]
struct AliyunCheckSmsVerifyCodeResponse {
// 校验接口同样返回首字母大写字段名,保持和发送接口一致的显式映射。
#[serde(default, rename = "Code")]
code: Option<String>,
#[serde(default, rename = "Message")]
message: Option<String>,
#[serde(default, rename = "Success")]
success: Option<bool>,
#[serde(default, rename = "Model")]
model: Option<AliyunCheckSmsVerifyCodeModel>,
}
#[derive(Debug, Deserialize)]
struct AliyunCheckSmsVerifyCodeModel {
#[serde(default, rename = "OutId")]
_out_id: Option<String>,
#[serde(default, rename = "VerifyResult")]
verify_result: Option<String>,
}
impl JwtConfig {
pub fn new(
issuer: String,
secret: String,
access_token_ttl_seconds: u64,
) -> Result<Self, JwtError> {
let issuer = normalize_required_string(&issuer)
.ok_or(JwtError::InvalidConfig("JWT issuer 不能为空"))?;
let secret = normalize_required_string(&secret)
.ok_or(JwtError::InvalidConfig("JWT secret 不能为空"))?;
if access_token_ttl_seconds == 0 {
return Err(JwtError::InvalidConfig(
"JWT access token 过期时间必须大于 0",
));
}
Ok(Self {
issuer,
secret,
access_token_ttl_seconds,
})
}
pub fn issuer(&self) -> &str {
&self.issuer
}
pub fn access_token_ttl_seconds(&self) -> u64 {
self.access_token_ttl_seconds
}
}
impl RefreshCookieSameSite {
pub fn parse(raw: &str) -> Option<Self> {
match raw.trim().to_ascii_lowercase().as_str() {
"lax" => Some(Self::Lax),
"strict" => Some(Self::Strict),
"none" => Some(Self::None),
_ => None,
}
}
pub fn as_str(&self) -> &'static str {
match self {
Self::Lax => "Lax",
Self::Strict => "Strict",
Self::None => "None",
}
}
}
impl RefreshCookieConfig {
pub fn new(
cookie_name: String,
cookie_path: String,
cookie_secure: bool,
cookie_same_site: RefreshCookieSameSite,
refresh_session_ttl_days: u32,
) -> Result<Self, RefreshCookieError> {
let cookie_name = normalize_required_string(&cookie_name).ok_or(
RefreshCookieError::InvalidConfig("refresh cookie 名称不能为空"),
)?;
let cookie_path = normalize_required_string(&cookie_path).ok_or(
RefreshCookieError::InvalidConfig("refresh cookie path 不能为空"),
)?;
if refresh_session_ttl_days == 0 {
return Err(RefreshCookieError::InvalidConfig(
"refresh session TTL 天数必须大于 0",
));
}
Ok(Self {
cookie_name,
cookie_path,
cookie_secure,
cookie_same_site,
refresh_session_ttl_days,
})
}
pub fn cookie_name(&self) -> &str {
&self.cookie_name
}
pub fn cookie_path(&self) -> &str {
&self.cookie_path
}
pub fn cookie_secure(&self) -> bool {
self.cookie_secure
}
pub fn cookie_same_site(&self) -> &RefreshCookieSameSite {
&self.cookie_same_site
}
pub fn refresh_session_ttl_days(&self) -> u32 {
self.refresh_session_ttl_days
}
}
impl SmsAuthProviderKind {
pub fn parse(raw: &str) -> Option<Self> {
match raw.trim().to_ascii_lowercase().as_str() {
"mock" => Some(Self::Mock),
"aliyun" => Some(Self::Aliyun),
_ => None,
}
}
pub fn as_str(&self) -> &'static str {
match self {
Self::Mock => "mock",
Self::Aliyun => "aliyun",
}
}
}
impl SmsAuthConfig {
pub fn new(
provider: SmsAuthProviderKind,
endpoint: String,
access_key_id: Option<String>,
access_key_secret: Option<String>,
sign_name: String,
template_code: String,
template_param_key: String,
country_code: String,
scheme_name: Option<String>,
code_length: u8,
code_type: u8,
valid_time_seconds: u64,
interval_seconds: u64,
duplicate_policy: u8,
case_auth_policy: u8,
return_verify_code: bool,
mock_verify_code: String,
) -> Result<Self, SmsProviderError> {
let endpoint = normalize_required_string(&endpoint)
.unwrap_or_else(|| DEFAULT_SMS_ENDPOINT.to_string());
let template_param_key = normalize_required_string(&template_param_key)
.unwrap_or_else(|| DEFAULT_SMS_TEMPLATE_PARAM_KEY.to_string());
let country_code = normalize_required_string(&country_code)
.unwrap_or_else(|| DEFAULT_SMS_COUNTRY_CODE.to_string());
let scheme_name = normalize_optional_string(scheme_name);
let mock_verify_code = normalize_required_string(&mock_verify_code)
.unwrap_or_else(|| DEFAULT_SMS_MOCK_VERIFY_CODE.to_string());
if !(4..=8).contains(&code_length) {
return Err(SmsProviderError::InvalidConfig(
"短信验证码长度必须在 4 到 8 之间".to_string(),
));
}
if !(1..=7).contains(&code_type) {
return Err(SmsProviderError::InvalidConfig(
"短信验证码类型取值非法".to_string(),
));
}
if interval_seconds == 0 || valid_time_seconds == 0 {
return Err(SmsProviderError::InvalidConfig(
"短信验证码有效期和发送间隔必须大于 0".to_string(),
));
}
if !(1..=2).contains(&duplicate_policy) {
return Err(SmsProviderError::InvalidConfig(
"短信验证码重复策略取值非法".to_string(),
));
}
if !(1..=2).contains(&case_auth_policy) {
return Err(SmsProviderError::InvalidConfig(
"短信验证码大小写校验策略取值非法".to_string(),
));
}
match provider {
SmsAuthProviderKind::Mock => {}
SmsAuthProviderKind::Aliyun => {
if normalize_required_string(&sign_name).is_none() {
return Err(SmsProviderError::InvalidConfig(
"阿里云短信签名不能为空".to_string(),
));
}
if normalize_required_string(&template_code).is_none() {
return Err(SmsProviderError::InvalidConfig(
"阿里云短信模板编码不能为空".to_string(),
));
}
if access_key_id
.as_deref()
.and_then(normalize_required_string)
.is_none()
|| access_key_secret
.as_deref()
.and_then(normalize_required_string)
.is_none()
{
return Err(SmsProviderError::InvalidConfig(
"阿里云短信 AccessKey 未配置".to_string(),
));
}
}
}
Ok(Self {
provider,
endpoint,
access_key_id: access_key_id.and_then(|value| normalize_required_string(&value)),
access_key_secret: access_key_secret
.and_then(|value| normalize_required_string(&value)),
sign_name: sign_name.trim().to_string(),
template_code: template_code.trim().to_string(),
template_param_key,
country_code,
scheme_name,
code_length,
code_type,
valid_time_seconds,
interval_seconds,
duplicate_policy,
case_auth_policy,
return_verify_code,
mock_verify_code,
})
}
}
impl SmsAuthProvider {
pub fn new(config: SmsAuthConfig) -> Result<Self, SmsProviderError> {
match config.provider {
SmsAuthProviderKind::Mock => Ok(Self::Mock(MockSmsAuthProvider { config })),
SmsAuthProviderKind::Aliyun => Ok(Self::Aliyun(AliyunSmsAuthProvider {
client: Client::new(),
config,
})),
}
}
pub fn kind(&self) -> SmsAuthProviderKind {
match self {
Self::Mock(_) => SmsAuthProviderKind::Mock,
Self::Aliyun(_) => SmsAuthProviderKind::Aliyun,
}
}
pub async fn send_code(
&self,
request: SmsSendCodeRequest,
) -> Result<SmsSendCodeResult, SmsProviderError> {
match self {
Self::Mock(provider) => provider.send_code(request).await,
Self::Aliyun(provider) => provider.send_code(request).await,
}
}
pub async fn verify_code(&self, request: SmsVerifyCodeRequest) -> Result<(), SmsProviderError> {
match self {
Self::Mock(provider) => provider.verify_code(request).await,
Self::Aliyun(provider) => provider.verify_code(request).await,
}
}
}
impl WechatAuthConfig {
#[allow(clippy::too_many_arguments)]
pub fn new(
enabled: bool,
provider: String,
app_id: Option<String>,
app_secret: Option<String>,
mini_program_app_id: Option<String>,
mini_program_app_secret: Option<String>,
authorize_endpoint: String,
access_token_endpoint: String,
user_info_endpoint: String,
js_code_session_endpoint: String,
mock_user_id: String,
mock_union_id: Option<String>,
mock_display_name: String,
mock_avatar_url: Option<String>,
) -> Self {
Self {
enabled,
provider,
app_id,
app_secret,
mini_program_app_id,
mini_program_app_secret,
authorize_endpoint,
access_token_endpoint,
user_info_endpoint,
js_code_session_endpoint,
mock_user_id,
mock_union_id,
mock_display_name,
mock_avatar_url,
}
}
}
impl WechatProvider {
pub fn new(config: WechatAuthConfig) -> Self {
if !config.enabled {
return Self::Disabled;
}
if config.provider.trim().eq_ignore_ascii_case("mock") {
return Self::Mock(MockWechatProvider {
mock_user_id: config.mock_user_id,
mock_union_id: config.mock_union_id,
mock_display_name: config.mock_display_name,
mock_avatar_url: config.mock_avatar_url,
});
}
let has_web_oauth_config = config
.app_id
.as_ref()
.is_some_and(|value| !value.is_empty())
&& config
.app_secret
.as_ref()
.is_some_and(|value| !value.is_empty());
let has_mini_program_config = config
.mini_program_app_id
.as_ref()
.is_some_and(|value| !value.is_empty())
&& config
.mini_program_app_secret
.as_ref()
.is_some_and(|value| !value.is_empty());
if !has_web_oauth_config && !has_mini_program_config {
return Self::Disabled;
};
Self::Real(RealWechatProvider {
client: Client::new(),
app_id: config.app_id,
app_secret: config.app_secret,
mini_program_app_id: config.mini_program_app_id,
mini_program_app_secret: config.mini_program_app_secret,
authorize_endpoint: config.authorize_endpoint,
access_token_endpoint: config.access_token_endpoint,
user_info_endpoint: config.user_info_endpoint,
js_code_session_endpoint: config.js_code_session_endpoint,
})
}
pub fn build_authorization_url(
&self,
callback_url: &str,
state: &str,
scene: &WechatAuthScene,
) -> Result<String, WechatProviderError> {
match self {
Self::Disabled => Err(WechatProviderError::Disabled),
Self::Mock(_) => build_mock_wechat_authorization_url(callback_url, state),
Self::Real(provider) => provider.build_authorization_url(callback_url, state, scene),
}
}
pub async fn resolve_callback_profile(
&self,
code: Option<&str>,
mock_code: Option<&str>,
) -> Result<WechatIdentityProfile, WechatProviderError> {
match self {
Self::Disabled => Err(WechatProviderError::Disabled),
Self::Mock(provider) => Ok(provider.resolve_callback_profile(mock_code)),
Self::Real(provider) => provider.resolve_callback_profile(code).await,
}
}
pub async fn resolve_mini_program_login_profile(
&self,
code: Option<&str>,
) -> Result<WechatIdentityProfile, WechatProviderError> {
match self {
Self::Disabled => Err(WechatProviderError::Disabled),
Self::Mock(provider) => Ok(provider.resolve_callback_profile(code)),
Self::Real(provider) => provider.resolve_mini_program_login_profile(code).await,
}
}
}
impl MockWechatProvider {
fn resolve_callback_profile(&self, mock_code: Option<&str>) -> WechatIdentityProfile {
let provider_uid = mock_code
.map(str::trim)
.filter(|value| !value.is_empty())
.unwrap_or(self.mock_user_id.as_str())
.to_string();
WechatIdentityProfile {
provider_uid,
provider_union_id: self.mock_union_id.clone(),
display_name: Some(self.mock_display_name.clone()),
avatar_url: self.mock_avatar_url.clone(),
}
}
}
impl RealWechatProvider {
fn build_authorization_url(
&self,
callback_url: &str,
state: &str,
scene: &WechatAuthScene,
) -> Result<String, WechatProviderError> {
let endpoint = match scene {
WechatAuthScene::Desktop => &self.authorize_endpoint,
WechatAuthScene::WechatInApp => DEFAULT_WECHAT_IN_APP_AUTHORIZE_ENDPOINT,
};
let mut url = Url::parse(endpoint).map_err(|error| {
WechatProviderError::InvalidConfig(format!("微信授权地址非法:{error}"))
})?;
let app_id = self.app_id.as_ref().ok_or_else(|| {
WechatProviderError::InvalidConfig("微信开放平台 AppID 未配置".to_string())
})?;
url.query_pairs_mut()
.append_pair("appid", app_id)
.append_pair("redirect_uri", callback_url)
.append_pair("response_type", "code")
.append_pair(
"scope",
match scene {
WechatAuthScene::Desktop => "snsapi_login",
WechatAuthScene::WechatInApp => "snsapi_userinfo",
},
)
.append_pair("state", state);
Ok(format!("{url}#wechat_redirect"))
}
async fn resolve_callback_profile(
&self,
code: Option<&str>,
) -> Result<WechatIdentityProfile, WechatProviderError> {
let code = code
.map(str::trim)
.filter(|value| !value.is_empty())
.ok_or(WechatProviderError::MissingCode)?;
let app_id = self.app_id.as_ref().ok_or_else(|| {
WechatProviderError::InvalidConfig("微信开放平台 AppID 未配置".to_string())
})?;
let app_secret = self.app_secret.as_ref().ok_or_else(|| {
WechatProviderError::InvalidConfig("微信开放平台 AppSecret 未配置".to_string())
})?;
let mut access_token_url = Url::parse(&self.access_token_endpoint).map_err(|error| {
WechatProviderError::InvalidConfig(format!("微信 access_token 地址非法:{error}"))
})?;
access_token_url
.query_pairs_mut()
.append_pair("appid", app_id)
.append_pair("secret", app_secret)
.append_pair("code", code)
.append_pair("grant_type", "authorization_code");
let access_token_payload = self
.client
.get(access_token_url.as_str())
.send()
.await
.map_err(|error| {
warn!(error = %error, "微信 access_token 请求失败");
WechatProviderError::RequestFailed(
"微信登录失败access_token 请求失败".to_string(),
)
})?
.json::<WechatAccessTokenResponse>()
.await
.map_err(|error| {
warn!(error = %error, "微信 access_token 响应解析失败");
WechatProviderError::DeserializeFailed(
"微信登录失败access_token 响应非法".to_string(),
)
})?;
let access_token = access_token_payload
.access_token
.filter(|value| !value.trim().is_empty())
.ok_or_else(|| {
WechatProviderError::Upstream(format!(
"微信登录失败:{}",
access_token_payload
.errmsg
.unwrap_or_else(|| "缺少 access_token".to_string())
))
})?;
let openid = access_token_payload
.openid
.filter(|value| !value.trim().is_empty())
.ok_or_else(|| {
WechatProviderError::MissingProfile("微信登录失败:缺少 openid".to_string())
})?;
let mut user_info_url = Url::parse(&self.user_info_endpoint).map_err(|error| {
WechatProviderError::InvalidConfig(format!("微信用户信息地址非法:{error}"))
})?;
user_info_url
.query_pairs_mut()
.append_pair("access_token", &access_token)
.append_pair("openid", &openid)
.append_pair("lang", "zh_CN");
let user_info_payload = self
.client
.get(user_info_url.as_str())
.send()
.await
.map_err(|error| {
warn!(error = %error, "微信用户信息请求失败");
WechatProviderError::RequestFailed("微信登录失败:用户信息请求失败".to_string())
})?
.json::<WechatUserInfoResponse>()
.await
.map_err(|error| {
warn!(error = %error, "微信用户信息响应解析失败");
WechatProviderError::DeserializeFailed("微信登录失败:用户信息响应非法".to_string())
})?;
let provider_uid = user_info_payload
.openid
.filter(|value| !value.trim().is_empty())
.ok_or_else(|| {
WechatProviderError::Upstream(format!(
"微信登录失败:{}",
user_info_payload
.errmsg
.unwrap_or_else(|| "缺少 openid".to_string())
))
})?;
Ok(WechatIdentityProfile {
provider_uid,
provider_union_id: user_info_payload.unionid.or(access_token_payload.unionid),
display_name: user_info_payload.nickname,
avatar_url: user_info_payload.headimgurl,
})
}
async fn resolve_mini_program_login_profile(
&self,
code: Option<&str>,
) -> Result<WechatIdentityProfile, WechatProviderError> {
let code = code
.map(str::trim)
.filter(|value| !value.is_empty())
.ok_or(WechatProviderError::MissingCode)?;
let app_id = self
.mini_program_app_id
.as_ref()
.or(self.app_id.as_ref())
.ok_or_else(|| {
WechatProviderError::InvalidConfig("微信小程序 AppID 未配置".to_string())
})?;
let app_secret = self
.mini_program_app_secret
.as_ref()
.or(self.app_secret.as_ref())
.ok_or_else(|| {
WechatProviderError::InvalidConfig("微信小程序 AppSecret 未配置".to_string())
})?;
let mut js_code_session_url =
Url::parse(&self.js_code_session_endpoint).map_err(|error| {
WechatProviderError::InvalidConfig(format!("微信 jscode2session 地址非法:{error}"))
})?;
js_code_session_url
.query_pairs_mut()
.append_pair("appid", app_id)
.append_pair("secret", app_secret)
.append_pair("js_code", code)
.append_pair("grant_type", "authorization_code");
let payload = self
.client
.get(js_code_session_url.as_str())
.send()
.await
.map_err(|error| {
warn!(error = %error, "微信小程序 jscode2session 请求失败");
WechatProviderError::RequestFailed(
"微信小程序登录失败jscode2session 请求失败".to_string(),
)
})?
.json::<WechatJsCodeSessionResponse>()
.await
.map_err(|error| {
warn!(error = %error, "微信小程序 jscode2session 响应解析失败");
WechatProviderError::DeserializeFailed(
"微信小程序登录失败jscode2session 响应非法".to_string(),
)
})?;
if let Some(errcode) = payload.errcode.filter(|value| *value != 0) {
return Err(WechatProviderError::Upstream(format!(
"微信小程序登录失败:{}",
payload
.errmsg
.unwrap_or_else(|| format!("jscode2session 返回错误 {errcode}"))
)));
}
let provider_uid = payload
.openid
.filter(|value| !value.trim().is_empty())
.ok_or_else(|| {
WechatProviderError::MissingProfile("微信小程序登录失败:缺少 openid".to_string())
})?;
Ok(WechatIdentityProfile {
provider_uid,
provider_union_id: payload.unionid,
display_name: None,
avatar_url: None,
})
}
}
fn build_mock_wechat_authorization_url(
callback_url: &str,
state: &str,
) -> Result<String, WechatProviderError> {
let mut callback = Url::parse(callback_url).map_err(|error| {
WechatProviderError::InvalidCallback(format!("微信回调地址非法:{error}"))
})?;
callback
.query_pairs_mut()
.append_pair("mock_code", "wx-mock-code")
.append_pair("state", state);
Ok(callback.to_string())
}
impl MockSmsAuthProvider {
async fn send_code(
&self,
request: SmsSendCodeRequest,
) -> Result<SmsSendCodeResult, SmsProviderError> {
let provider_out_id =
build_sms_provider_out_id(&request.scene, &request.national_phone_number);
Ok(SmsSendCodeResult {
cooldown_seconds: self.config.interval_seconds,
expires_in_seconds: self.config.valid_time_seconds,
provider_request_id: Some("mock-request-id".to_string()),
provider_out_id: Some(provider_out_id),
})
}
async fn verify_code(&self, request: SmsVerifyCodeRequest) -> Result<(), SmsProviderError> {
if request.verify_code.trim() != self.config.mock_verify_code {
return Err(SmsProviderError::InvalidVerifyCode);
}
Ok(())
}
}
impl AliyunSmsAuthProvider {
async fn send_code(
&self,
request: SmsSendCodeRequest,
) -> Result<SmsSendCodeResult, SmsProviderError> {
let provider_out_id =
build_sms_provider_out_id(&request.scene, &request.national_phone_number);
let phone_masked = mask_phone_number(&request.national_phone_number);
let template_param = serde_json::json!({
self.config.template_param_key.clone(): "##code##",
"min": self.config.valid_time_seconds,
})
.to_string();
info!(
provider = "aliyun",
scene = request.scene.as_str(),
phone_masked = phone_masked.as_str(),
endpoint = self.config.endpoint.as_str(),
sign_name = self.config.sign_name.as_str(),
template_code = self.config.template_code.as_str(),
code_length = self.config.code_length,
valid_time_seconds = self.config.valid_time_seconds,
interval_seconds = self.config.interval_seconds,
provider_out_id = provider_out_id.as_str(),
"准备调用阿里云短信发送接口"
);
let mut query = BTreeMap::new();
query.insert("Action".to_string(), "SendSmsVerifyCode".to_string());
query.insert("Format".to_string(), "json".to_string());
query.insert("Version".to_string(), "2017-05-25".to_string());
query.insert("Timestamp".to_string(), current_aliyun_timestamp());
query.insert("SignatureNonce".to_string(), new_uuid_simple_string());
query.insert("SignatureMethod".to_string(), "HMAC-SHA1".to_string());
query.insert("SignatureVersion".to_string(), "1.0".to_string());
query.insert(
"AccessKeyId".to_string(),
self.config.access_key_id.clone().unwrap_or_default(),
);
query.insert(
"PhoneNumber".to_string(),
request.national_phone_number.trim().to_string(),
);
query.insert("CountryCode".to_string(), self.config.country_code.clone());
query.insert("SignName".to_string(), self.config.sign_name.clone());
query.insert(
"TemplateCode".to_string(),
self.config.template_code.clone(),
);
query.insert("TemplateParam".to_string(), template_param);
query.insert(
"CodeLength".to_string(),
self.config.code_length.to_string(),
);
query.insert("CodeType".to_string(), self.config.code_type.to_string());
query.insert(
"ValidTime".to_string(),
self.config.valid_time_seconds.to_string(),
);
query.insert(
"Interval".to_string(),
self.config.interval_seconds.to_string(),
);
query.insert(
"DuplicatePolicy".to_string(),
self.config.duplicate_policy.to_string(),
);
query.insert(
"ReturnVerifyCode".to_string(),
self.config.return_verify_code.to_string(),
);
query.insert("OutId".to_string(), provider_out_id.clone());
if let Some(scheme_name) = self.config.scheme_name.clone() {
query.insert("SchemeName".to_string(), scheme_name);
}
self.sign_query(&mut query)?;
let payload = self
.client
.post(build_aliyun_sms_url(&self.config.endpoint)?)
.form(&query)
.send()
.await
.map_err(|error| SmsProviderError::Upstream(format!("短信验证码发送失败:{error}")))?;
let http_status = payload.status();
let body = parse_aliyun_json_response(payload, "短信验证码发送失败").await?;
info!(
provider = "aliyun",
scene = request.scene.as_str(),
phone_masked = phone_masked.as_str(),
http_status = http_status.as_u16(),
provider_code = body.code.as_deref().unwrap_or("unknown"),
provider_message = body.message.as_deref().unwrap_or("unknown"),
provider_request_id = body
.request_id
.as_deref()
.or_else(|| body
.model
.as_ref()
.and_then(|model| model.request_id.as_deref()))
.unwrap_or("unknown"),
provider_out_id = body
.model
.as_ref()
.and_then(|model| model.out_id.as_deref())
.unwrap_or("unknown"),
success = body.success.unwrap_or(false),
"阿里云短信发送接口返回响应"
);
if !body.success.unwrap_or(false) || body.code.as_deref() != Some("OK") {
warn!(
provider = "aliyun",
scene = request.scene.as_str(),
phone_masked = phone_masked.as_str(),
http_status = http_status.as_u16(),
provider_code = body.code.as_deref().unwrap_or("unknown"),
provider_message = body.message.as_deref().unwrap_or("unknown"),
provider_request_id = body
.request_id
.as_deref()
.or_else(|| body
.model
.as_ref()
.and_then(|model| model.request_id.as_deref()))
.unwrap_or("unknown"),
provider_out_id = body
.model
.as_ref()
.and_then(|model| model.out_id.as_deref())
.unwrap_or("unknown"),
"阿里云短信发送接口返回业务失败"
);
return Err(map_aliyun_provider_error(
"短信验证码发送失败",
body.message,
body.code,
));
}
Ok(SmsSendCodeResult {
cooldown_seconds: self.config.interval_seconds,
expires_in_seconds: self.config.valid_time_seconds,
provider_request_id: body.request_id.or_else(|| {
body.model
.as_ref()
.and_then(|model| model.request_id.clone())
}),
provider_out_id: body.model.and_then(|model| model.out_id),
})
}
async fn verify_code(&self, request: SmsVerifyCodeRequest) -> Result<(), SmsProviderError> {
let mut query = BTreeMap::new();
query.insert("Action".to_string(), "CheckSmsVerifyCode".to_string());
query.insert("Format".to_string(), "json".to_string());
query.insert("Version".to_string(), "2017-05-25".to_string());
query.insert("Timestamp".to_string(), current_aliyun_timestamp());
query.insert("SignatureNonce".to_string(), new_uuid_simple_string());
query.insert("SignatureMethod".to_string(), "HMAC-SHA1".to_string());
query.insert("SignatureVersion".to_string(), "1.0".to_string());
query.insert(
"AccessKeyId".to_string(),
self.config.access_key_id.clone().unwrap_or_default(),
);
query.insert(
"PhoneNumber".to_string(),
request.national_phone_number.trim().to_string(),
);
query.insert("CountryCode".to_string(), self.config.country_code.clone());
query.insert(
"VerifyCode".to_string(),
request.verify_code.trim().to_string(),
);
query.insert(
"CaseAuthPolicy".to_string(),
self.config.case_auth_policy.to_string(),
);
if let Some(scheme_name) = self.config.scheme_name.clone() {
query.insert("SchemeName".to_string(), scheme_name);
}
if let Some(provider_out_id) = request.provider_out_id {
query.insert("OutId".to_string(), provider_out_id);
}
self.sign_query(&mut query)?;
let payload = self
.client
.post(build_aliyun_sms_url(&self.config.endpoint)?)
.form(&query)
.send()
.await
.map_err(|error| SmsProviderError::Upstream(format!("验证码校验失败:{error}")))?;
let body = parse_aliyun_json_response_for_verify(payload).await?;
if !body.success.unwrap_or(false) || body.code.as_deref() != Some("OK") {
return Err(map_aliyun_provider_error(
"验证码校验失败",
body.message,
body.code,
));
}
if body.model.and_then(|model| model.verify_result).as_deref() != Some("PASS") {
return Err(SmsProviderError::InvalidVerifyCode);
}
Ok(())
}
fn sign_query(&self, query: &mut BTreeMap<String, String>) -> Result<(), SmsProviderError> {
let access_key_secret = self.config.access_key_secret.as_deref().ok_or_else(|| {
SmsProviderError::InvalidConfig("阿里云短信 AccessKeySecret 未配置".to_string())
})?;
let canonicalized = canonicalize_aliyun_rpc_params(query);
let string_to_sign = format!(
"POST&{}&{}",
aliyun_percent_encode("/"),
aliyun_percent_encode(&canonicalized)
);
let mut signer = HmacSha1::new_from_slice(format!("{access_key_secret}&").as_bytes())
.map_err(|error| {
SmsProviderError::InvalidConfig(format!("初始化短信签名器失败:{error}"))
})?;
signer.update(string_to_sign.as_bytes());
let signature = BASE64_STANDARD.encode(signer.finalize().into_bytes());
query.insert("Signature".to_string(), signature);
Ok(())
}
}
impl AccessTokenClaims {
pub fn from_input(
input: AccessTokenClaimsInput,
config: &JwtConfig,
issued_at: OffsetDateTime,
) -> Result<Self, JwtError> {
let user_id = normalize_required_field(input.user_id, "JWT sub 不能为空")?;
let session_id = normalize_required_field(input.session_id, "JWT sid 不能为空")?;
let roles = normalize_roles(input.roles)?;
let display_name = normalize_optional_field(input.display_name);
let issued_at_unix = issued_at.unix_timestamp();
if issued_at_unix < 0 {
return Err(JwtError::InvalidClaims("JWT iat 不能早于 Unix epoch"));
}
let expires_at = issued_at
.checked_add(Duration::seconds(
i64::try_from(config.access_token_ttl_seconds()).map_err(|_| {
JwtError::InvalidConfig("JWT access token 过期时间超出 i64 上限")
})?,
))
.ok_or(JwtError::InvalidConfig("JWT 过期时间计算溢出"))?;
let expires_at_unix = expires_at.unix_timestamp();
if expires_at_unix <= issued_at_unix {
return Err(JwtError::InvalidClaims("JWT exp 必须晚于 iat"));
}
let claims = Self {
iss: config.issuer().to_string(),
sub: user_id,
sid: session_id,
provider: input.provider,
roles,
ver: input.token_version,
phone_verified: input.phone_verified,
binding_status: input.binding_status,
display_name,
iat: issued_at_unix as u64,
exp: expires_at_unix as u64,
};
claims.validate_for_config(config)?;
Ok(claims)
}
pub fn user_id(&self) -> &str {
&self.sub
}
pub fn session_id(&self) -> &str {
&self.sid
}
pub fn token_version(&self) -> u64 {
self.ver
}
pub fn validate_for_config(&self, config: &JwtConfig) -> Result<(), JwtError> {
if self.iss.trim() != config.issuer() {
return Err(JwtError::InvalidClaims("JWT iss 与当前配置不一致"));
}
normalize_required_field(self.sub.clone(), "JWT sub 不能为空")?;
normalize_required_field(self.sid.clone(), "JWT sid 不能为空")?;
normalize_roles(self.roles.clone())?;
if self.exp <= self.iat {
return Err(JwtError::InvalidClaims("JWT exp 必须晚于 iat"));
}
Ok(())
}
}
pub fn sign_access_token(
claims: &AccessTokenClaims,
config: &JwtConfig,
) -> Result<String, JwtError> {
claims.validate_for_config(config)?;
let header = Header {
alg: ACCESS_TOKEN_ALGORITHM,
typ: Some("JWT".to_string()),
..Header::default()
};
encode(
&header,
claims,
&EncodingKey::from_secret(config.secret.as_bytes()),
)
.map_err(|error| JwtError::SignFailed(format!("JWT 签发失败:{error}")))
}
pub fn verify_access_token(token: &str, config: &JwtConfig) -> Result<AccessTokenClaims, JwtError> {
let token = token.trim();
if token.is_empty() {
return Err(JwtError::VerifyFailed("JWT 不能为空".to_string()));
}
let mut validation = Validation::new(ACCESS_TOKEN_ALGORITHM);
validation.required_spec_claims = HashSet::from([
"exp".to_string(),
"iat".to_string(),
"iss".to_string(),
"sub".to_string(),
]);
validation.set_issuer(&[config.issuer()]);
let decoded = decode::<AccessTokenClaims>(
token,
&DecodingKey::from_secret(config.secret.as_bytes()),
&validation,
)
.map_err(map_verify_error)?;
decoded.claims.validate_for_config(config)?;
Ok(decoded.claims)
}
pub fn read_refresh_session_token(
cookie_header: &str,
config: &RefreshCookieConfig,
) -> Option<String> {
let cookie_header = cookie_header.trim();
if cookie_header.is_empty() {
return None;
}
for entry in cookie_header.split(';') {
let entry = entry.trim();
if entry.is_empty() {
continue;
}
let (raw_name, raw_value) = entry.split_once('=')?;
if raw_name.trim() != config.cookie_name() {
continue;
}
let raw_value = raw_value.trim();
if raw_value.is_empty() {
return None;
}
return urlencoding::decode(raw_value)
.ok()
.map(|decoded| decoded.into_owned());
}
None
}
pub async fn hash_password(password: &str) -> Result<String, PasswordHashError> {
let salt = SaltString::generate(&mut OsRng);
Argon2::default()
.hash_password(password.as_bytes(), &salt)
.map(|hash| hash.to_string())
.map_err(|error| PasswordHashError::HashFailed(format!("密码哈希失败:{error}")))
}
pub async fn verify_password(
password_hash: &str,
password: &str,
) -> Result<bool, PasswordHashError> {
let parsed_hash = PasswordHash::new(password_hash)
.map_err(|error| PasswordHashError::VerifyFailed(format!("密码哈希格式非法:{error}")))?;
Ok(Argon2::default()
.verify_password(password.as_bytes(), &parsed_hash)
.is_ok())
}
pub fn create_refresh_session_token() -> String {
new_uuid_simple_string()
}
pub fn hash_refresh_session_token(token: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(token.as_bytes());
format!("{:x}", hasher.finalize())
}
pub fn build_refresh_session_set_cookie(token: &str, config: &RefreshCookieConfig) -> String {
let mut parts = vec![
format!(
"{}={}",
config.cookie_name(),
urlencoding::encode(token).into_owned()
),
format!("Path={}", config.cookie_path()),
"HttpOnly".to_string(),
format!("SameSite={}", config.cookie_same_site().as_str()),
format!(
"Max-Age={}",
u64::from(config.refresh_session_ttl_days()) * 24 * 60 * 60
),
];
if config.cookie_secure() {
parts.push("Secure".to_string());
}
parts.join("; ")
}
pub fn build_refresh_session_clear_cookie(config: &RefreshCookieConfig) -> String {
let mut parts = vec![
format!("{}=", config.cookie_name()),
format!("Path={}", config.cookie_path()),
"HttpOnly".to_string(),
format!("SameSite={}", config.cookie_same_site().as_str()),
"Max-Age=0".to_string(),
];
if config.cookie_secure() {
parts.push("Secure".to_string());
}
parts.join("; ")
}
impl fmt::Display for JwtError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidConfig(message) | Self::InvalidClaims(message) => f.write_str(message),
Self::SignFailed(message) | Self::VerifyFailed(message) => f.write_str(message),
}
}
}
impl Error for JwtError {}
impl fmt::Display for RefreshCookieError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidConfig(message) => f.write_str(message),
}
}
}
impl Error for RefreshCookieError {}
impl fmt::Display for PasswordHashError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::HashFailed(message) | Self::VerifyFailed(message) => f.write_str(message),
}
}
}
impl Error for PasswordHashError {}
fn normalize_required_field(
value: String,
error_message: &'static str,
) -> Result<String, JwtError> {
normalize_required_string(&value).ok_or(JwtError::InvalidClaims(error_message))
}
fn normalize_optional_field(value: Option<String>) -> Option<String> {
normalize_optional_string(value)
}
fn normalize_roles(roles: Vec<String>) -> Result<Vec<String>, JwtError> {
let roles = roles
.into_iter()
.map(|role| role.trim().to_string())
.filter(|role| !role.is_empty())
.collect::<Vec<_>>();
if roles.is_empty() {
return Err(JwtError::InvalidClaims("JWT roles 至少包含一个角色"));
}
Ok(roles)
}
fn map_verify_error(error: jsonwebtoken::errors::Error) -> JwtError {
let message = match error.kind() {
ErrorKind::ExpiredSignature => "JWT 已过期".to_string(),
ErrorKind::InvalidIssuer => "JWT 发行者不匹配".to_string(),
ErrorKind::InvalidSignature => "JWT 签名无效".to_string(),
ErrorKind::InvalidAlgorithm => "JWT 算法不匹配".to_string(),
ErrorKind::InvalidToken => "JWT 非法".to_string(),
ErrorKind::ImmatureSignature => "JWT 尚未生效".to_string(),
ErrorKind::MissingRequiredClaim(claim) => format!("JWT 缺少必填字段:{claim}"),
_ => format!("JWT 校验失败:{error}"),
};
JwtError::VerifyFailed(message)
}
fn build_sms_provider_out_id(scene: &str, national_phone_number: &str) -> String {
let phone_suffix = national_phone_number
.chars()
.rev()
.take(4)
.collect::<String>()
.chars()
.rev()
.collect::<String>();
format!("{scene}_{}_{}", phone_suffix, new_uuid_simple_string())
}
fn build_aliyun_sms_url(endpoint: &str) -> Result<String, SmsProviderError> {
let endpoint = endpoint
.trim()
.trim_start_matches("https://")
.trim_start_matches("http://")
.trim_matches('/');
if endpoint.is_empty() {
return Err(SmsProviderError::InvalidConfig(
"阿里云短信 endpoint 不能为空".to_string(),
));
}
Ok(format!("https://{endpoint}/"))
}
fn current_aliyun_timestamp() -> String {
OffsetDateTime::now_utc()
.format(&time::format_description::well_known::Rfc3339)
.unwrap_or_else(|_| "1970-01-01T00:00:00Z".to_string())
}
fn canonicalize_aliyun_rpc_params(params: &BTreeMap<String, String>) -> String {
params
.iter()
.filter(|(key, _)| key.as_str() != "Signature")
.map(|(key, value)| {
format!(
"{}={}",
aliyun_percent_encode(key),
aliyun_percent_encode(value)
)
})
.collect::<Vec<_>>()
.join("&")
}
fn aliyun_percent_encode(value: &str) -> String {
urlencoding::encode(value)
.into_owned()
.replace('+', "%20")
.replace('*', "%2A")
.replace("%7E", "~")
}
async fn parse_aliyun_json_response(
response: reqwest::Response,
fallback_message: &str,
) -> Result<AliyunSendSmsVerifyCodeResponse, SmsProviderError> {
let status = response.status();
let body = response
.text()
.await
.map_err(|error| SmsProviderError::Upstream(format!("{fallback_message}{error}")))?;
let payload =
serde_json::from_str::<AliyunSendSmsVerifyCodeResponse>(&body).map_err(|error| {
SmsProviderError::Upstream(format!("{fallback_message}:响应解析失败:{error}"))
})?;
if status.is_client_error() || status.is_server_error() {
return Err(map_http_status_to_sms_provider_error(
fallback_message,
status,
serde_json::from_str::<Value>(&body).ok(),
));
}
Ok(payload)
}
async fn parse_aliyun_json_response_for_verify(
response: reqwest::Response,
) -> Result<AliyunCheckSmsVerifyCodeResponse, SmsProviderError> {
let status = response.status();
let body = response
.text()
.await
.map_err(|error| SmsProviderError::Upstream(format!("验证码校验失败:{error}")))?;
let payload =
serde_json::from_str::<AliyunCheckSmsVerifyCodeResponse>(&body).map_err(|error| {
SmsProviderError::Upstream(format!("验证码校验失败:响应解析失败:{error}"))
})?;
if status.is_client_error() || status.is_server_error() {
return Err(map_http_status_to_sms_provider_error(
"验证码校验失败",
status,
serde_json::from_str::<Value>(&body).ok(),
));
}
Ok(payload)
}
fn map_http_status_to_sms_provider_error(
fallback_message: &str,
status: StatusCode,
payload: Option<Value>,
) -> SmsProviderError {
let provider_message = payload
.as_ref()
.and_then(|value| value.get("Message").and_then(Value::as_str))
.unwrap_or_default();
let provider_code = payload
.as_ref()
.and_then(|value| value.get("Code").and_then(Value::as_str))
.unwrap_or_default();
if status.is_client_error() {
return map_aliyun_provider_error(
fallback_message,
Some(provider_message.to_string()),
Some(provider_code.to_string()),
);
}
SmsProviderError::Upstream(build_provider_error_message(
fallback_message,
provider_message,
))
}
fn map_aliyun_provider_error(
fallback_message: &str,
provider_message: Option<String>,
provider_code: Option<String>,
) -> SmsProviderError {
let provider_message = provider_message.unwrap_or_default();
let provider_code = provider_code.unwrap_or_default();
let normalized_code = provider_code.trim().to_ascii_uppercase();
if normalized_code.contains("VERIFY")
|| normalized_code.contains("CODE")
|| normalized_code.contains("CHECK")
{
return SmsProviderError::InvalidVerifyCode;
}
if normalized_code.contains("MOBILE")
|| normalized_code.contains("PHONE")
|| normalized_code.contains("SIGN")
|| normalized_code.contains("TEMPLATE")
|| normalized_code.contains("ACCESSKEY")
{
return SmsProviderError::InvalidConfig(build_provider_error_message(
fallback_message,
&provider_message,
));
}
SmsProviderError::Upstream(build_provider_error_message(
fallback_message,
&provider_message,
))
}
fn build_provider_error_message(prefix: &str, provider_message: &str) -> String {
let provider_message = provider_message.trim();
if provider_message.is_empty() {
prefix.to_string()
} else {
format!("{prefix}{provider_message}")
}
}
fn mask_phone_number(phone_number: &str) -> String {
let chars: Vec<char> = phone_number.chars().collect();
if chars.len() <= 4 {
return "*".repeat(chars.len().max(1));
}
let prefix_len = chars.len().min(3);
let suffix_len = 4.min(chars.len().saturating_sub(prefix_len));
let mask_len = chars.len().saturating_sub(prefix_len + suffix_len);
let mut masked = String::new();
masked.extend(chars.iter().take(prefix_len));
masked.push_str(&"*".repeat(mask_len.max(1)));
if suffix_len > 0 {
masked.extend(chars.iter().skip(chars.len() - suffix_len));
}
masked
}
impl fmt::Display for SmsProviderError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidConfig(message) | Self::Upstream(message) => f.write_str(message),
Self::InvalidVerifyCode => f.write_str("验证码错误"),
}
}
}
impl Error for SmsProviderError {}
impl fmt::Display for WechatProviderError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Disabled => f.write_str("微信登录暂未启用"),
Self::MissingCode => f.write_str("缺少微信授权 code"),
Self::InvalidConfig(message)
| Self::InvalidCallback(message)
| Self::RequestFailed(message)
| Self::DeserializeFailed(message)
| Self::Upstream(message)
| Self::MissingProfile(message) => f.write_str(message),
}
}
}
impl Error for WechatProviderError {}
impl JwtError {
pub fn kind(&self) -> AuthPlatformErrorKind {
match self {
Self::InvalidConfig(_) => AuthPlatformErrorKind::InvalidConfig,
Self::InvalidClaims(_) => AuthPlatformErrorKind::InvalidClaims,
Self::SignFailed(_) => AuthPlatformErrorKind::SignFailed,
Self::VerifyFailed(_) => AuthPlatformErrorKind::VerifyFailed,
}
}
}
impl RefreshCookieError {
pub fn kind(&self) -> AuthPlatformErrorKind {
match self {
Self::InvalidConfig(_) => AuthPlatformErrorKind::CookieConfig,
}
}
}
impl PasswordHashError {
pub fn kind(&self) -> AuthPlatformErrorKind {
match self {
Self::HashFailed(_) => AuthPlatformErrorKind::HashFailed,
Self::VerifyFailed(_) => AuthPlatformErrorKind::VerifyFailed,
}
}
}
impl SmsProviderError {
pub fn kind(&self) -> AuthPlatformErrorKind {
match self {
Self::InvalidConfig(_) => AuthPlatformErrorKind::InvalidConfig,
Self::InvalidVerifyCode => AuthPlatformErrorKind::InvalidVerifyCode,
Self::Upstream(_) => AuthPlatformErrorKind::Upstream,
}
}
}
impl WechatProviderError {
pub fn kind(&self) -> AuthPlatformErrorKind {
match self {
Self::Disabled => AuthPlatformErrorKind::Disabled,
Self::MissingCode => AuthPlatformErrorKind::MissingCode,
Self::InvalidConfig(_) => AuthPlatformErrorKind::InvalidConfig,
Self::InvalidCallback(_) => AuthPlatformErrorKind::InvalidCallback,
Self::RequestFailed(_) => AuthPlatformErrorKind::RequestFailed,
Self::DeserializeFailed(_) => AuthPlatformErrorKind::DeserializeFailed,
Self::Upstream(_) => AuthPlatformErrorKind::Upstream,
Self::MissingProfile(_) => AuthPlatformErrorKind::MissingProfile,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn auth_platform_error_kind_is_stable_for_adapter_mapping() {
assert_eq!(
JwtError::InvalidClaims("JWT roles 至少包含一个角色").kind(),
AuthPlatformErrorKind::InvalidClaims
);
assert_eq!(
PasswordHashError::VerifyFailed("密码校验失败".to_string()).kind(),
AuthPlatformErrorKind::VerifyFailed
);
assert_eq!(
SmsProviderError::InvalidVerifyCode.kind(),
AuthPlatformErrorKind::InvalidVerifyCode
);
assert_eq!(
WechatProviderError::MissingCode.kind(),
AuthPlatformErrorKind::MissingCode
);
}
#[test]
fn mock_wechat_provider_builds_callback_authorization_url() {
let provider = WechatProvider::new(WechatAuthConfig::new(
true,
"mock".to_string(),
None,
None,
None,
None,
DEFAULT_WECHAT_AUTHORIZE_ENDPOINT.to_string(),
DEFAULT_WECHAT_ACCESS_TOKEN_ENDPOINT.to_string(),
DEFAULT_WECHAT_USER_INFO_ENDPOINT.to_string(),
DEFAULT_WECHAT_JS_CODE_SESSION_ENDPOINT.to_string(),
"wx-user-001".to_string(),
Some("wx-union-001".to_string()),
"微信测试用户".to_string(),
Some("https://example.test/avatar.png".to_string()),
));
let authorization_url = provider
.build_authorization_url(
"http://127.0.0.1:3000/api/auth/wechat/callback",
"state_001",
&WechatAuthScene::Desktop,
)
.expect("mock authorization url should build");
assert!(authorization_url.contains("mock_code=wx-mock-code"));
assert!(authorization_url.contains("state=state_001"));
}
#[tokio::test]
async fn mock_wechat_provider_resolves_identity_profile() {
let provider = WechatProvider::new(WechatAuthConfig::new(
true,
"mock".to_string(),
None,
None,
None,
None,
DEFAULT_WECHAT_AUTHORIZE_ENDPOINT.to_string(),
DEFAULT_WECHAT_ACCESS_TOKEN_ENDPOINT.to_string(),
DEFAULT_WECHAT_USER_INFO_ENDPOINT.to_string(),
DEFAULT_WECHAT_JS_CODE_SESSION_ENDPOINT.to_string(),
"wx-user-001".to_string(),
Some("wx-union-001".to_string()),
"微信测试用户".to_string(),
None,
));
let profile = provider
.resolve_callback_profile(None, Some("wx-code-001"))
.await
.expect("mock profile should resolve");
assert_eq!(profile.provider_uid, "wx-code-001");
assert_eq!(profile.provider_union_id.as_deref(), Some("wx-union-001"));
assert_eq!(profile.display_name.as_deref(), Some("微信测试用户"));
}
fn build_jwt_config() -> JwtConfig {
JwtConfig::new(
"https://auth.genarrative.local".to_string(),
"genarrative-dev-secret".to_string(),
DEFAULT_ACCESS_TOKEN_TTL_SECONDS,
)
.expect("jwt config should be valid")
}
fn build_claims_input() -> AccessTokenClaimsInput {
AccessTokenClaimsInput {
user_id: "usr_123".to_string(),
session_id: "sess_456".to_string(),
provider: AuthProvider::Wechat,
roles: vec!["user".to_string()],
token_version: 3,
phone_verified: false,
binding_status: BindingStatus::PendingBindPhone,
display_name: Some("微信旅人".to_string()),
}
}
fn build_refresh_cookie_config() -> RefreshCookieConfig {
RefreshCookieConfig::new(
DEFAULT_REFRESH_COOKIE_NAME.to_string(),
DEFAULT_REFRESH_COOKIE_PATH.to_string(),
false,
RefreshCookieSameSite::Lax,
DEFAULT_REFRESH_SESSION_TTL_DAYS,
)
.expect("refresh cookie config should be valid")
}
fn build_mock_sms_config() -> SmsAuthConfig {
SmsAuthConfig::new(
SmsAuthProviderKind::Mock,
DEFAULT_SMS_ENDPOINT.to_string(),
None,
None,
String::new(),
String::new(),
DEFAULT_SMS_TEMPLATE_PARAM_KEY.to_string(),
DEFAULT_SMS_COUNTRY_CODE.to_string(),
None,
DEFAULT_SMS_CODE_LENGTH,
DEFAULT_SMS_CODE_TYPE,
DEFAULT_SMS_VALID_TIME_SECONDS,
DEFAULT_SMS_INTERVAL_SECONDS,
DEFAULT_SMS_DUPLICATE_POLICY,
DEFAULT_SMS_CASE_AUTH_POLICY,
false,
DEFAULT_SMS_MOCK_VERIFY_CODE.to_string(),
)
.expect("mock sms config should be valid")
}
#[test]
fn round_trip_sign_and_verify_access_token() {
let config = build_jwt_config();
let claims =
AccessTokenClaims::from_input(build_claims_input(), &config, OffsetDateTime::now_utc())
.expect("claims should build");
let token = sign_access_token(&claims, &config).expect("token should sign");
let verified = verify_access_token(&token, &config).expect("token should verify");
assert_eq!(verified, claims);
assert_eq!(verified.user_id(), "usr_123");
assert_eq!(verified.session_id(), "sess_456");
assert_eq!(verified.token_version(), 3);
}
#[test]
fn verify_rejects_invalid_issuer() {
let config = build_jwt_config();
let claims =
AccessTokenClaims::from_input(build_claims_input(), &config, OffsetDateTime::now_utc())
.expect("claims should build");
let token = sign_access_token(&claims, &config).expect("token should sign");
let other_config = JwtConfig::new(
"https://auth.other.local".to_string(),
"genarrative-dev-secret".to_string(),
DEFAULT_ACCESS_TOKEN_TTL_SECONDS,
)
.expect("other config should be valid");
let error = verify_access_token(&token, &other_config).expect_err("issuer should mismatch");
assert_eq!(
error,
JwtError::VerifyFailed("JWT 发行者不匹配".to_string())
);
}
#[test]
fn build_claims_rejects_empty_roles() {
let error = AccessTokenClaims::from_input(
AccessTokenClaimsInput {
roles: Vec::new(),
..build_claims_input()
},
&build_jwt_config(),
OffsetDateTime::now_utc(),
)
.expect_err("empty roles should be rejected");
assert_eq!(error, JwtError::InvalidClaims("JWT roles 至少包含一个角色"));
}
#[test]
fn read_refresh_session_token_returns_matching_cookie() {
let token = read_refresh_session_token(
"theme=dark; genarrative_refresh_session=refresh-token-01; locale=zh-CN",
&build_refresh_cookie_config(),
);
assert_eq!(token.as_deref(), Some("refresh-token-01"));
}
#[test]
fn read_refresh_session_token_decodes_urlencoded_value() {
let token = read_refresh_session_token(
"genarrative_refresh_session=refresh%2Ftoken%3D01",
&build_refresh_cookie_config(),
);
assert_eq!(token.as_deref(), Some("refresh/token=01"));
}
#[test]
fn read_refresh_session_token_returns_none_when_missing() {
let token =
read_refresh_session_token("theme=dark; locale=zh-CN", &build_refresh_cookie_config());
assert!(token.is_none());
}
#[tokio::test]
async fn hash_and_verify_password_round_trip() {
let password_hash = hash_password("secret123")
.await
.expect("password hash should build");
let is_valid = verify_password(&password_hash, "secret123")
.await
.expect("password hash should verify");
assert!(is_valid);
}
#[test]
fn build_refresh_session_cookie_respects_config() {
let cookie =
build_refresh_session_set_cookie("refresh/token=01", &build_refresh_cookie_config());
assert!(cookie.contains("genarrative_refresh_session=refresh%2Ftoken%3D01"));
assert!(cookie.contains("Path=/api/auth"));
assert!(cookie.contains("HttpOnly"));
assert!(cookie.contains("SameSite=Lax"));
assert!(cookie.contains("Max-Age=2592000"));
}
#[test]
fn hash_refresh_session_token_matches_sha256_hex() {
let hash = hash_refresh_session_token("refresh-token-01");
assert_eq!(
hash,
"9fab76f9100ec6c151c8caa0c42ab10e10fbc7618f15e24cf3dffc93e19c4c4e"
);
}
#[test]
fn build_refresh_session_clear_cookie_respects_config() {
let cookie = build_refresh_session_clear_cookie(&build_refresh_cookie_config());
assert!(cookie.contains("genarrative_refresh_session="));
assert!(cookie.contains("Path=/api/auth"));
assert!(cookie.contains("HttpOnly"));
assert!(cookie.contains("SameSite=Lax"));
assert!(cookie.contains("Max-Age=0"));
}
#[test]
fn sms_auth_provider_kind_parses_supported_values() {
assert_eq!(
SmsAuthProviderKind::parse("mock"),
Some(SmsAuthProviderKind::Mock)
);
assert_eq!(
SmsAuthProviderKind::parse("aliyun"),
Some(SmsAuthProviderKind::Aliyun)
);
assert_eq!(SmsAuthProviderKind::parse("other"), None);
}
#[tokio::test]
async fn mock_sms_provider_sends_and_verifies_code() {
let provider =
SmsAuthProvider::new(build_mock_sms_config()).expect("provider should build");
let send_result = provider
.send_code(SmsSendCodeRequest {
national_phone_number: "13800138000".to_string(),
scene: "login".to_string(),
})
.await
.expect("send code should succeed");
assert_eq!(send_result.cooldown_seconds, DEFAULT_SMS_INTERVAL_SECONDS);
assert_eq!(
send_result.expires_in_seconds,
DEFAULT_SMS_VALID_TIME_SECONDS
);
assert_eq!(
send_result.provider_request_id.as_deref(),
Some("mock-request-id")
);
assert!(send_result.provider_out_id.is_some());
provider
.verify_code(SmsVerifyCodeRequest {
national_phone_number: "13800138000".to_string(),
verify_code: DEFAULT_SMS_MOCK_VERIFY_CODE.to_string(),
provider_out_id: send_result.provider_out_id,
})
.await
.expect("verify code should succeed");
}
#[tokio::test]
async fn mock_sms_provider_rejects_wrong_code() {
let provider =
SmsAuthProvider::new(build_mock_sms_config()).expect("provider should build");
let error = provider
.verify_code(SmsVerifyCodeRequest {
national_phone_number: "13800138000".to_string(),
verify_code: "000000".to_string(),
provider_out_id: None,
})
.await
.expect_err("wrong verify code should fail");
assert_eq!(error, SmsProviderError::InvalidVerifyCode);
}
#[test]
fn aliyun_sms_config_requires_access_key() {
let error = SmsAuthConfig::new(
SmsAuthProviderKind::Aliyun,
DEFAULT_SMS_ENDPOINT.to_string(),
None,
None,
"测试签名".to_string(),
"SMS_001".to_string(),
DEFAULT_SMS_TEMPLATE_PARAM_KEY.to_string(),
DEFAULT_SMS_COUNTRY_CODE.to_string(),
None,
DEFAULT_SMS_CODE_LENGTH,
DEFAULT_SMS_CODE_TYPE,
DEFAULT_SMS_VALID_TIME_SECONDS,
DEFAULT_SMS_INTERVAL_SECONDS,
DEFAULT_SMS_DUPLICATE_POLICY,
DEFAULT_SMS_CASE_AUTH_POLICY,
false,
DEFAULT_SMS_MOCK_VERIFY_CODE.to_string(),
)
.expect_err("aliyun config without access key should fail");
assert_eq!(
error,
SmsProviderError::InvalidConfig("阿里云短信 AccessKey 未配置".to_string())
);
}
#[test]
fn canonicalize_aliyun_rpc_params_keeps_sorted_percent_encoded_order() {
let mut params = BTreeMap::new();
params.insert(
"TemplateParam".to_string(),
"{\"code\":\"##code##\"}".to_string(),
);
params.insert("Action".to_string(), "SendSmsVerifyCode".to_string());
params.insert("PhoneNumber".to_string(), "13800138000".to_string());
assert_eq!(
canonicalize_aliyun_rpc_params(&params),
"Action=SendSmsVerifyCode&PhoneNumber=13800138000&TemplateParam=%7B%22code%22%3A%22%23%23code%23%23%22%7D"
);
}
#[test]
fn aliyun_send_response_deserializes_pascal_case_fields() {
let payload = serde_json::from_str::<AliyunSendSmsVerifyCodeResponse>(
r#"{
"Code": "OK",
"Message": "成功",
"RequestId": "req_123",
"Success": true,
"Model": {
"BizId": "biz_456",
"OutId": "out_789",
"RequestId": "req_model_001"
}
}"#,
)
.expect("aliyun send response should deserialize");
assert_eq!(payload.code.as_deref(), Some("OK"));
assert_eq!(payload.message.as_deref(), Some("成功"));
assert_eq!(payload.request_id.as_deref(), Some("req_123"));
assert_eq!(payload.success, Some(true));
assert_eq!(
payload
.model
.as_ref()
.and_then(|model| model.out_id.as_deref()),
Some("out_789")
);
assert_eq!(
payload
.model
.as_ref()
.and_then(|model| model.request_id.as_deref()),
Some("req_model_001")
);
}
#[test]
fn aliyun_verify_response_deserializes_pascal_case_fields() {
let payload = serde_json::from_str::<AliyunCheckSmsVerifyCodeResponse>(
r#"{
"Code": "OK",
"Message": "成功",
"Success": true,
"Model": {
"OutId": "out_789",
"VerifyResult": "PASS"
}
}"#,
)
.expect("aliyun verify response should deserialize");
assert_eq!(payload.code.as_deref(), Some("OK"));
assert_eq!(payload.message.as_deref(), Some("成功"));
assert_eq!(payload.success, Some(true));
assert_eq!(
payload
.model
.as_ref()
.and_then(|model| model.verify_result.as_deref()),
Some("PASS")
);
}
}