Files
Genarrative/server-rs/crates/api-server/src/auth.rs

183 lines
4.9 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use axum::{
Json,
extract::{Extension, Request, State},
http::{
HeaderMap, StatusCode,
header::{AUTHORIZATION, COOKIE},
},
middleware::Next,
response::Response,
};
use platform_auth::{AccessTokenClaims, read_refresh_session_token, verify_access_token};
use serde_json::{Value, json};
use tracing::warn;
use crate::{
api_response::json_success_body, http_error::AppError, request_context::RequestContext,
state::AppState,
};
// 统一把已校验的 claims 写入 request extensions避免后续 handler 再次重复解析 Bearer token。
#[derive(Clone, Debug)]
pub struct AuthenticatedAccessToken {
claims: AccessTokenClaims,
}
#[derive(Clone, Debug)]
pub struct RefreshSessionToken {
token: String,
}
impl AuthenticatedAccessToken {
pub fn new(claims: AccessTokenClaims) -> Self {
Self { claims }
}
pub fn claims(&self) -> &AccessTokenClaims {
&self.claims
}
}
impl RefreshSessionToken {
pub fn new(token: String) -> Self {
Self { token }
}
pub fn token(&self) -> &str {
&self.token
}
}
pub async fn require_bearer_auth(
State(state): State<AppState>,
mut request: Request,
next: Next,
) -> Result<Response, AppError> {
let bearer_token = extract_bearer_token(request.headers())?;
let request_id = request
.extensions()
.get::<RequestContext>()
.map(|context| context.request_id().to_string())
.unwrap_or_else(|| "unknown".to_string());
let claims = verify_access_token(&bearer_token, state.auth_jwt_config()).map_err(|error| {
warn!(
%request_id,
error = %error,
"Bearer JWT 校验失败"
);
AppError::from_status(StatusCode::UNAUTHORIZED)
})?;
request
.extensions_mut()
.insert(AuthenticatedAccessToken::new(claims));
Ok(next.run(request).await)
}
pub async fn inspect_auth_claims(
Extension(request_context): Extension<RequestContext>,
Extension(authenticated): Extension<AuthenticatedAccessToken>,
) -> Json<Value> {
json_success_body(
Some(&request_context),
json!({
"claims": authenticated.claims(),
}),
)
}
pub async fn attach_refresh_session_token(
State(state): State<AppState>,
mut request: Request,
next: Next,
) -> Response {
if let Some(token) = request
.headers()
.get(COOKIE)
.and_then(|value| value.to_str().ok())
.and_then(|cookie_header| {
read_refresh_session_token(cookie_header, state.refresh_cookie_config())
})
{
request
.extensions_mut()
.insert(RefreshSessionToken::new(token));
}
next.run(request).await
}
pub async fn inspect_refresh_session_cookie(
State(state): State<AppState>,
Extension(request_context): Extension<RequestContext>,
request: Request,
) -> Json<Value> {
let maybe_token = request.extensions().get::<RefreshSessionToken>();
json_success_body(
Some(&request_context),
json!({
"cookieName": state.refresh_cookie_config().cookie_name(),
"present": maybe_token.is_some(),
"tokenLength": maybe_token.map(|token| token.token().len()),
}),
)
}
fn extract_bearer_token(headers: &HeaderMap) -> Result<String, AppError> {
let authorization = headers
.get(AUTHORIZATION)
.and_then(|value| value.to_str().ok())
.map(str::trim)
.ok_or_else(|| AppError::from_status(StatusCode::UNAUTHORIZED))?;
let token = authorization
.strip_prefix("Bearer ")
.or_else(|| authorization.strip_prefix("bearer "))
.map(str::trim)
.filter(|token| !token.is_empty())
.ok_or_else(|| AppError::from_status(StatusCode::UNAUTHORIZED))?;
Ok(token.to_string())
}
#[cfg(test)]
mod tests {
use super::{RefreshSessionToken, extract_bearer_token};
use axum::{
http::{HeaderMap, HeaderValue, StatusCode, header::AUTHORIZATION},
response::IntoResponse,
};
#[test]
fn extract_bearer_token_accepts_standard_header() {
let mut headers = HeaderMap::new();
headers.insert(
AUTHORIZATION,
HeaderValue::from_static("Bearer token-value"),
);
let token = extract_bearer_token(&headers).expect("bearer token should be extracted");
assert_eq!(token, "token-value");
}
#[test]
fn extract_bearer_token_rejects_missing_scheme() {
let mut headers = HeaderMap::new();
headers.insert(AUTHORIZATION, HeaderValue::from_static("Basic abc"));
let error = extract_bearer_token(&headers).expect_err("basic auth should be rejected");
assert_eq!(error.into_response().status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn refresh_session_token_retains_original_value() {
let token = RefreshSessionToken::new("refresh-token-01".to_string());
assert_eq!(token.token(), "refresh-token-01");
}
}