use std::sync::Arc; use axum::{ body::Body, extract::{Request, State}, http::{HeaderValue, StatusCode, header::RETRY_AFTER}, middleware::Next, response::Response, }; use http_body_util::BodyExt; use tokio::sync::{OwnedSemaphorePermit, TryAcquireError}; use crate::{ http_error::AppError, request_context::RequestContext, state::{BackpressureState, HttpRequestPermitPool, HttpRequestPermitPoolKind}, }; pub async fn limit_concurrent_requests( State(state): State, request: Request, next: Next, ) -> Response { if should_bypass_backpressure(&request) { return next.run(request).await; } let requested_pool = classify_request_permit_pool(request.uri().path()); let Some((permit_pool_kind, permit_pool)) = state.request_permit_pool(requested_pool) else { return next.run(request).await; }; match acquire_http_request_permit(permit_pool_kind, permit_pool) { Ok(permit) => hold_permit_until_response_body_dropped(next.run(request).await, permit), Err(_) => reject_overloaded_request(&request), } } fn acquire_http_request_permit( permit_pool_kind: HttpRequestPermitPoolKind, permit_pool: Arc, ) -> Result { match permit_pool.clone().try_acquire_owned() { Ok(permit) => { crate::telemetry::update_http_request_permits_available( permit_pool_kind, permit_pool.available_permits(), ); Ok(HttpRequestPermitGuard { permit_pool_kind, permit: Some(permit), permit_pool, }) } Err(error) => { crate::telemetry::update_http_request_permits_available( permit_pool_kind, permit_pool.available_permits(), ); Err(error) } } } fn hold_permit_until_response_body_dropped( response: Response, permit: HttpRequestPermitGuard, ) -> Response { response.map(|body| { Body::new(body.map_frame(move |frame| { let _permit_guard = &permit; frame })) }) } struct HttpRequestPermitGuard { permit_pool_kind: HttpRequestPermitPoolKind, permit: Option, permit_pool: Arc, } impl Drop for HttpRequestPermitGuard { fn drop(&mut self) { drop(self.permit.take()); crate::telemetry::update_http_request_permits_available( self.permit_pool_kind, self.permit_pool.available_permits(), ); } } fn reject_overloaded_request(request: &Request) -> Response { let request_context = request.extensions().get::().cloned(); let mut response = AppError::from_status(StatusCode::TOO_MANY_REQUESTS) .with_message("服务繁忙,请稍后重试") .into_response_with_context(request_context.as_ref()); response .headers_mut() .insert(RETRY_AFTER, HeaderValue::from_static("1")); response } fn should_bypass_backpressure(request: &Request) -> bool { request.uri().path() == "/healthz" } fn classify_request_permit_pool(path: &str) -> HttpRequestPermitPoolKind { if is_gallery_list_path(path) { HttpRequestPermitPoolKind::Gallery } else if is_gallery_detail_path(path) { HttpRequestPermitPoolKind::Detail } else if path.starts_with("/admin/api/") { HttpRequestPermitPoolKind::Admin } else { HttpRequestPermitPoolKind::Default } } fn is_gallery_list_path(path: &str) -> bool { matches!( path, "/api/runtime/puzzle/gallery" | "/api/runtime/custom-world-gallery" ) } fn is_gallery_detail_path(path: &str) -> bool { let puzzle_prefix = "/api/runtime/puzzle/gallery/"; if let Some(profile_id) = path.strip_prefix(puzzle_prefix) { return !profile_id.is_empty() && !profile_id.contains('/'); } let custom_world_prefix = "/api/runtime/custom-world-gallery/"; if let Some(remainder) = path.strip_prefix(custom_world_prefix) { let mut segments = remainder.split('/'); return matches!( (segments.next(), segments.next(), segments.next()), (Some(owner_user_id), Some(profile_id), None) if !owner_user_id.is_empty() && !profile_id.is_empty() ); } false } #[cfg(test)] mod tests { use std::sync::Arc; use axum::{ Router, body::Body, extract::Extension, http::{Request, StatusCode, header::RETRY_AFTER}, middleware, routing::get, }; use tokio::sync::Notify; use tower::ServiceExt; use axum::extract::FromRef; use crate::{ config::AppConfig, state::{AppState, BackpressureState}, }; use super::{classify_request_permit_pool, limit_concurrent_requests}; #[derive(Clone)] struct HeldRequestGate { entered: Arc, release: Arc, } async fn held_request(Extension(gate): Extension) -> &'static str { gate.entered.notify_one(); gate.release.notified().await; "ok" } async fn fast_request() -> &'static str { "ok" } fn test_request(path: &str) -> Request { Request::builder() .uri(path) .body(Body::empty()) .expect("test request should build") } fn build_test_app(max_concurrent_requests: usize, gate: HeldRequestGate) -> Router { let mut config = AppConfig::default(); config.max_concurrent_requests = Some(max_concurrent_requests); let state = AppState::new(config).expect("state should build"); let backpressure_state = BackpressureState::from_ref(&state); Router::new() .route("/held", get(held_request)) .route("/fast", get(fast_request)) .route("/healthz", get(fast_request)) .layer(middleware::from_fn_with_state( backpressure_state, limit_concurrent_requests, )) .layer(Extension(gate)) .with_state(state) } fn build_grouped_test_app( default_max_concurrent_requests: usize, gallery_max_concurrent_requests: usize, admin_max_concurrent_requests: usize, gate: HeldRequestGate, ) -> Router { let mut config = AppConfig::default(); config.max_concurrent_requests = Some(default_max_concurrent_requests); config.gallery_max_concurrent_requests = Some(gallery_max_concurrent_requests); config.admin_max_concurrent_requests = Some(admin_max_concurrent_requests); let state = AppState::new(config).expect("state should build"); let backpressure_state = BackpressureState::from_ref(&state); Router::new() .route("/held", get(held_request)) .route("/api/runtime/puzzle/gallery", get(held_request)) .route("/api/runtime/custom-world-gallery", get(held_request)) .route("/api/runtime/puzzle/gallery/profile-1", get(held_request)) .route( "/api/runtime/puzzle/gallery/profile-1/like", get(fast_request), ) .route( "/api/runtime/custom-world-gallery/user-1/profile-1", get(held_request), ) .route("/admin/api/overview", get(held_request)) .route("/fast", get(fast_request)) .layer(middleware::from_fn_with_state( backpressure_state, limit_concurrent_requests, )) .layer(Extension(gate)) .with_state(state) } #[tokio::test] async fn returns_429_when_concurrency_permits_are_exhausted() { let gate = HeldRequestGate { entered: Arc::new(Notify::new()), release: Arc::new(Notify::new()), }; let app = build_test_app(1, gate.clone()); let entered = gate.entered.notified(); let held_response = tokio::spawn(app.clone().oneshot(test_request("/held"))); entered.await; let rejected_response = app .clone() .oneshot(test_request("/fast")) .await .expect("rejected request should complete"); assert_eq!(rejected_response.status(), StatusCode::TOO_MANY_REQUESTS); assert_eq!( rejected_response .headers() .get(RETRY_AFTER) .and_then(|value| value.to_str().ok()), Some("1") ); gate.release.notify_one(); let completed_response = held_response .await .expect("held request task should join") .expect("held request should complete"); assert_eq!(completed_response.status(), StatusCode::OK); } #[tokio::test] async fn healthz_bypasses_concurrency_backpressure() { let gate = HeldRequestGate { entered: Arc::new(Notify::new()), release: Arc::new(Notify::new()), }; let app = build_test_app(1, gate.clone()); let entered = gate.entered.notified(); let held_response = tokio::spawn(app.clone().oneshot(test_request("/held"))); entered.await; let health_response = app .clone() .oneshot(test_request("/healthz")) .await .expect("healthz request should complete"); assert_eq!(health_response.status(), StatusCode::OK); gate.release.notify_one(); let completed_response = held_response .await .expect("held request task should join") .expect("held request should complete"); assert_eq!(completed_response.status(), StatusCode::OK); } #[tokio::test] async fn permit_is_held_until_response_body_is_dropped() { let gate = HeldRequestGate { entered: Arc::new(Notify::new()), release: Arc::new(Notify::new()), }; let app = build_test_app(1, gate); let first_response = app .clone() .oneshot(test_request("/fast")) .await .expect("first request should complete"); assert_eq!(first_response.status(), StatusCode::OK); let rejected_response = app .clone() .oneshot(test_request("/fast")) .await .expect("second request should complete"); assert_eq!(rejected_response.status(), StatusCode::TOO_MANY_REQUESTS); drop(first_response); let accepted_response = app .oneshot(test_request("/fast")) .await .expect("third request should complete"); assert_eq!(accepted_response.status(), StatusCode::OK); } #[tokio::test] async fn gallery_pool_rejects_gallery_without_blocking_default_routes() { let gate = HeldRequestGate { entered: Arc::new(Notify::new()), release: Arc::new(Notify::new()), }; let app = build_grouped_test_app(2, 1, 1, gate.clone()); let entered = gate.entered.notified(); let held_response = tokio::spawn( app.clone() .oneshot(test_request("/api/runtime/puzzle/gallery")), ); entered.await; let rejected_gallery_response = app .clone() .oneshot(test_request("/api/runtime/custom-world-gallery")) .await .expect("rejected gallery request should complete"); assert_eq!( rejected_gallery_response.status(), StatusCode::TOO_MANY_REQUESTS ); let accepted_default_response = app .clone() .oneshot(test_request("/fast")) .await .expect("default request should complete"); assert_eq!(accepted_default_response.status(), StatusCode::OK); gate.release.notify_one(); let completed_response = held_response .await .expect("held request task should join") .expect("held request should complete"); assert_eq!(completed_response.status(), StatusCode::OK); } #[tokio::test] async fn detail_pool_falls_back_to_default_when_unset() { let gate = HeldRequestGate { entered: Arc::new(Notify::new()), release: Arc::new(Notify::new()), }; let mut config = AppConfig::default(); config.max_concurrent_requests = Some(1); config.detail_max_concurrent_requests = None; let state = AppState::new(config).expect("state should build"); let backpressure_state = BackpressureState::from_ref(&state); let app = Router::new() .route("/api/runtime/puzzle/gallery/profile-1", get(held_request)) .route("/fast", get(fast_request)) .layer(middleware::from_fn_with_state( backpressure_state, limit_concurrent_requests, )) .layer(Extension(gate.clone())) .with_state(state); let entered = gate.entered.notified(); let held_response = tokio::spawn( app.clone() .oneshot(test_request("/api/runtime/puzzle/gallery/profile-1")), ); entered.await; let rejected_default_response = app .clone() .oneshot(test_request("/fast")) .await .expect("default request should complete"); assert_eq!( rejected_default_response.status(), StatusCode::TOO_MANY_REQUESTS ); gate.release.notify_one(); let completed_response = held_response .await .expect("held request task should join") .expect("held request should complete"); assert_eq!(completed_response.status(), StatusCode::OK); } #[tokio::test] async fn admin_pool_is_isolated_from_default_routes() { let gate = HeldRequestGate { entered: Arc::new(Notify::new()), release: Arc::new(Notify::new()), }; let app = build_grouped_test_app(2, 1, 1, gate.clone()); let entered = gate.entered.notified(); let held_response = tokio::spawn(app.clone().oneshot(test_request("/admin/api/overview"))); entered.await; let rejected_admin_response = app .clone() .oneshot(test_request("/admin/api/overview")) .await .expect("rejected admin request should complete"); assert_eq!( rejected_admin_response.status(), StatusCode::TOO_MANY_REQUESTS ); let accepted_default_response = app .clone() .oneshot(test_request("/fast")) .await .expect("default request should complete"); assert_eq!(accepted_default_response.status(), StatusCode::OK); gate.release.notify_one(); let completed_response = held_response .await .expect("held request task should join") .expect("held request should complete"); assert_eq!(completed_response.status(), StatusCode::OK); } #[test] fn classifies_only_exact_gallery_detail_paths_as_detail() { assert_eq!( classify_request_permit_pool("/api/runtime/puzzle/gallery/profile-1"), crate::state::HttpRequestPermitPoolKind::Detail ); assert_eq!( classify_request_permit_pool("/api/runtime/puzzle/gallery/profile-1/like"), crate::state::HttpRequestPermitPoolKind::Default ); assert_eq!( classify_request_permit_pool("/api/runtime/custom-world-gallery/user-1/profile-1"), crate::state::HttpRequestPermitPoolKind::Detail ); assert_eq!( classify_request_permit_pool("/api/runtime/custom-world-gallery/user-1/profile-1/like"), crate::state::HttpRequestPermitPoolKind::Default ); } }