183 lines
4.9 KiB
Rust
183 lines
4.9 KiB
Rust
use axum::{
|
||
Json,
|
||
extract::{Extension, Request, State},
|
||
http::{
|
||
HeaderMap, StatusCode,
|
||
header::{AUTHORIZATION, COOKIE},
|
||
},
|
||
middleware::Next,
|
||
response::Response,
|
||
};
|
||
use platform_auth::{AccessTokenClaims, read_refresh_session_token, verify_access_token};
|
||
use serde_json::{Value, json};
|
||
use tracing::warn;
|
||
|
||
use crate::{
|
||
api_response::json_success_body, http_error::AppError, request_context::RequestContext,
|
||
state::AppState,
|
||
};
|
||
|
||
// 统一把已校验的 claims 写入 request extensions,避免后续 handler 再次重复解析 Bearer token。
|
||
#[derive(Clone, Debug)]
|
||
pub struct AuthenticatedAccessToken {
|
||
claims: AccessTokenClaims,
|
||
}
|
||
|
||
#[derive(Clone, Debug)]
|
||
pub struct RefreshSessionToken {
|
||
token: String,
|
||
}
|
||
|
||
impl AuthenticatedAccessToken {
|
||
pub fn new(claims: AccessTokenClaims) -> Self {
|
||
Self { claims }
|
||
}
|
||
|
||
pub fn claims(&self) -> &AccessTokenClaims {
|
||
&self.claims
|
||
}
|
||
}
|
||
|
||
impl RefreshSessionToken {
|
||
pub fn new(token: String) -> Self {
|
||
Self { token }
|
||
}
|
||
|
||
pub fn token(&self) -> &str {
|
||
&self.token
|
||
}
|
||
}
|
||
|
||
pub async fn require_bearer_auth(
|
||
State(state): State<AppState>,
|
||
mut request: Request,
|
||
next: Next,
|
||
) -> Result<Response, AppError> {
|
||
let bearer_token = extract_bearer_token(request.headers())?;
|
||
let request_id = request
|
||
.extensions()
|
||
.get::<RequestContext>()
|
||
.map(|context| context.request_id().to_string())
|
||
.unwrap_or_else(|| "unknown".to_string());
|
||
let claims = verify_access_token(&bearer_token, state.auth_jwt_config()).map_err(|error| {
|
||
warn!(
|
||
%request_id,
|
||
error = %error,
|
||
"Bearer JWT 校验失败"
|
||
);
|
||
AppError::from_status(StatusCode::UNAUTHORIZED)
|
||
})?;
|
||
|
||
request
|
||
.extensions_mut()
|
||
.insert(AuthenticatedAccessToken::new(claims));
|
||
|
||
Ok(next.run(request).await)
|
||
}
|
||
|
||
pub async fn inspect_auth_claims(
|
||
Extension(request_context): Extension<RequestContext>,
|
||
Extension(authenticated): Extension<AuthenticatedAccessToken>,
|
||
) -> Json<Value> {
|
||
json_success_body(
|
||
Some(&request_context),
|
||
json!({
|
||
"claims": authenticated.claims(),
|
||
}),
|
||
)
|
||
}
|
||
|
||
pub async fn attach_refresh_session_token(
|
||
State(state): State<AppState>,
|
||
mut request: Request,
|
||
next: Next,
|
||
) -> Response {
|
||
if let Some(token) = request
|
||
.headers()
|
||
.get(COOKIE)
|
||
.and_then(|value| value.to_str().ok())
|
||
.and_then(|cookie_header| {
|
||
read_refresh_session_token(cookie_header, state.refresh_cookie_config())
|
||
})
|
||
{
|
||
request
|
||
.extensions_mut()
|
||
.insert(RefreshSessionToken::new(token));
|
||
}
|
||
|
||
next.run(request).await
|
||
}
|
||
|
||
pub async fn inspect_refresh_session_cookie(
|
||
State(state): State<AppState>,
|
||
Extension(request_context): Extension<RequestContext>,
|
||
request: Request,
|
||
) -> Json<Value> {
|
||
let maybe_token = request.extensions().get::<RefreshSessionToken>();
|
||
|
||
json_success_body(
|
||
Some(&request_context),
|
||
json!({
|
||
"cookieName": state.refresh_cookie_config().cookie_name(),
|
||
"present": maybe_token.is_some(),
|
||
"tokenLength": maybe_token.map(|token| token.token().len()),
|
||
}),
|
||
)
|
||
}
|
||
|
||
fn extract_bearer_token(headers: &HeaderMap) -> Result<String, AppError> {
|
||
let authorization = headers
|
||
.get(AUTHORIZATION)
|
||
.and_then(|value| value.to_str().ok())
|
||
.map(str::trim)
|
||
.ok_or_else(|| AppError::from_status(StatusCode::UNAUTHORIZED))?;
|
||
|
||
let token = authorization
|
||
.strip_prefix("Bearer ")
|
||
.or_else(|| authorization.strip_prefix("bearer "))
|
||
.map(str::trim)
|
||
.filter(|token| !token.is_empty())
|
||
.ok_or_else(|| AppError::from_status(StatusCode::UNAUTHORIZED))?;
|
||
|
||
Ok(token.to_string())
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::{RefreshSessionToken, extract_bearer_token};
|
||
use axum::{
|
||
http::{HeaderMap, HeaderValue, StatusCode, header::AUTHORIZATION},
|
||
response::IntoResponse,
|
||
};
|
||
|
||
#[test]
|
||
fn extract_bearer_token_accepts_standard_header() {
|
||
let mut headers = HeaderMap::new();
|
||
headers.insert(
|
||
AUTHORIZATION,
|
||
HeaderValue::from_static("Bearer token-value"),
|
||
);
|
||
|
||
let token = extract_bearer_token(&headers).expect("bearer token should be extracted");
|
||
|
||
assert_eq!(token, "token-value");
|
||
}
|
||
|
||
#[test]
|
||
fn extract_bearer_token_rejects_missing_scheme() {
|
||
let mut headers = HeaderMap::new();
|
||
headers.insert(AUTHORIZATION, HeaderValue::from_static("Basic abc"));
|
||
|
||
let error = extract_bearer_token(&headers).expect_err("basic auth should be rejected");
|
||
|
||
assert_eq!(error.into_response().status(), StatusCode::UNAUTHORIZED);
|
||
}
|
||
|
||
#[test]
|
||
fn refresh_session_token_retains_original_value() {
|
||
let token = RefreshSessionToken::new("refresh-token-01".to_string());
|
||
|
||
assert_eq!(token.token(), "refresh-token-01");
|
||
}
|
||
}
|