feat: add refresh token rotation flow

This commit is contained in:
2026-04-21 15:27:04 +08:00
parent 70dbefda2b
commit 584a77e572
16 changed files with 1048 additions and 85 deletions

View File

@@ -19,6 +19,7 @@ use crate::{
error_middleware::normalize_error_response,
health::health_check,
password_entry::password_entry,
refresh_session::refresh_session,
request_context::{attach_request_context, resolve_request_id},
response_headers::propagate_request_id_header,
state::AppState,
@@ -54,6 +55,13 @@ pub fn build_router(state: AppState) -> Router {
require_bearer_auth,
)),
)
.route(
"/api/auth/refresh",
post(refresh_session).route_layer(middleware::from_fn_with_state(
state.clone(),
attach_refresh_session_token,
)),
)
.route(
"/api/assets/direct-upload-tickets",
post(create_direct_upload_ticket),
@@ -616,4 +624,112 @@ mod tests {
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn refresh_session_rotates_cookie_and_returns_new_access_token() {
let app = build_router(AppState::new(AppConfig::default()).expect("state should build"));
let login_response = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/api/auth/entry")
.header("content-type", "application/json")
.body(Body::from(
serde_json::json!({
"username": "guest_refresh",
"password": "secret123"
})
.to_string(),
))
.expect("login request should build"),
)
.await
.expect("login request should succeed");
let first_cookie = login_response
.headers()
.get("set-cookie")
.and_then(|value| value.to_str().ok())
.expect("refresh cookie should exist")
.to_string();
let refresh_response = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/api/auth/refresh")
.header("cookie", first_cookie.clone())
.body(Body::empty())
.expect("refresh request should build"),
)
.await
.expect("refresh request should succeed");
assert_eq!(refresh_response.status(), StatusCode::OK);
let second_cookie = refresh_response
.headers()
.get("set-cookie")
.and_then(|value| value.to_str().ok())
.expect("rotated refresh cookie should exist")
.to_string();
assert_ne!(first_cookie, second_cookie);
let refresh_body = refresh_response
.into_body()
.collect()
.await
.expect("refresh body should collect")
.to_bytes();
let refresh_payload: Value =
serde_json::from_slice(&refresh_body).expect("refresh payload should be json");
assert!(refresh_payload["token"].as_str().is_some());
let stale_refresh_response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/api/auth/refresh")
.header("cookie", first_cookie)
.body(Body::empty())
.expect("stale refresh request should build"),
)
.await
.expect("stale refresh request should succeed");
assert_eq!(stale_refresh_response.status(), StatusCode::UNAUTHORIZED);
assert!(
stale_refresh_response
.headers()
.get("set-cookie")
.and_then(|value| value.to_str().ok())
.is_some_and(|value| value.contains("Max-Age=0"))
);
}
#[tokio::test]
async fn refresh_session_rejects_missing_cookie_and_clears_cookie() {
let app = build_router(AppState::new(AppConfig::default()).expect("state should build"));
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/api/auth/refresh")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("request should succeed");
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
assert!(
response
.headers()
.get("set-cookie")
.and_then(|value| value.to_str().ok())
.is_some_and(|value| value.contains("Max-Age=0"))
);
}
}

View File

