feat: add platform auth jwt adapter
This commit is contained in:
377
server-rs/crates/platform-auth/src/lib.rs
Normal file
377
server-rs/crates/platform-auth/src/lib.rs
Normal file
@@ -0,0 +1,377 @@
|
||||
use std::{collections::HashSet, error::Error, fmt};
|
||||
|
||||
use jsonwebtoken::{
|
||||
Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode, errors::ErrorKind,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use time::{Duration, OffsetDateTime};
|
||||
|
||||
pub const ACCESS_TOKEN_ALGORITHM: Algorithm = Algorithm::HS256;
|
||||
pub const DEFAULT_ACCESS_TOKEN_TTL_SECONDS: u64 = 2 * 60 * 60;
|
||||
|
||||
// 鉴权 provider 直接冻结成文档中约定的枚举,避免后续在多个 crate 内重复发明字符串字面量。
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum AuthProvider {
|
||||
Password,
|
||||
Phone,
|
||||
Wechat,
|
||||
}
|
||||
|
||||
// 绑定状态只保留当前 JWT 需要透传的最小快照,不把完整账号状态枚举直接泄漏到 token 中。
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum BindingStatus {
|
||||
Active,
|
||||
PendingBindPhone,
|
||||
}
|
||||
|
||||
// 用于签发 access token 的领域输入,和最终 JWT claims 解耦,避免业务层手动拼 iat/exp/iss。
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct AccessTokenClaimsInput {
|
||||
pub user_id: String,
|
||||
pub session_id: String,
|
||||
pub provider: AuthProvider,
|
||||
pub roles: Vec<String>,
|
||||
pub token_version: u64,
|
||||
pub phone_verified: bool,
|
||||
pub binding_status: BindingStatus,
|
||||
pub display_name: Option<String>,
|
||||
}
|
||||
|
||||
// 直接映射最终 JWT payload,字段名与文档冻结口径保持一致。
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct AccessTokenClaims {
|
||||
pub iss: String,
|
||||
pub sub: String,
|
||||
pub sid: String,
|
||||
pub provider: AuthProvider,
|
||||
pub roles: Vec<String>,
|
||||
pub ver: u64,
|
||||
pub phone_verified: bool,
|
||||
pub binding_status: BindingStatus,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub display_name: Option<String>,
|
||||
pub iat: u64,
|
||||
pub exp: u64,
|
||||
}
|
||||
|
||||
// 统一承载 JWT 配置,避免 secret、issuer、ttl 在 api-server 与后续模块里散落。
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct JwtConfig {
|
||||
issuer: String,
|
||||
secret: String,
|
||||
access_token_ttl_seconds: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum JwtError {
|
||||
InvalidConfig(&'static str),
|
||||
InvalidClaims(&'static str),
|
||||
SignFailed(String),
|
||||
VerifyFailed(String),
|
||||
}
|
||||
|
||||
impl JwtConfig {
|
||||
pub fn new(
|
||||
issuer: String,
|
||||
secret: String,
|
||||
access_token_ttl_seconds: u64,
|
||||
) -> Result<Self, JwtError> {
|
||||
let issuer = issuer.trim().to_string();
|
||||
let secret = secret.trim().to_string();
|
||||
|
||||
if issuer.is_empty() {
|
||||
return Err(JwtError::InvalidConfig("JWT issuer 不能为空"));
|
||||
}
|
||||
|
||||
if secret.is_empty() {
|
||||
return Err(JwtError::InvalidConfig("JWT secret 不能为空"));
|
||||
}
|
||||
|
||||
if access_token_ttl_seconds == 0 {
|
||||
return Err(JwtError::InvalidConfig(
|
||||
"JWT access token 过期时间必须大于 0",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
issuer,
|
||||
secret,
|
||||
access_token_ttl_seconds,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn issuer(&self) -> &str {
|
||||
&self.issuer
|
||||
}
|
||||
|
||||
pub fn access_token_ttl_seconds(&self) -> u64 {
|
||||
self.access_token_ttl_seconds
|
||||
}
|
||||
}
|
||||
|
||||
impl AccessTokenClaims {
|
||||
pub fn from_input(
|
||||
input: AccessTokenClaimsInput,
|
||||
config: &JwtConfig,
|
||||
issued_at: OffsetDateTime,
|
||||
) -> Result<Self, JwtError> {
|
||||
let user_id = normalize_required_field(input.user_id, "JWT sub 不能为空")?;
|
||||
let session_id = normalize_required_field(input.session_id, "JWT sid 不能为空")?;
|
||||
let roles = normalize_roles(input.roles)?;
|
||||
let display_name = normalize_optional_field(input.display_name);
|
||||
|
||||
let issued_at_unix = issued_at.unix_timestamp();
|
||||
if issued_at_unix < 0 {
|
||||
return Err(JwtError::InvalidClaims("JWT iat 不能早于 Unix epoch"));
|
||||
}
|
||||
|
||||
let expires_at = issued_at
|
||||
.checked_add(Duration::seconds(
|
||||
i64::try_from(config.access_token_ttl_seconds()).map_err(|_| {
|
||||
JwtError::InvalidConfig("JWT access token 过期时间超出 i64 上限")
|
||||
})?,
|
||||
))
|
||||
.ok_or(JwtError::InvalidConfig("JWT 过期时间计算溢出"))?;
|
||||
|
||||
let expires_at_unix = expires_at.unix_timestamp();
|
||||
if expires_at_unix <= issued_at_unix {
|
||||
return Err(JwtError::InvalidClaims("JWT exp 必须晚于 iat"));
|
||||
}
|
||||
|
||||
let claims = Self {
|
||||
iss: config.issuer().to_string(),
|
||||
sub: user_id,
|
||||
sid: session_id,
|
||||
provider: input.provider,
|
||||
roles,
|
||||
ver: input.token_version,
|
||||
phone_verified: input.phone_verified,
|
||||
binding_status: input.binding_status,
|
||||
display_name,
|
||||
iat: issued_at_unix as u64,
|
||||
exp: expires_at_unix as u64,
|
||||
};
|
||||
|
||||
claims.validate_for_config(config)?;
|
||||
Ok(claims)
|
||||
}
|
||||
|
||||
pub fn user_id(&self) -> &str {
|
||||
&self.sub
|
||||
}
|
||||
|
||||
pub fn session_id(&self) -> &str {
|
||||
&self.sid
|
||||
}
|
||||
|
||||
pub fn token_version(&self) -> u64 {
|
||||
self.ver
|
||||
}
|
||||
|
||||
pub fn validate_for_config(&self, config: &JwtConfig) -> Result<(), JwtError> {
|
||||
if self.iss.trim() != config.issuer() {
|
||||
return Err(JwtError::InvalidClaims("JWT iss 与当前配置不一致"));
|
||||
}
|
||||
|
||||
normalize_required_field(self.sub.clone(), "JWT sub 不能为空")?;
|
||||
normalize_required_field(self.sid.clone(), "JWT sid 不能为空")?;
|
||||
normalize_roles(self.roles.clone())?;
|
||||
|
||||
if self.exp <= self.iat {
|
||||
return Err(JwtError::InvalidClaims("JWT exp 必须晚于 iat"));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sign_access_token(
|
||||
claims: &AccessTokenClaims,
|
||||
config: &JwtConfig,
|
||||
) -> Result<String, JwtError> {
|
||||
claims.validate_for_config(config)?;
|
||||
|
||||
let header = Header {
|
||||
alg: ACCESS_TOKEN_ALGORITHM,
|
||||
typ: Some("JWT".to_string()),
|
||||
..Header::default()
|
||||
};
|
||||
|
||||
encode(
|
||||
&header,
|
||||
claims,
|
||||
&EncodingKey::from_secret(config.secret.as_bytes()),
|
||||
)
|
||||
.map_err(|error| JwtError::SignFailed(format!("JWT 签发失败:{error}")))
|
||||
}
|
||||
|
||||
pub fn verify_access_token(token: &str, config: &JwtConfig) -> Result<AccessTokenClaims, JwtError> {
|
||||
let token = token.trim();
|
||||
if token.is_empty() {
|
||||
return Err(JwtError::VerifyFailed("JWT 不能为空".to_string()));
|
||||
}
|
||||
|
||||
let mut validation = Validation::new(ACCESS_TOKEN_ALGORITHM);
|
||||
validation.required_spec_claims = HashSet::from([
|
||||
"exp".to_string(),
|
||||
"iat".to_string(),
|
||||
"iss".to_string(),
|
||||
"sub".to_string(),
|
||||
]);
|
||||
validation.set_issuer(&[config.issuer()]);
|
||||
|
||||
let decoded = decode::<AccessTokenClaims>(
|
||||
token,
|
||||
&DecodingKey::from_secret(config.secret.as_bytes()),
|
||||
&validation,
|
||||
)
|
||||
.map_err(map_verify_error)?;
|
||||
|
||||
decoded.claims.validate_for_config(config)?;
|
||||
Ok(decoded.claims)
|
||||
}
|
||||
|
||||
impl fmt::Display for JwtError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::InvalidConfig(message) | Self::InvalidClaims(message) => f.write_str(message),
|
||||
Self::SignFailed(message) | Self::VerifyFailed(message) => f.write_str(message),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for JwtError {}
|
||||
|
||||
fn normalize_required_field(
|
||||
value: String,
|
||||
error_message: &'static str,
|
||||
) -> Result<String, JwtError> {
|
||||
let value = value.trim().to_string();
|
||||
if value.is_empty() {
|
||||
return Err(JwtError::InvalidClaims(error_message));
|
||||
}
|
||||
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
fn normalize_optional_field(value: Option<String>) -> Option<String> {
|
||||
value.and_then(|field| {
|
||||
let field = field.trim().to_string();
|
||||
if field.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(field)
|
||||
})
|
||||
}
|
||||
|
||||
fn normalize_roles(roles: Vec<String>) -> Result<Vec<String>, JwtError> {
|
||||
let roles = roles
|
||||
.into_iter()
|
||||
.map(|role| role.trim().to_string())
|
||||
.filter(|role| !role.is_empty())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if roles.is_empty() {
|
||||
return Err(JwtError::InvalidClaims("JWT roles 至少包含一个角色"));
|
||||
}
|
||||
|
||||
Ok(roles)
|
||||
}
|
||||
|
||||
fn map_verify_error(error: jsonwebtoken::errors::Error) -> JwtError {
|
||||
let message = match error.kind() {
|
||||
ErrorKind::ExpiredSignature => "JWT 已过期".to_string(),
|
||||
ErrorKind::InvalidIssuer => "JWT 发行者不匹配".to_string(),
|
||||
ErrorKind::InvalidSignature => "JWT 签名无效".to_string(),
|
||||
ErrorKind::InvalidAlgorithm => "JWT 算法不匹配".to_string(),
|
||||
ErrorKind::InvalidToken => "JWT 非法".to_string(),
|
||||
ErrorKind::ImmatureSignature => "JWT 尚未生效".to_string(),
|
||||
ErrorKind::MissingRequiredClaim(claim) => format!("JWT 缺少必填字段:{claim}"),
|
||||
_ => format!("JWT 校验失败:{error}"),
|
||||
};
|
||||
|
||||
JwtError::VerifyFailed(message)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn build_jwt_config() -> JwtConfig {
|
||||
JwtConfig::new(
|
||||
"https://auth.genarrative.local".to_string(),
|
||||
"genarrative-dev-secret".to_string(),
|
||||
DEFAULT_ACCESS_TOKEN_TTL_SECONDS,
|
||||
)
|
||||
.expect("jwt config should be valid")
|
||||
}
|
||||
|
||||
fn build_claims_input() -> AccessTokenClaimsInput {
|
||||
AccessTokenClaimsInput {
|
||||
user_id: "usr_123".to_string(),
|
||||
session_id: "sess_456".to_string(),
|
||||
provider: AuthProvider::Wechat,
|
||||
roles: vec!["user".to_string()],
|
||||
token_version: 3,
|
||||
phone_verified: false,
|
||||
binding_status: BindingStatus::PendingBindPhone,
|
||||
display_name: Some("微信旅人".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn round_trip_sign_and_verify_access_token() {
|
||||
let config = build_jwt_config();
|
||||
let claims =
|
||||
AccessTokenClaims::from_input(build_claims_input(), &config, OffsetDateTime::now_utc())
|
||||
.expect("claims should build");
|
||||
|
||||
let token = sign_access_token(&claims, &config).expect("token should sign");
|
||||
let verified = verify_access_token(&token, &config).expect("token should verify");
|
||||
|
||||
assert_eq!(verified, claims);
|
||||
assert_eq!(verified.user_id(), "usr_123");
|
||||
assert_eq!(verified.session_id(), "sess_456");
|
||||
assert_eq!(verified.token_version(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_rejects_invalid_issuer() {
|
||||
let config = build_jwt_config();
|
||||
let claims =
|
||||
AccessTokenClaims::from_input(build_claims_input(), &config, OffsetDateTime::now_utc())
|
||||
.expect("claims should build");
|
||||
let token = sign_access_token(&claims, &config).expect("token should sign");
|
||||
let other_config = JwtConfig::new(
|
||||
"https://auth.other.local".to_string(),
|
||||
"genarrative-dev-secret".to_string(),
|
||||
DEFAULT_ACCESS_TOKEN_TTL_SECONDS,
|
||||
)
|
||||
.expect("other config should be valid");
|
||||
|
||||
let error = verify_access_token(&token, &other_config).expect_err("issuer should mismatch");
|
||||
|
||||
assert_eq!(
|
||||
error,
|
||||
JwtError::VerifyFailed("JWT 发行者不匹配".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_claims_rejects_empty_roles() {
|
||||
let error = AccessTokenClaims::from_input(
|
||||
AccessTokenClaimsInput {
|
||||
roles: Vec::new(),
|
||||
..build_claims_input()
|
||||
},
|
||||
&build_jwt_config(),
|
||||
OffsetDateTime::now_utc(),
|
||||
)
|
||||
.expect_err("empty roles should be rejected");
|
||||
|
||||
assert_eq!(error, JwtError::InvalidClaims("JWT roles 至少包含一个角色"));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user