feat: add refresh token rotation flow
This commit is contained in:
@@ -6,6 +6,8 @@ use std::{
|
||||
};
|
||||
|
||||
use platform_auth::{hash_password, verify_password};
|
||||
use time::{Duration, OffsetDateTime};
|
||||
use uuid::Uuid;
|
||||
|
||||
const USERNAME_MIN_LENGTH: usize = 3;
|
||||
const USERNAME_MAX_LENGTH: usize = 24;
|
||||
@@ -37,6 +39,11 @@ pub struct AuthUser {
|
||||
pub token_version: u64,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct AuthMeResult {
|
||||
pub user: AuthUser,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct PasswordEntryInput {
|
||||
pub username: String,
|
||||
@@ -50,7 +57,39 @@ pub struct PasswordEntryResult {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct AuthMeResult {
|
||||
pub struct CreateRefreshSessionInput {
|
||||
pub user_id: String,
|
||||
pub refresh_token_hash: String,
|
||||
pub issued_by_provider: AuthLoginMethod,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct RefreshSessionRecord {
|
||||
pub session_id: String,
|
||||
pub user_id: String,
|
||||
pub refresh_token_hash: String,
|
||||
pub issued_by_provider: AuthLoginMethod,
|
||||
pub expires_at: String,
|
||||
pub revoked_at: Option<String>,
|
||||
pub created_at: String,
|
||||
pub updated_at: String,
|
||||
pub last_seen_at: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct CreateRefreshSessionResult {
|
||||
pub session: RefreshSessionRecord,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct RotateRefreshSessionInput {
|
||||
pub refresh_token_hash: String,
|
||||
pub next_refresh_token_hash: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct RotateRefreshSessionResult {
|
||||
pub session: RefreshSessionRecord,
|
||||
pub user: AuthUser,
|
||||
}
|
||||
|
||||
@@ -63,15 +102,26 @@ pub enum PasswordEntryError {
|
||||
PasswordHash(String),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum RefreshSessionError {
|
||||
MissingToken,
|
||||
SessionNotFound,
|
||||
SessionExpired,
|
||||
UserNotFound,
|
||||
Store(String),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct InMemoryPasswordUserStore {
|
||||
inner: Arc<Mutex<InMemoryPasswordUserStoreState>>,
|
||||
pub struct InMemoryAuthStore {
|
||||
inner: Arc<Mutex<InMemoryAuthStoreState>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct InMemoryPasswordUserStoreState {
|
||||
next_id: u64,
|
||||
struct InMemoryAuthStoreState {
|
||||
next_user_id: u64,
|
||||
users_by_username: HashMap<String, StoredPasswordUser>,
|
||||
sessions_by_id: HashMap<String, StoredRefreshSession>,
|
||||
session_id_by_refresh_token_hash: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -80,13 +130,24 @@ struct StoredPasswordUser {
|
||||
password_hash: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct StoredRefreshSession {
|
||||
session: RefreshSessionRecord,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PasswordEntryService {
|
||||
store: InMemoryPasswordUserStore,
|
||||
store: InMemoryAuthStore,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RefreshSessionService {
|
||||
store: InMemoryAuthStore,
|
||||
refresh_session_ttl_days: u32,
|
||||
}
|
||||
|
||||
impl PasswordEntryService {
|
||||
pub fn new(store: InMemoryPasswordUserStore) -> Self {
|
||||
pub fn new(store: InMemoryAuthStore) -> Self {
|
||||
Self { store }
|
||||
}
|
||||
|
||||
@@ -114,10 +175,7 @@ impl PasswordEntryService {
|
||||
let password_hash = hash_password(&input.password)
|
||||
.await
|
||||
.map_err(|error| PasswordEntryError::PasswordHash(error.to_string()))?;
|
||||
match self
|
||||
.store
|
||||
.create_user(username.clone(), password_hash.clone())
|
||||
{
|
||||
match self.store.create_user(username.clone(), password_hash) {
|
||||
Ok(user) => Ok(PasswordEntryResult {
|
||||
user,
|
||||
created: true,
|
||||
@@ -141,9 +199,7 @@ impl PasswordEntryService {
|
||||
Err(CreateUserError::Store(message)) => Err(PasswordEntryError::Store(message)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PasswordEntryService {
|
||||
pub fn get_user_by_id(
|
||||
&self,
|
||||
user_id: &str,
|
||||
@@ -154,18 +210,129 @@ impl PasswordEntryService {
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for InMemoryPasswordUserStore {
|
||||
impl RefreshSessionService {
|
||||
pub fn new(store: InMemoryAuthStore, refresh_session_ttl_days: u32) -> Self {
|
||||
Self {
|
||||
store,
|
||||
refresh_session_ttl_days,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_session(
|
||||
&self,
|
||||
input: CreateRefreshSessionInput,
|
||||
now: OffsetDateTime,
|
||||
) -> Result<CreateRefreshSessionResult, RefreshSessionError> {
|
||||
self.store
|
||||
.find_by_user_id(&input.user_id)
|
||||
.map_err(map_password_store_error)?
|
||||
.ok_or(RefreshSessionError::UserNotFound)?;
|
||||
|
||||
let session_id = format!("usess_{}", Uuid::new_v4().simple());
|
||||
let expires_at = now
|
||||
.checked_add(Duration::days(i64::from(self.refresh_session_ttl_days)))
|
||||
.ok_or_else(|| RefreshSessionError::Store("refresh session 过期时间计算溢出".to_string()))?;
|
||||
let now_iso = now.format(&time::format_description::well_known::Rfc3339).map_err(
|
||||
|error| RefreshSessionError::Store(format!("refresh session 时间格式化失败:{error}")),
|
||||
)?;
|
||||
let expires_at_iso = expires_at
|
||||
.format(&time::format_description::well_known::Rfc3339)
|
||||
.map_err(|error| {
|
||||
RefreshSessionError::Store(format!("refresh session 过期时间格式化失败:{error}"))
|
||||
})?;
|
||||
let session = RefreshSessionRecord {
|
||||
session_id,
|
||||
user_id: input.user_id,
|
||||
refresh_token_hash: input.refresh_token_hash,
|
||||
issued_by_provider: input.issued_by_provider,
|
||||
expires_at: expires_at_iso,
|
||||
revoked_at: None,
|
||||
created_at: now_iso.clone(),
|
||||
updated_at: now_iso.clone(),
|
||||
last_seen_at: now_iso,
|
||||
};
|
||||
|
||||
self.store.insert_session(session.clone())?;
|
||||
|
||||
Ok(CreateRefreshSessionResult { session })
|
||||
}
|
||||
|
||||
pub fn rotate_session(
|
||||
&self,
|
||||
input: RotateRefreshSessionInput,
|
||||
now: OffsetDateTime,
|
||||
) -> Result<RotateRefreshSessionResult, RefreshSessionError> {
|
||||
let refresh_token_hash = input.refresh_token_hash.trim().to_string();
|
||||
if refresh_token_hash.is_empty() {
|
||||
return Err(RefreshSessionError::MissingToken);
|
||||
}
|
||||
|
||||
let session = self
|
||||
.store
|
||||
.find_session_by_refresh_token_hash(&refresh_token_hash)?
|
||||
.ok_or(RefreshSessionError::SessionNotFound)?;
|
||||
|
||||
if session.session.revoked_at.is_some() {
|
||||
return Err(RefreshSessionError::SessionNotFound);
|
||||
}
|
||||
|
||||
let expires_at = OffsetDateTime::parse(
|
||||
&session.session.expires_at,
|
||||
&time::format_description::well_known::Rfc3339,
|
||||
)
|
||||
.map_err(|error| RefreshSessionError::Store(format!("refresh session 过期时间解析失败:{error}")))?;
|
||||
if expires_at <= now {
|
||||
return Err(RefreshSessionError::SessionExpired);
|
||||
}
|
||||
|
||||
let user = self
|
||||
.store
|
||||
.find_by_user_id(&session.session.user_id)
|
||||
.map_err(map_password_store_error)?
|
||||
.ok_or(RefreshSessionError::UserNotFound)?;
|
||||
|
||||
let next_expires_at = now
|
||||
.checked_add(Duration::days(i64::from(self.refresh_session_ttl_days)))
|
||||
.ok_or_else(|| RefreshSessionError::Store("refresh session 过期时间计算溢出".to_string()))?;
|
||||
let now_iso = now.format(&time::format_description::well_known::Rfc3339).map_err(
|
||||
|error| RefreshSessionError::Store(format!("refresh session 时间格式化失败:{error}")),
|
||||
)?;
|
||||
let next_expires_at_iso = next_expires_at
|
||||
.format(&time::format_description::well_known::Rfc3339)
|
||||
.map_err(|error| {
|
||||
RefreshSessionError::Store(format!("refresh session 过期时间格式化失败:{error}"))
|
||||
})?;
|
||||
|
||||
let updated_session = self.store.rotate_session(
|
||||
&session.session.session_id,
|
||||
&session.session.refresh_token_hash,
|
||||
input.next_refresh_token_hash,
|
||||
next_expires_at_iso,
|
||||
now_iso.clone(),
|
||||
now_iso,
|
||||
)?;
|
||||
|
||||
Ok(RotateRefreshSessionResult {
|
||||
session: updated_session.session,
|
||||
user: user.user,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for InMemoryAuthStore {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
inner: Arc::new(Mutex::new(InMemoryPasswordUserStoreState {
|
||||
next_id: 1,
|
||||
inner: Arc::new(Mutex::new(InMemoryAuthStoreState {
|
||||
next_user_id: 1,
|
||||
users_by_username: HashMap::new(),
|
||||
sessions_by_id: HashMap::new(),
|
||||
session_id_by_refresh_token_hash: HashMap::new(),
|
||||
})),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl InMemoryPasswordUserStore {
|
||||
impl InMemoryAuthStore {
|
||||
fn find_by_username(
|
||||
&self,
|
||||
username: &str,
|
||||
@@ -177,6 +344,22 @@ impl InMemoryPasswordUserStore {
|
||||
Ok(state.users_by_username.get(username).cloned())
|
||||
}
|
||||
|
||||
fn find_by_user_id(
|
||||
&self,
|
||||
user_id: &str,
|
||||
) -> Result<Option<StoredPasswordUser>, PasswordEntryError> {
|
||||
let state = self
|
||||
.inner
|
||||
.lock()
|
||||
.map_err(|_| PasswordEntryError::Store("用户仓储锁已中毒".to_string()))?;
|
||||
|
||||
Ok(state
|
||||
.users_by_username
|
||||
.values()
|
||||
.find(|stored_user| stored_user.user.id == user_id)
|
||||
.cloned())
|
||||
}
|
||||
|
||||
fn create_user(
|
||||
&self,
|
||||
username: String,
|
||||
@@ -191,8 +374,8 @@ impl InMemoryPasswordUserStore {
|
||||
return Err(CreateUserError::AlreadyExists);
|
||||
}
|
||||
|
||||
let user_id = format!("user_{:08}", state.next_id);
|
||||
state.next_id += 1;
|
||||
let user_id = format!("user_{:08}", state.next_user_id);
|
||||
state.next_user_id += 1;
|
||||
|
||||
let user = AuthUser {
|
||||
id: user_id,
|
||||
@@ -215,20 +398,102 @@ impl InMemoryPasswordUserStore {
|
||||
Ok(user)
|
||||
}
|
||||
|
||||
fn find_by_user_id(
|
||||
fn insert_session(
|
||||
&self,
|
||||
user_id: &str,
|
||||
) -> Result<Option<StoredPasswordUser>, PasswordEntryError> {
|
||||
session: RefreshSessionRecord,
|
||||
) -> Result<(), RefreshSessionError> {
|
||||
let mut state = self
|
||||
.inner
|
||||
.lock()
|
||||
.map_err(|_| RefreshSessionError::Store("会话仓储锁已中毒".to_string()))?;
|
||||
|
||||
if state
|
||||
.session_id_by_refresh_token_hash
|
||||
.contains_key(&session.refresh_token_hash)
|
||||
{
|
||||
return Err(RefreshSessionError::Store(
|
||||
"refresh token hash 已存在,无法重复创建会话".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
state.session_id_by_refresh_token_hash.insert(
|
||||
session.refresh_token_hash.clone(),
|
||||
session.session_id.clone(),
|
||||
);
|
||||
state.sessions_by_id.insert(
|
||||
session.session_id.clone(),
|
||||
StoredRefreshSession { session },
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn find_session_by_refresh_token_hash(
|
||||
&self,
|
||||
refresh_token_hash: &str,
|
||||
) -> Result<Option<StoredRefreshSession>, RefreshSessionError> {
|
||||
let state = self
|
||||
.inner
|
||||
.lock()
|
||||
.map_err(|_| PasswordEntryError::Store("用户仓储锁已中毒".to_string()))?;
|
||||
.map_err(|_| RefreshSessionError::Store("会话仓储锁已中毒".to_string()))?;
|
||||
let Some(session_id) = state.session_id_by_refresh_token_hash.get(refresh_token_hash) else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
Ok(state
|
||||
.users_by_username
|
||||
.values()
|
||||
.find(|stored_user| stored_user.user.id == user_id)
|
||||
.cloned())
|
||||
Ok(state.sessions_by_id.get(session_id).cloned())
|
||||
}
|
||||
|
||||
fn rotate_session(
|
||||
&self,
|
||||
session_id: &str,
|
||||
previous_refresh_token_hash: &str,
|
||||
next_refresh_token_hash: String,
|
||||
next_expires_at: String,
|
||||
updated_at: String,
|
||||
last_seen_at: String,
|
||||
) -> Result<StoredRefreshSession, RefreshSessionError> {
|
||||
let mut state = self
|
||||
.inner
|
||||
.lock()
|
||||
.map_err(|_| RefreshSessionError::Store("会话仓储锁已中毒".to_string()))?;
|
||||
|
||||
if state
|
||||
.session_id_by_refresh_token_hash
|
||||
.contains_key(&next_refresh_token_hash)
|
||||
{
|
||||
return Err(RefreshSessionError::Store(
|
||||
"新 refresh token hash 已存在,无法轮换".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let current_refresh_token_hash = state
|
||||
.sessions_by_id
|
||||
.get(session_id)
|
||||
.ok_or(RefreshSessionError::SessionNotFound)?
|
||||
.session
|
||||
.refresh_token_hash
|
||||
.clone();
|
||||
if current_refresh_token_hash != previous_refresh_token_hash {
|
||||
return Err(RefreshSessionError::SessionNotFound);
|
||||
}
|
||||
|
||||
state
|
||||
.session_id_by_refresh_token_hash
|
||||
.remove(previous_refresh_token_hash);
|
||||
let stored = state
|
||||
.sessions_by_id
|
||||
.get_mut(session_id)
|
||||
.ok_or(RefreshSessionError::SessionNotFound)?;
|
||||
stored.session.refresh_token_hash = next_refresh_token_hash.clone();
|
||||
stored.session.expires_at = next_expires_at;
|
||||
stored.session.updated_at = updated_at;
|
||||
stored.session.last_seen_at = last_seen_at;
|
||||
let updated_session = stored.clone();
|
||||
state
|
||||
.session_id_by_refresh_token_hash
|
||||
.insert(next_refresh_token_hash, updated_session.session.session_id.clone());
|
||||
|
||||
Ok(updated_session)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -270,6 +535,32 @@ impl fmt::Display for PasswordEntryError {
|
||||
|
||||
impl Error for PasswordEntryError {}
|
||||
|
||||
impl fmt::Display for RefreshSessionError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::MissingToken => f.write_str("缺少刷新会话"),
|
||||
Self::SessionNotFound | Self::SessionExpired | Self::UserNotFound => {
|
||||
f.write_str("当前登录态已失效,请重新登录")
|
||||
}
|
||||
Self::Store(message) => f.write_str(message),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for RefreshSessionError {}
|
||||
|
||||
fn map_password_store_error(error: PasswordEntryError) -> RefreshSessionError {
|
||||
match error {
|
||||
PasswordEntryError::Store(message) => RefreshSessionError::Store(message),
|
||||
PasswordEntryError::InvalidUsername
|
||||
| PasswordEntryError::InvalidPasswordLength
|
||||
| PasswordEntryError::InvalidCredentials
|
||||
| PasswordEntryError::PasswordHash(_) => {
|
||||
RefreshSessionError::Store("用户仓储读取失败".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_username(raw_username: &str) -> Result<String, PasswordEntryError> {
|
||||
let username = raw_username.trim().to_string();
|
||||
let valid_length =
|
||||
@@ -296,15 +587,25 @@ fn validate_password(password: &str) -> Result<(), PasswordEntryError> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use platform_auth::hash_refresh_session_token;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn build_service() -> PasswordEntryService {
|
||||
PasswordEntryService::new(InMemoryPasswordUserStore::default())
|
||||
fn build_store() -> InMemoryAuthStore {
|
||||
InMemoryAuthStore::default()
|
||||
}
|
||||
|
||||
fn build_password_service(store: InMemoryAuthStore) -> PasswordEntryService {
|
||||
PasswordEntryService::new(store)
|
||||
}
|
||||
|
||||
fn build_refresh_service(store: InMemoryAuthStore) -> RefreshSessionService {
|
||||
RefreshSessionService::new(store, 30)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn first_password_entry_creates_user() {
|
||||
let service = build_service();
|
||||
let service = build_password_service(build_store());
|
||||
|
||||
let result = service
|
||||
.execute(PasswordEntryInput {
|
||||
@@ -324,7 +625,8 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn repeated_password_entry_reuses_same_user() {
|
||||
let service = build_service();
|
||||
let store = build_store();
|
||||
let service = build_password_service(store);
|
||||
let first = service
|
||||
.execute(PasswordEntryInput {
|
||||
username: "guest_001".to_string(),
|
||||
@@ -348,7 +650,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn repeated_password_entry_rejects_wrong_password() {
|
||||
let service = build_service();
|
||||
let service = build_password_service(build_store());
|
||||
service
|
||||
.execute(PasswordEntryInput {
|
||||
username: "guest_001".to_string(),
|
||||
@@ -370,7 +672,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn invalid_username_returns_bad_request_error() {
|
||||
let service = build_service();
|
||||
let service = build_password_service(build_store());
|
||||
|
||||
let error = service
|
||||
.execute(PasswordEntryInput {
|
||||
@@ -382,4 +684,66 @@ mod tests {
|
||||
|
||||
assert_eq!(error, PasswordEntryError::InvalidUsername);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn refresh_session_creation_and_rotation_keep_same_session_id() {
|
||||
let store = build_store();
|
||||
let password_service = build_password_service(store.clone());
|
||||
let refresh_service = build_refresh_service(store);
|
||||
let user = password_service
|
||||
.execute(PasswordEntryInput {
|
||||
username: "guest_002".to_string(),
|
||||
password: "secret123".to_string(),
|
||||
})
|
||||
.await
|
||||
.expect("seed login should succeed")
|
||||
.user;
|
||||
let now = OffsetDateTime::now_utc();
|
||||
let first_token_hash = hash_refresh_session_token("refresh-token-01");
|
||||
let created = refresh_service
|
||||
.create_session(
|
||||
CreateRefreshSessionInput {
|
||||
user_id: user.id.clone(),
|
||||
refresh_token_hash: first_token_hash.clone(),
|
||||
issued_by_provider: AuthLoginMethod::Password,
|
||||
},
|
||||
now,
|
||||
)
|
||||
.expect("session should create");
|
||||
|
||||
let rotated = refresh_service
|
||||
.rotate_session(
|
||||
RotateRefreshSessionInput {
|
||||
refresh_token_hash: first_token_hash,
|
||||
next_refresh_token_hash: hash_refresh_session_token("refresh-token-02"),
|
||||
},
|
||||
now + Duration::minutes(10),
|
||||
)
|
||||
.expect("session should rotate");
|
||||
|
||||
assert_eq!(rotated.user.id, user.id);
|
||||
assert_eq!(rotated.session.session_id, created.session.session_id);
|
||||
assert_ne!(
|
||||
rotated.session.refresh_token_hash,
|
||||
created.session.refresh_token_hash
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn refresh_session_rejects_unknown_token_hash() {
|
||||
let store = build_store();
|
||||
let refresh_service = build_refresh_service(store);
|
||||
|
||||
let error = refresh_service
|
||||
.rotate_session(
|
||||
RotateRefreshSessionInput {
|
||||
refresh_token_hash: hash_refresh_session_token("missing"),
|
||||
next_refresh_token_hash: hash_refresh_session_token("next"),
|
||||
},
|
||||
OffsetDateTime::now_utc(),
|
||||
)
|
||||
.expect_err("unknown token should fail");
|
||||
|
||||
assert_eq!(error, RefreshSessionError::SessionNotFound);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user