@@ -0,0 +1,132 @@
use axum::http::{
HeaderMap, HeaderValue, StatusCode,
header::SET_COOKIE,
};
use module_auth::{
AuthLoginMethod, AuthUser, CreateRefreshSessionInput, RefreshSessionError,
};
use platform_auth::{
AccessTokenClaims, AccessTokenClaimsInput, AuthProvider, BindingStatus,
build_refresh_session_clear_cookie, build_refresh_session_set_cookie,
create_refresh_session_token, hash_refresh_session_token, sign_access_token,
};
use time::OffsetDateTime;
use crate::{http_error::AppError, state::AppState};
#[derive(Debug, Clone)]
pub struct SignedAuthSession {
pub access_token: String,
pub refresh_token: String,
}
pub fn create_password_auth_session(
state: &AppState,
user: &AuthUser,
) -> Result<SignedAuthSession, AppError> {
let refresh_token = create_refresh_session_token();
let refresh_token_hash = hash_refresh_session_token(&refresh_token);
let session = state
.refresh_session_service()
.create_session(
CreateRefreshSessionInput {
user_id: user.id.clone(),
refresh_token_hash,
issued_by_provider: AuthLoginMethod::Password,
},
OffsetDateTime::now_utc(),
)
.map_err(map_refresh_session_error)?;
let access_token = sign_access_token_for_user(state, user, &session.session.session_id)?;
Ok(SignedAuthSession {
access_token,
refresh_token,
})
}
pub fn sign_access_token_for_user(
state: &AppState,
user: &AuthUser,
session_id: &str,
) -> Result<String, AppError> {
let access_claims = AccessTokenClaims::from_input(
AccessTokenClaimsInput {
user_id: user.id.clone(),
session_id: session_id.to_string(),
provider: map_auth_provider(&user.login_method),
roles: vec!["user".to_string()],
token_version: user.token_version,
phone_verified: user.phone_number_masked.is_some(),
binding_status: map_binding_status(&user.binding_status),
display_name: Some(user.display_name.clone()),
},
state.auth_jwt_config(),
OffsetDateTime::now_utc(),
)
.map_err(|error| {
AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR).with_message(error.to_string())
})?;
sign_access_token(&access_claims, state.auth_jwt_config()).map_err(|error| {
AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR).with_message(error.to_string())
})
}
pub fn build_refresh_session_cookie_header(
state: &AppState,
refresh_token: &str,
) -> Result<HeaderValue, AppError> {
let refresh_cookie =
build_refresh_session_set_cookie(refresh_token, state.refresh_cookie_config());
HeaderValue::from_str(&refresh_cookie).map_err(|error| {
AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR)
.with_message(format!("refresh cookie 头构造失败:{error}"))
})
}
pub fn build_clear_refresh_session_cookie_header(
state: &AppState,
) -> Result<HeaderValue, AppError> {
let refresh_cookie = build_refresh_session_clear_cookie(state.refresh_cookie_config());
HeaderValue::from_str(&refresh_cookie).map_err(|error| {
AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR)
.with_message(format!("refresh cookie 头构造失败:{error}"))
})
}
pub fn attach_set_cookie_header(
headers: &mut HeaderMap,
set_cookie: HeaderValue,
) {
headers.insert(SET_COOKIE, set_cookie);
}
pub fn map_refresh_session_error(error: RefreshSessionError) -> AppError {
match error {
RefreshSessionError::MissingToken
| RefreshSessionError::SessionNotFound
| RefreshSessionError::SessionExpired
| RefreshSessionError::UserNotFound => {
AppError::from_status(StatusCode::UNAUTHORIZED).with_message(error.to_string())
}
RefreshSessionError::Store(message) => {
AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR).with_message(message)
}
}
}
fn map_auth_provider(login_method: &AuthLoginMethod) -> AuthProvider {
match login_method {
AuthLoginMethod::Password => AuthProvider::Password,
AuthLoginMethod::Phone => AuthProvider::Phone,
AuthLoginMethod::Wechat => AuthProvider::Wechat,
}
}
fn map_binding_status(binding_status: &module_auth::AuthBindingStatus) -> BindingStatus {
match binding_status {
module_auth::AuthBindingStatus::Active => BindingStatus::Active,
module_auth::AuthBindingStatus::PendingBindPhone => BindingStatus::PendingBindPhone,
}
}

View File

