Files
Genarrative/server-rs/crates/api-server/src/auth.rs

213 lines
6.0 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use axum::{
Json,
extract::{Extension, Request, State},
http::{
HeaderMap, StatusCode,
header::{AUTHORIZATION, COOKIE},
},
middleware::Next,
response::Response,
};
use platform_auth::{AccessTokenClaims, read_refresh_session_token, verify_access_token};
use serde_json::{Value, json};
use tracing::warn;
use crate::{
api_response::json_success_body, http_error::AppError, request_context::RequestContext,
state::AppState,
};
// 统一把已校验的 claims 写入 request extensions避免后续 handler 再次重复解析 Bearer token。
#[derive(Clone, Debug)]
pub struct AuthenticatedAccessToken {
claims: AccessTokenClaims,
}
#[derive(Clone, Debug)]
pub struct RefreshSessionToken {
token: String,
}
impl AuthenticatedAccessToken {
pub fn new(claims: AccessTokenClaims) -> Self {
Self { claims }
}
pub fn claims(&self) -> &AccessTokenClaims {
&self.claims
}
}
impl RefreshSessionToken {
pub fn new(token: String) -> Self {
Self { token }
}
pub fn token(&self) -> &str {
&self.token
}
}
pub async fn require_bearer_auth(
State(state): State<AppState>,
mut request: Request,
next: Next,
) -> Result<Response, AppError> {
let bearer_token = extract_bearer_token(request.headers())?;
let request_id = request
.extensions()
.get::<RequestContext>()
.map(|context| context.request_id().to_string())
.unwrap_or_else(|| "unknown".to_string());
let claims = verify_access_token(&bearer_token, state.auth_jwt_config()).map_err(|error| {
warn!(
%request_id,
error = %error,
"Bearer JWT 校验失败"
);
AppError::from_status(StatusCode::UNAUTHORIZED)
})?;
let current_user = state
.auth_user_service()
.get_user_by_id(claims.user_id())
.map_err(|error| {
warn!(
%request_id,
error = %error,
"Bearer JWT 用户快照读取失败"
);
AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR)
})?
.ok_or_else(|| {
warn!(
%request_id,
user_id = %claims.user_id(),
"Bearer JWT 对应用户不存在"
);
AppError::from_status(StatusCode::UNAUTHORIZED)
})?;
if current_user.token_version != claims.token_version() {
warn!(
%request_id,
user_id = %claims.user_id(),
token_version = claims.token_version(),
current_token_version = current_user.token_version,
"Bearer JWT 版本已失效"
);
return Err(AppError::from_status(StatusCode::UNAUTHORIZED)
.with_message("当前登录态已失效,请重新登录"));
}
request
.extensions_mut()
.insert(AuthenticatedAccessToken::new(claims));
Ok(next.run(request).await)
}
pub async fn inspect_auth_claims(
Extension(request_context): Extension<RequestContext>,
Extension(authenticated): Extension<AuthenticatedAccessToken>,
) -> Json<Value> {
json_success_body(
Some(&request_context),
json!({
"claims": authenticated.claims(),
}),
)
}
pub async fn attach_refresh_session_token(
State(state): State<AppState>,
mut request: Request,
next: Next,
) -> Response {
if let Some(token) = request
.headers()
.get(COOKIE)
.and_then(|value| value.to_str().ok())
.and_then(|cookie_header| {
read_refresh_session_token(cookie_header, state.refresh_cookie_config())
})
{
request
.extensions_mut()
.insert(RefreshSessionToken::new(token));
}
next.run(request).await
}
pub async fn inspect_refresh_session_cookie(
State(state): State<AppState>,
Extension(request_context): Extension<RequestContext>,
request: Request,
) -> Json<Value> {
let maybe_token = request.extensions().get::<RefreshSessionToken>();
json_success_body(
Some(&request_context),
json!({
"cookieName": state.refresh_cookie_config().cookie_name(),
"present": maybe_token.is_some(),
"tokenLength": maybe_token.map(|token| token.token().len()),
}),
)
}
fn extract_bearer_token(headers: &HeaderMap) -> Result<String, AppError> {
let authorization = headers
.get(AUTHORIZATION)
.and_then(|value| value.to_str().ok())
.map(str::trim)
.ok_or_else(|| AppError::from_status(StatusCode::UNAUTHORIZED))?;
let token = authorization
.strip_prefix("Bearer ")
.or_else(|| authorization.strip_prefix("bearer "))
.map(str::trim)
.filter(|token| !token.is_empty())
.ok_or_else(|| AppError::from_status(StatusCode::UNAUTHORIZED))?;
Ok(token.to_string())
}
#[cfg(test)]
mod tests {
use super::{RefreshSessionToken, extract_bearer_token};
use axum::{
http::{HeaderMap, HeaderValue, StatusCode, header::AUTHORIZATION},
response::IntoResponse,
};
#[test]
fn extract_bearer_token_accepts_standard_header() {
let mut headers = HeaderMap::new();
headers.insert(
AUTHORIZATION,
HeaderValue::from_static("Bearer token-value"),
);
let token = extract_bearer_token(&headers).expect("bearer token should be extracted");
assert_eq!(token, "token-value");
}
#[test]
fn extract_bearer_token_rejects_missing_scheme() {
let mut headers = HeaderMap::new();
headers.insert(AUTHORIZATION, HeaderValue::from_static("Basic abc"));
let error = extract_bearer_token(&headers).expect_err("basic auth should be rejected");
assert_eq!(error.into_response().status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn refresh_session_token_retains_original_value() {
let token = RefreshSessionToken::new("refresh-token-01".to_string());
assert_eq!(token.token(), "refresh-token-01");
}
}