use std::{collections::HashSet, error::Error, fmt}; use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier, password_hash::SaltString}; use jsonwebtoken::{ Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode, errors::ErrorKind, }; use rand_core::OsRng; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use shared_kernel::{new_uuid_simple_string, normalize_optional_string, normalize_required_string}; use time::{Duration, OffsetDateTime}; pub const ACCESS_TOKEN_ALGORITHM: Algorithm = Algorithm::HS256; pub const DEFAULT_ACCESS_TOKEN_TTL_SECONDS: u64 = 2 * 60 * 60; pub const DEFAULT_REFRESH_COOKIE_NAME: &str = "genarrative_refresh_session"; pub const DEFAULT_REFRESH_COOKIE_PATH: &str = "/api/auth"; pub const DEFAULT_REFRESH_SESSION_TTL_DAYS: u32 = 30; // 鉴权 provider 直接冻结成文档中约定的枚举,避免后续在多个 crate 内重复发明字符串字面量。 #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum AuthProvider { Password, Phone, Wechat, } // 绑定状态只保留当前 JWT 需要透传的最小快照,不把完整账号状态枚举直接泄漏到 token 中。 #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum BindingStatus { Active, PendingBindPhone, } // 用于签发 access token 的领域输入,和最终 JWT claims 解耦,避免业务层手动拼 iat/exp/iss。 #[derive(Clone, Debug, PartialEq, Eq)] pub struct AccessTokenClaimsInput { pub user_id: String, pub session_id: String, pub provider: AuthProvider, pub roles: Vec, pub token_version: u64, pub phone_verified: bool, pub binding_status: BindingStatus, pub display_name: Option, } // 直接映射最终 JWT payload,字段名与文档冻结口径保持一致。 #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct AccessTokenClaims { pub iss: String, pub sub: String, pub sid: String, pub provider: AuthProvider, pub roles: Vec, pub ver: u64, pub phone_verified: bool, pub binding_status: BindingStatus, #[serde(skip_serializing_if = "Option::is_none")] pub display_name: Option, pub iat: u64, pub exp: u64, } // 统一承载 JWT 配置,避免 secret、issuer、ttl 在 api-server 与后续模块里散落。 #[derive(Clone, Debug, PartialEq, Eq)] pub struct JwtConfig { issuer: String, secret: String, access_token_ttl_seconds: u64, } // refresh cookie 的 SameSite 固定约束成枚举,避免各层直接使用大小写不一致的字符串。 #[derive(Clone, Debug, PartialEq, Eq)] pub enum RefreshCookieSameSite { Lax, Strict, None, } // refresh cookie 的平台配置统一收口到 platform-auth,避免 api-server 直接散落 cookie 细节。 #[derive(Clone, Debug, PartialEq, Eq)] pub struct RefreshCookieConfig { cookie_name: String, cookie_path: String, cookie_secure: bool, cookie_same_site: RefreshCookieSameSite, refresh_session_ttl_days: u32, } #[derive(Debug, PartialEq, Eq)] pub enum JwtError { InvalidConfig(&'static str), InvalidClaims(&'static str), SignFailed(String), VerifyFailed(String), } #[derive(Debug, PartialEq, Eq)] pub enum RefreshCookieError { InvalidConfig(&'static str), } #[derive(Debug, PartialEq, Eq)] pub enum PasswordHashError { HashFailed(String), VerifyFailed(String), } impl JwtConfig { pub fn new( issuer: String, secret: String, access_token_ttl_seconds: u64, ) -> Result { let issuer = normalize_required_string(&issuer) .ok_or(JwtError::InvalidConfig("JWT issuer 不能为空"))?; let secret = normalize_required_string(&secret) .ok_or(JwtError::InvalidConfig("JWT secret 不能为空"))?; if access_token_ttl_seconds == 0 { return Err(JwtError::InvalidConfig( "JWT access token 过期时间必须大于 0", )); } Ok(Self { issuer, secret, access_token_ttl_seconds, }) } pub fn issuer(&self) -> &str { &self.issuer } pub fn access_token_ttl_seconds(&self) -> u64 { self.access_token_ttl_seconds } } impl RefreshCookieSameSite { pub fn parse(raw: &str) -> Option { match raw.trim().to_ascii_lowercase().as_str() { "lax" => Some(Self::Lax), "strict" => Some(Self::Strict), "none" => Some(Self::None), _ => None, } } pub fn as_str(&self) -> &'static str { match self { Self::Lax => "Lax", Self::Strict => "Strict", Self::None => "None", } } } impl RefreshCookieConfig { pub fn new( cookie_name: String, cookie_path: String, cookie_secure: bool, cookie_same_site: RefreshCookieSameSite, refresh_session_ttl_days: u32, ) -> Result { let cookie_name = normalize_required_string(&cookie_name).ok_or( RefreshCookieError::InvalidConfig("refresh cookie 名称不能为空"), )?; let cookie_path = normalize_required_string(&cookie_path).ok_or( RefreshCookieError::InvalidConfig("refresh cookie path 不能为空"), )?; if refresh_session_ttl_days == 0 { return Err(RefreshCookieError::InvalidConfig( "refresh session TTL 天数必须大于 0", )); } Ok(Self { cookie_name, cookie_path, cookie_secure, cookie_same_site, refresh_session_ttl_days, }) } pub fn cookie_name(&self) -> &str { &self.cookie_name } pub fn cookie_path(&self) -> &str { &self.cookie_path } pub fn cookie_secure(&self) -> bool { self.cookie_secure } pub fn cookie_same_site(&self) -> &RefreshCookieSameSite { &self.cookie_same_site } pub fn refresh_session_ttl_days(&self) -> u32 { self.refresh_session_ttl_days } } impl AccessTokenClaims { pub fn from_input( input: AccessTokenClaimsInput, config: &JwtConfig, issued_at: OffsetDateTime, ) -> Result { let user_id = normalize_required_field(input.user_id, "JWT sub 不能为空")?; let session_id = normalize_required_field(input.session_id, "JWT sid 不能为空")?; let roles = normalize_roles(input.roles)?; let display_name = normalize_optional_field(input.display_name); let issued_at_unix = issued_at.unix_timestamp(); if issued_at_unix < 0 { return Err(JwtError::InvalidClaims("JWT iat 不能早于 Unix epoch")); } let expires_at = issued_at .checked_add(Duration::seconds( i64::try_from(config.access_token_ttl_seconds()).map_err(|_| { JwtError::InvalidConfig("JWT access token 过期时间超出 i64 上限") })?, )) .ok_or(JwtError::InvalidConfig("JWT 过期时间计算溢出"))?; let expires_at_unix = expires_at.unix_timestamp(); if expires_at_unix <= issued_at_unix { return Err(JwtError::InvalidClaims("JWT exp 必须晚于 iat")); } let claims = Self { iss: config.issuer().to_string(), sub: user_id, sid: session_id, provider: input.provider, roles, ver: input.token_version, phone_verified: input.phone_verified, binding_status: input.binding_status, display_name, iat: issued_at_unix as u64, exp: expires_at_unix as u64, }; claims.validate_for_config(config)?; Ok(claims) } pub fn user_id(&self) -> &str { &self.sub } pub fn session_id(&self) -> &str { &self.sid } pub fn token_version(&self) -> u64 { self.ver } pub fn validate_for_config(&self, config: &JwtConfig) -> Result<(), JwtError> { if self.iss.trim() != config.issuer() { return Err(JwtError::InvalidClaims("JWT iss 与当前配置不一致")); } normalize_required_field(self.sub.clone(), "JWT sub 不能为空")?; normalize_required_field(self.sid.clone(), "JWT sid 不能为空")?; normalize_roles(self.roles.clone())?; if self.exp <= self.iat { return Err(JwtError::InvalidClaims("JWT exp 必须晚于 iat")); } Ok(()) } } pub fn sign_access_token( claims: &AccessTokenClaims, config: &JwtConfig, ) -> Result { claims.validate_for_config(config)?; let header = Header { alg: ACCESS_TOKEN_ALGORITHM, typ: Some("JWT".to_string()), ..Header::default() }; encode( &header, claims, &EncodingKey::from_secret(config.secret.as_bytes()), ) .map_err(|error| JwtError::SignFailed(format!("JWT 签发失败:{error}"))) } pub fn verify_access_token(token: &str, config: &JwtConfig) -> Result { let token = token.trim(); if token.is_empty() { return Err(JwtError::VerifyFailed("JWT 不能为空".to_string())); } let mut validation = Validation::new(ACCESS_TOKEN_ALGORITHM); validation.required_spec_claims = HashSet::from([ "exp".to_string(), "iat".to_string(), "iss".to_string(), "sub".to_string(), ]); validation.set_issuer(&[config.issuer()]); let decoded = decode::( token, &DecodingKey::from_secret(config.secret.as_bytes()), &validation, ) .map_err(map_verify_error)?; decoded.claims.validate_for_config(config)?; Ok(decoded.claims) } pub fn read_refresh_session_token( cookie_header: &str, config: &RefreshCookieConfig, ) -> Option { let cookie_header = cookie_header.trim(); if cookie_header.is_empty() { return None; } for entry in cookie_header.split(';') { let entry = entry.trim(); if entry.is_empty() { continue; } let (raw_name, raw_value) = entry.split_once('=')?; if raw_name.trim() != config.cookie_name() { continue; } let raw_value = raw_value.trim(); if raw_value.is_empty() { return None; } return urlencoding::decode(raw_value) .ok() .map(|decoded| decoded.into_owned()); } None } pub async fn hash_password(password: &str) -> Result { let salt = SaltString::generate(&mut OsRng); Argon2::default() .hash_password(password.as_bytes(), &salt) .map(|hash| hash.to_string()) .map_err(|error| PasswordHashError::HashFailed(format!("密码哈希失败:{error}"))) } pub async fn verify_password( password_hash: &str, password: &str, ) -> Result { let parsed_hash = PasswordHash::new(password_hash) .map_err(|error| PasswordHashError::VerifyFailed(format!("密码哈希格式非法:{error}")))?; Ok(Argon2::default() .verify_password(password.as_bytes(), &parsed_hash) .is_ok()) } pub fn create_refresh_session_token() -> String { new_uuid_simple_string() } pub fn hash_refresh_session_token(token: &str) -> String { let mut hasher = Sha256::new(); hasher.update(token.as_bytes()); format!("{:x}", hasher.finalize()) } pub fn build_refresh_session_set_cookie(token: &str, config: &RefreshCookieConfig) -> String { let mut parts = vec![ format!( "{}={}", config.cookie_name(), urlencoding::encode(token).into_owned() ), format!("Path={}", config.cookie_path()), "HttpOnly".to_string(), format!("SameSite={}", config.cookie_same_site().as_str()), format!( "Max-Age={}", u64::from(config.refresh_session_ttl_days()) * 24 * 60 * 60 ), ]; if config.cookie_secure() { parts.push("Secure".to_string()); } parts.join("; ") } pub fn build_refresh_session_clear_cookie(config: &RefreshCookieConfig) -> String { let mut parts = vec![ format!("{}=", config.cookie_name()), format!("Path={}", config.cookie_path()), "HttpOnly".to_string(), format!("SameSite={}", config.cookie_same_site().as_str()), "Max-Age=0".to_string(), ]; if config.cookie_secure() { parts.push("Secure".to_string()); } parts.join("; ") } impl fmt::Display for JwtError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::InvalidConfig(message) | Self::InvalidClaims(message) => f.write_str(message), Self::SignFailed(message) | Self::VerifyFailed(message) => f.write_str(message), } } } impl Error for JwtError {} impl fmt::Display for RefreshCookieError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::InvalidConfig(message) => f.write_str(message), } } } impl Error for RefreshCookieError {} impl fmt::Display for PasswordHashError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::HashFailed(message) | Self::VerifyFailed(message) => f.write_str(message), } } } impl Error for PasswordHashError {} fn normalize_required_field( value: String, error_message: &'static str, ) -> Result { normalize_required_string(&value).ok_or(JwtError::InvalidClaims(error_message)) } fn normalize_optional_field(value: Option) -> Option { normalize_optional_string(value) } fn normalize_roles(roles: Vec) -> Result, JwtError> { let roles = roles .into_iter() .map(|role| role.trim().to_string()) .filter(|role| !role.is_empty()) .collect::>(); if roles.is_empty() { return Err(JwtError::InvalidClaims("JWT roles 至少包含一个角色")); } Ok(roles) } fn map_verify_error(error: jsonwebtoken::errors::Error) -> JwtError { let message = match error.kind() { ErrorKind::ExpiredSignature => "JWT 已过期".to_string(), ErrorKind::InvalidIssuer => "JWT 发行者不匹配".to_string(), ErrorKind::InvalidSignature => "JWT 签名无效".to_string(), ErrorKind::InvalidAlgorithm => "JWT 算法不匹配".to_string(), ErrorKind::InvalidToken => "JWT 非法".to_string(), ErrorKind::ImmatureSignature => "JWT 尚未生效".to_string(), ErrorKind::MissingRequiredClaim(claim) => format!("JWT 缺少必填字段:{claim}"), _ => format!("JWT 校验失败:{error}"), }; JwtError::VerifyFailed(message) } #[cfg(test)] mod tests { use super::*; fn build_jwt_config() -> JwtConfig { JwtConfig::new( "https://auth.genarrative.local".to_string(), "genarrative-dev-secret".to_string(), DEFAULT_ACCESS_TOKEN_TTL_SECONDS, ) .expect("jwt config should be valid") } fn build_claims_input() -> AccessTokenClaimsInput { AccessTokenClaimsInput { user_id: "usr_123".to_string(), session_id: "sess_456".to_string(), provider: AuthProvider::Wechat, roles: vec!["user".to_string()], token_version: 3, phone_verified: false, binding_status: BindingStatus::PendingBindPhone, display_name: Some("微信旅人".to_string()), } } fn build_refresh_cookie_config() -> RefreshCookieConfig { RefreshCookieConfig::new( DEFAULT_REFRESH_COOKIE_NAME.to_string(), DEFAULT_REFRESH_COOKIE_PATH.to_string(), false, RefreshCookieSameSite::Lax, DEFAULT_REFRESH_SESSION_TTL_DAYS, ) .expect("refresh cookie config should be valid") } #[test] fn round_trip_sign_and_verify_access_token() { let config = build_jwt_config(); let claims = AccessTokenClaims::from_input(build_claims_input(), &config, OffsetDateTime::now_utc()) .expect("claims should build"); let token = sign_access_token(&claims, &config).expect("token should sign"); let verified = verify_access_token(&token, &config).expect("token should verify"); assert_eq!(verified, claims); assert_eq!(verified.user_id(), "usr_123"); assert_eq!(verified.session_id(), "sess_456"); assert_eq!(verified.token_version(), 3); } #[test] fn verify_rejects_invalid_issuer() { let config = build_jwt_config(); let claims = AccessTokenClaims::from_input(build_claims_input(), &config, OffsetDateTime::now_utc()) .expect("claims should build"); let token = sign_access_token(&claims, &config).expect("token should sign"); let other_config = JwtConfig::new( "https://auth.other.local".to_string(), "genarrative-dev-secret".to_string(), DEFAULT_ACCESS_TOKEN_TTL_SECONDS, ) .expect("other config should be valid"); let error = verify_access_token(&token, &other_config).expect_err("issuer should mismatch"); assert_eq!( error, JwtError::VerifyFailed("JWT 发行者不匹配".to_string()) ); } #[test] fn build_claims_rejects_empty_roles() { let error = AccessTokenClaims::from_input( AccessTokenClaimsInput { roles: Vec::new(), ..build_claims_input() }, &build_jwt_config(), OffsetDateTime::now_utc(), ) .expect_err("empty roles should be rejected"); assert_eq!(error, JwtError::InvalidClaims("JWT roles 至少包含一个角色")); } #[test] fn read_refresh_session_token_returns_matching_cookie() { let token = read_refresh_session_token( "theme=dark; genarrative_refresh_session=refresh-token-01; locale=zh-CN", &build_refresh_cookie_config(), ); assert_eq!(token.as_deref(), Some("refresh-token-01")); } #[test] fn read_refresh_session_token_decodes_urlencoded_value() { let token = read_refresh_session_token( "genarrative_refresh_session=refresh%2Ftoken%3D01", &build_refresh_cookie_config(), ); assert_eq!(token.as_deref(), Some("refresh/token=01")); } #[test] fn read_refresh_session_token_returns_none_when_missing() { let token = read_refresh_session_token("theme=dark; locale=zh-CN", &build_refresh_cookie_config()); assert!(token.is_none()); } #[tokio::test] async fn hash_and_verify_password_round_trip() { let password_hash = hash_password("secret123") .await .expect("password hash should build"); let is_valid = verify_password(&password_hash, "secret123") .await .expect("password hash should verify"); assert!(is_valid); } #[test] fn build_refresh_session_cookie_respects_config() { let cookie = build_refresh_session_set_cookie("refresh/token=01", &build_refresh_cookie_config()); assert!(cookie.contains("genarrative_refresh_session=refresh%2Ftoken%3D01")); assert!(cookie.contains("Path=/api/auth")); assert!(cookie.contains("HttpOnly")); assert!(cookie.contains("SameSite=Lax")); assert!(cookie.contains("Max-Age=2592000")); } #[test] fn hash_refresh_session_token_matches_sha256_hex() { let hash = hash_refresh_session_token("refresh-token-01"); assert_eq!( hash, "9fab76f9100ec6c151c8caa0c42ab10e10fbc7618f15e24cf3dffc93e19c4c4e" ); } #[test] fn build_refresh_session_clear_cookie_respects_config() { let cookie = build_refresh_session_clear_cookie(&build_refresh_cookie_config()); assert!(cookie.contains("genarrative_refresh_session=")); assert!(cookie.contains("Path=/api/auth")); assert!(cookie.contains("HttpOnly")); assert!(cookie.contains("SameSite=Lax")); assert!(cookie.contains("Max-Age=0")); } }