@@ -1,4 +1,5 @@
use axum::{
http::{HeaderMap, HeaderValue},
http::StatusCode,
response::{IntoResponse, Response},
};
@@ -13,6 +14,7 @@ pub struct AppError {
code: &'static str,
message: String,
details: Option<Value>,
headers: HeaderMap,
}
#[derive(Clone, Debug, Serialize)]
@@ -32,6 +34,7 @@ impl AppError {
code,
message: message.to_string(),
details: None,
headers: HeaderMap::new(),
}
}
@@ -49,11 +52,17 @@ impl AppError {
self
}
pub fn with_header(mut self, name: &'static str, value: HeaderValue) -> Self {
self.headers.insert(name, value);
self
}
pub fn into_response_with_context(self, request_context: Option<&RequestContext>) -> Response {
let status_code = self.status_code;
let payload = self.to_payload();
(status_code, json_error_body(request_context, &payload)).into_response()
let mut response = (status_code, json_error_body(request_context, &payload)).into_response();
response.headers_mut().extend(self.headers);
response
}
fn to_payload(&self) -> ApiErrorPayload {

View File

@@ -2,12 +2,14 @@ mod api_response;
mod app;
mod assets;
mod auth;
mod auth_session;
mod auth_me;
mod config;
mod error_middleware;
mod health;
mod http_error;
mod password_entry;
mod refresh_session;
mod request_context;
mod response_headers;
mod state;

View File

@@ -1,20 +1,20 @@
use axum::{
Json,
extract::{Extension, State},
http::{HeaderMap, HeaderValue, StatusCode, header::SET_COOKIE},
http::{HeaderMap, StatusCode},
response::IntoResponse,
};
use module_auth::{PasswordEntryError, PasswordEntryInput};
use platform_auth::{
AccessTokenClaims, AccessTokenClaimsInput, AuthProvider, BindingStatus,
build_refresh_session_set_cookie, create_refresh_session_token, sign_access_token,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
use time::OffsetDateTime;
use crate::{
api_response::json_success_body, http_error::AppError, request_context::RequestContext,
api_response::json_success_body,
auth_session::{
attach_set_cookie_header, build_refresh_session_cookie_header, create_password_auth_session,
},
http_error::AppError,
request_context::RequestContext,
state::AppState,
};
@@ -57,45 +57,20 @@ pub async fn password_entry(
})
.await
.map_err(map_password_entry_error)?;
let refresh_session_token = create_refresh_session_token();
let access_claims = AccessTokenClaims::from_input(
AccessTokenClaimsInput {
user_id: result.user.id.clone(),
session_id: refresh_session_token.clone(),
provider: AuthProvider::Password,
roles: vec!["user".to_string()],
token_version: result.user.token_version,
phone_verified: false,
binding_status: BindingStatus::Active,
display_name: Some(result.user.display_name.clone()),
},
state.auth_jwt_config(),
OffsetDateTime::now_utc(),
)
.map_err(|error| {
AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR).with_message(error.to_string())
})?;
let access_token =
sign_access_token(&access_claims, state.auth_jwt_config()).map_err(|error| {
AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR).with_message(error.to_string())
})?;
let refresh_cookie =
build_refresh_session_set_cookie(&refresh_session_token, state.refresh_cookie_config());
let signed_session = create_password_auth_session(&state, &result.user)?;
let mut headers = HeaderMap::new();
let set_cookie = HeaderValue::from_str(&refresh_cookie).map_err(|error| {
AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR)
.with_message(format!("refresh cookie 头构造失败:{error}"))
})?;
headers.insert(SET_COOKIE, set_cookie);
attach_set_cookie_header(
&mut headers,
build_refresh_session_cookie_header(&state, &signed_session.refresh_token)?,
);
Ok((
headers,
json_success_body(
Some(&request_context),
PasswordEntryResponse {
token: access_token,
token: signed_session.access_token,
user: PasswordEntryUserPayload {
id: result.user.id,
username: result.user.username,

View File

@@ -0,0 +1,84 @@
use axum::{
extract::{Extension, State},
http::HeaderMap,
response::IntoResponse,
};
use module_auth::{RefreshSessionError, RotateRefreshSessionInput};
use platform_auth::hash_refresh_session_token;
use serde::Serialize;
use time::OffsetDateTime;
use crate::{
api_response::json_success_body,
auth::RefreshSessionToken,
auth_session::{
attach_set_cookie_header, build_clear_refresh_session_cookie_header,
build_refresh_session_cookie_header, map_refresh_session_error, sign_access_token_for_user,
},
http_error::AppError,
request_context::RequestContext,
state::AppState,
};
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct RefreshSessionResponse {
pub token: String,
}
pub async fn refresh_session(
State(state): State<AppState>,
Extension(request_context): Extension<RequestContext>,
maybe_refresh_token: Option<Extension<RefreshSessionToken>>,
) -> Result<impl IntoResponse, AppError> {
let raw_refresh_token = maybe_refresh_token
.map(|token| token.0.token().to_string())
.unwrap_or_default();
if raw_refresh_token.trim().is_empty() {
return Err(map_refresh_error_with_clear_cookie(
&state,
RefreshSessionError::MissingToken,
));
}
let refresh_token_hash = hash_refresh_session_token(&raw_refresh_token);
let next_refresh_token = platform_auth::create_refresh_session_token();
let next_refresh_token_hash = hash_refresh_session_token(&next_refresh_token);
let rotated = state
.refresh_session_service()
.rotate_session(
RotateRefreshSessionInput {
refresh_token_hash,
next_refresh_token_hash,
},
OffsetDateTime::now_utc(),
)
.map_err(|error| map_refresh_error_with_clear_cookie(&state, error))?;
let access_token =
sign_access_token_for_user(&state, &rotated.user, &rotated.session.session_id)?;
let mut headers = HeaderMap::new();
attach_set_cookie_header(
&mut headers,
build_refresh_session_cookie_header(&state, &next_refresh_token)?,
);
Ok((
headers,
json_success_body(
Some(&request_context),
RefreshSessionResponse {
token: access_token,
},
),
))
}
fn map_refresh_error_with_clear_cookie(state: &AppState, error: RefreshSessionError) -> AppError {
let response_error = map_refresh_session_error(error);
if let Ok(set_cookie) = build_clear_refresh_session_cookie_header(state) {
return response_error.with_header("set-cookie", set_cookie);
}
response_error
}

View File

@@ -1,6 +1,6 @@
use std::{error::Error, fmt};
use module_auth::{InMemoryPasswordUserStore, PasswordEntryService};
use module_auth::{InMemoryAuthStore, PasswordEntryService, RefreshSessionService};
use platform_auth::{
JwtConfig, JwtError, RefreshCookieConfig, RefreshCookieError, RefreshCookieSameSite,
};
@@ -18,6 +18,7 @@ pub struct AppState {
refresh_cookie_config: RefreshCookieConfig,
oss_client: Option<OssClient>,
password_entry_service: PasswordEntryService,
refresh_session_service: RefreshSessionService,
}
#[derive(Debug)]
@@ -46,8 +47,10 @@ impl AppState {
config.refresh_session_ttl_days,
)?;
let oss_client = build_oss_client(&config)?;
let password_entry_service =
PasswordEntryService::new(InMemoryPasswordUserStore::default());
let auth_store = InMemoryAuthStore::default();
let password_entry_service = PasswordEntryService::new(auth_store.clone());
let refresh_session_service =
RefreshSessionService::new(auth_store, config.refresh_session_ttl_days);
Ok(Self {
config,
@@ -55,6 +58,7 @@ impl AppState {
refresh_cookie_config,
oss_client,
password_entry_service,
refresh_session_service,
})
}
@@ -73,6 +77,10 @@ impl AppState {
pub fn password_entry_service(&self) -> &PasswordEntryService {
&self.password_entry_service
}
pub fn refresh_session_service(&self) -> &RefreshSessionService {
&self.refresh_session_service
}
}
impl fmt::Display for AppStateInitError {
@@ -109,8 +117,7 @@ fn build_oss_client(config: &AppConfig) -> Result<Option<OssClient>, AppStateIni
let has_any_oss_field = config.oss_bucket.is_some()
|| config.oss_endpoint.is_some()
|| config.oss_access_key_id.is_some()
|| config.oss_access_key_secret.is_some()
|| config.oss_public_base_url.is_some();
|| config.oss_access_key_secret.is_some();
if !has_any_oss_field {
return Ok(None);
@@ -121,7 +128,7 @@ fn build_oss_client(config: &AppConfig) -> Result<Option<OssClient>, AppStateIni
config.oss_endpoint.clone().unwrap_or_default(),
config.oss_access_key_id.clone().unwrap_or_default(),
config.oss_access_key_secret.clone().unwrap_or_default(),
config.oss_public_base_url.clone(),
config.oss_read_expire_seconds,
config.oss_post_expire_seconds,
config.oss_post_max_size_bytes,
config.oss_success_action_status,