feat: add platform auth jwt adapter

This commit is contained in:
2026-04-21 13:02:44 +08:00
parent e37163d4d3
commit adaf514a1a
20 changed files with 1220 additions and 44 deletions

View File

@@ -0,0 +1,377 @@
use std::{collections::HashSet, error::Error, fmt};
use jsonwebtoken::{
Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode, errors::ErrorKind,
};
use serde::{Deserialize, Serialize};
use time::{Duration, OffsetDateTime};
pub const ACCESS_TOKEN_ALGORITHM: Algorithm = Algorithm::HS256;
pub const DEFAULT_ACCESS_TOKEN_TTL_SECONDS: u64 = 2 * 60 * 60;
// 鉴权 provider 直接冻结成文档中约定的枚举,避免后续在多个 crate 内重复发明字符串字面量。
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AuthProvider {
Password,
Phone,
Wechat,
}
// 绑定状态只保留当前 JWT 需要透传的最小快照,不把完整账号状态枚举直接泄漏到 token 中。
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum BindingStatus {
Active,
PendingBindPhone,
}
// 用于签发 access token 的领域输入,和最终 JWT claims 解耦,避免业务层手动拼 iat/exp/iss。
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct AccessTokenClaimsInput {
pub user_id: String,
pub session_id: String,
pub provider: AuthProvider,
pub roles: Vec<String>,
pub token_version: u64,
pub phone_verified: bool,
pub binding_status: BindingStatus,
pub display_name: Option<String>,
}
// 直接映射最终 JWT payload字段名与文档冻结口径保持一致。
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct AccessTokenClaims {
pub iss: String,
pub sub: String,
pub sid: String,
pub provider: AuthProvider,
pub roles: Vec<String>,
pub ver: u64,
pub phone_verified: bool,
pub binding_status: BindingStatus,
#[serde(skip_serializing_if = "Option::is_none")]
pub display_name: Option<String>,
pub iat: u64,
pub exp: u64,
}
// 统一承载 JWT 配置,避免 secret、issuer、ttl 在 api-server 与后续模块里散落。
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct JwtConfig {
issuer: String,
secret: String,
access_token_ttl_seconds: u64,
}
#[derive(Debug, PartialEq, Eq)]
pub enum JwtError {
InvalidConfig(&'static str),
InvalidClaims(&'static str),
SignFailed(String),
VerifyFailed(String),
}
impl JwtConfig {
pub fn new(
issuer: String,
secret: String,
access_token_ttl_seconds: u64,
) -> Result<Self, JwtError> {
let issuer = issuer.trim().to_string();
let secret = secret.trim().to_string();
if issuer.is_empty() {
return Err(JwtError::InvalidConfig("JWT issuer 不能为空"));
}
if secret.is_empty() {
return Err(JwtError::InvalidConfig("JWT secret 不能为空"));
}
if access_token_ttl_seconds == 0 {
return Err(JwtError::InvalidConfig(
"JWT access token 过期时间必须大于 0",
));
}
Ok(Self {
issuer,
secret,
access_token_ttl_seconds,
})
}
pub fn issuer(&self) -> &str {
&self.issuer
}
pub fn access_token_ttl_seconds(&self) -> u64 {
self.access_token_ttl_seconds
}
}
impl AccessTokenClaims {
pub fn from_input(
input: AccessTokenClaimsInput,
config: &JwtConfig,
issued_at: OffsetDateTime,
) -> Result<Self, JwtError> {
let user_id = normalize_required_field(input.user_id, "JWT sub 不能为空")?;
let session_id = normalize_required_field(input.session_id, "JWT sid 不能为空")?;
let roles = normalize_roles(input.roles)?;
let display_name = normalize_optional_field(input.display_name);
let issued_at_unix = issued_at.unix_timestamp();
if issued_at_unix < 0 {
return Err(JwtError::InvalidClaims("JWT iat 不能早于 Unix epoch"));
}
let expires_at = issued_at
.checked_add(Duration::seconds(
i64::try_from(config.access_token_ttl_seconds()).map_err(|_| {
JwtError::InvalidConfig("JWT access token 过期时间超出 i64 上限")
})?,
))
.ok_or(JwtError::InvalidConfig("JWT 过期时间计算溢出"))?;
let expires_at_unix = expires_at.unix_timestamp();
if expires_at_unix <= issued_at_unix {
return Err(JwtError::InvalidClaims("JWT exp 必须晚于 iat"));
}
let claims = Self {
iss: config.issuer().to_string(),
sub: user_id,
sid: session_id,
provider: input.provider,
roles,
ver: input.token_version,
phone_verified: input.phone_verified,
binding_status: input.binding_status,
display_name,
iat: issued_at_unix as u64,
exp: expires_at_unix as u64,
};
claims.validate_for_config(config)?;
Ok(claims)
}
pub fn user_id(&self) -> &str {
&self.sub
}
pub fn session_id(&self) -> &str {
&self.sid
}
pub fn token_version(&self) -> u64 {
self.ver
}
pub fn validate_for_config(&self, config: &JwtConfig) -> Result<(), JwtError> {
if self.iss.trim() != config.issuer() {
return Err(JwtError::InvalidClaims("JWT iss 与当前配置不一致"));
}
normalize_required_field(self.sub.clone(), "JWT sub 不能为空")?;
normalize_required_field(self.sid.clone(), "JWT sid 不能为空")?;
normalize_roles(self.roles.clone())?;
if self.exp <= self.iat {
return Err(JwtError::InvalidClaims("JWT exp 必须晚于 iat"));
}
Ok(())
}
}
pub fn sign_access_token(
claims: &AccessTokenClaims,
config: &JwtConfig,
) -> Result<String, JwtError> {
claims.validate_for_config(config)?;
let header = Header {
alg: ACCESS_TOKEN_ALGORITHM,
typ: Some("JWT".to_string()),
..Header::default()
};
encode(
&header,
claims,
&EncodingKey::from_secret(config.secret.as_bytes()),
)
.map_err(|error| JwtError::SignFailed(format!("JWT 签发失败:{error}")))
}
pub fn verify_access_token(token: &str, config: &JwtConfig) -> Result<AccessTokenClaims, JwtError> {
let token = token.trim();
if token.is_empty() {
return Err(JwtError::VerifyFailed("JWT 不能为空".to_string()));
}
let mut validation = Validation::new(ACCESS_TOKEN_ALGORITHM);
validation.required_spec_claims = HashSet::from([
"exp".to_string(),
"iat".to_string(),
"iss".to_string(),
"sub".to_string(),
]);
validation.set_issuer(&[config.issuer()]);
let decoded = decode::<AccessTokenClaims>(
token,
&DecodingKey::from_secret(config.secret.as_bytes()),
&validation,
)
.map_err(map_verify_error)?;
decoded.claims.validate_for_config(config)?;
Ok(decoded.claims)
}
impl fmt::Display for JwtError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidConfig(message) | Self::InvalidClaims(message) => f.write_str(message),
Self::SignFailed(message) | Self::VerifyFailed(message) => f.write_str(message),
}
}
}
impl Error for JwtError {}
fn normalize_required_field(
value: String,
error_message: &'static str,
) -> Result<String, JwtError> {
let value = value.trim().to_string();
if value.is_empty() {
return Err(JwtError::InvalidClaims(error_message));
}
Ok(value)
}
fn normalize_optional_field(value: Option<String>) -> Option<String> {
value.and_then(|field| {
let field = field.trim().to_string();
if field.is_empty() {
return None;
}
Some(field)
})
}
fn normalize_roles(roles: Vec<String>) -> Result<Vec<String>, JwtError> {
let roles = roles
.into_iter()
.map(|role| role.trim().to_string())
.filter(|role| !role.is_empty())
.collect::<Vec<_>>();
if roles.is_empty() {
return Err(JwtError::InvalidClaims("JWT roles 至少包含一个角色"));
}
Ok(roles)
}
fn map_verify_error(error: jsonwebtoken::errors::Error) -> JwtError {
let message = match error.kind() {
ErrorKind::ExpiredSignature => "JWT 已过期".to_string(),
ErrorKind::InvalidIssuer => "JWT 发行者不匹配".to_string(),
ErrorKind::InvalidSignature => "JWT 签名无效".to_string(),
ErrorKind::InvalidAlgorithm => "JWT 算法不匹配".to_string(),
ErrorKind::InvalidToken => "JWT 非法".to_string(),
ErrorKind::ImmatureSignature => "JWT 尚未生效".to_string(),
ErrorKind::MissingRequiredClaim(claim) => format!("JWT 缺少必填字段:{claim}"),
_ => format!("JWT 校验失败:{error}"),
};
JwtError::VerifyFailed(message)
}
#[cfg(test)]
mod tests {
use super::*;
fn build_jwt_config() -> JwtConfig {
JwtConfig::new(
"https://auth.genarrative.local".to_string(),
"genarrative-dev-secret".to_string(),
DEFAULT_ACCESS_TOKEN_TTL_SECONDS,
)
.expect("jwt config should be valid")
}
fn build_claims_input() -> AccessTokenClaimsInput {
AccessTokenClaimsInput {
user_id: "usr_123".to_string(),
session_id: "sess_456".to_string(),
provider: AuthProvider::Wechat,
roles: vec!["user".to_string()],
token_version: 3,
phone_verified: false,
binding_status: BindingStatus::PendingBindPhone,
display_name: Some("微信旅人".to_string()),
}
}
#[test]
fn round_trip_sign_and_verify_access_token() {
let config = build_jwt_config();
let claims =
AccessTokenClaims::from_input(build_claims_input(), &config, OffsetDateTime::now_utc())
.expect("claims should build");
let token = sign_access_token(&claims, &config).expect("token should sign");
let verified = verify_access_token(&token, &config).expect("token should verify");
assert_eq!(verified, claims);
assert_eq!(verified.user_id(), "usr_123");
assert_eq!(verified.session_id(), "sess_456");
assert_eq!(verified.token_version(), 3);
}
#[test]
fn verify_rejects_invalid_issuer() {
let config = build_jwt_config();
let claims =
AccessTokenClaims::from_input(build_claims_input(), &config, OffsetDateTime::now_utc())
.expect("claims should build");
let token = sign_access_token(&claims, &config).expect("token should sign");
let other_config = JwtConfig::new(
"https://auth.other.local".to_string(),
"genarrative-dev-secret".to_string(),
DEFAULT_ACCESS_TOKEN_TTL_SECONDS,
)
.expect("other config should be valid");
let error = verify_access_token(&token, &other_config).expect_err("issuer should mismatch");
assert_eq!(
error,
JwtError::VerifyFailed("JWT 发行者不匹配".to_string())
);
}
#[test]
fn build_claims_rejects_empty_roles() {
let error = AccessTokenClaims::from_input(
AccessTokenClaimsInput {
roles: Vec::new(),
..build_claims_input()
},
&build_jwt_config(),
OffsetDateTime::now_utc(),
)
.expect_err("empty roles should be rejected");
assert_eq!(error, JwtError::InvalidClaims("JWT roles 至少包含一个角色"));
}
}