use std::sync::Arc; use axum::{ body::Body, extract::{Request, State}, http::{HeaderValue, StatusCode, header::RETRY_AFTER}, middleware::Next, response::Response, }; use http_body_util::BodyExt; use tokio::sync::{OwnedSemaphorePermit, TryAcquireError}; use crate::{ http_error::AppError, request_context::RequestContext, state::{AppState, HttpRequestPermitPool}, }; pub async fn limit_concurrent_requests( State(state): State, request: Request, next: Next, ) -> Response { if should_bypass_backpressure(&request) { return next.run(request).await; } let Some(permit_pool) = state.http_request_permit_pool() else { return next.run(request).await; }; match acquire_http_request_permit(permit_pool) { Ok(permit) => hold_permit_until_response_body_dropped(next.run(request).await, permit), Err(_) => reject_overloaded_request(&request), } } fn acquire_http_request_permit( permit_pool: Arc, ) -> Result { match permit_pool.clone().try_acquire_owned() { Ok(permit) => { crate::telemetry::update_http_request_permits_available(permit_pool.available_permits()); Ok(HttpRequestPermitGuard { permit: Some(permit), permit_pool, }) } Err(error) => { crate::telemetry::update_http_request_permits_available(permit_pool.available_permits()); Err(error) } } } fn hold_permit_until_response_body_dropped( response: Response, permit: HttpRequestPermitGuard, ) -> Response { response.map(|body| { Body::new(body.map_frame(move |frame| { let _permit_guard = &permit; frame })) }) } struct HttpRequestPermitGuard { permit: Option, permit_pool: Arc, } impl Drop for HttpRequestPermitGuard { fn drop(&mut self) { drop(self.permit.take()); crate::telemetry::update_http_request_permits_available(self.permit_pool.available_permits()); } } fn reject_overloaded_request(request: &Request) -> Response { let request_context = request.extensions().get::().cloned(); let mut response = AppError::from_status(StatusCode::TOO_MANY_REQUESTS) .with_message("服务繁忙,请稍后重试") .into_response_with_context(request_context.as_ref()); response .headers_mut() .insert(RETRY_AFTER, HeaderValue::from_static("1")); response } fn should_bypass_backpressure(request: &Request) -> bool { request.uri().path() == "/healthz" } #[cfg(test)] mod tests { use std::sync::Arc; use axum::{ Router, body::Body, extract::Extension, http::{Request, StatusCode, header::RETRY_AFTER}, middleware, routing::get, }; use tokio::sync::Notify; use tower::ServiceExt; use crate::{config::AppConfig, state::AppState}; use super::limit_concurrent_requests; #[derive(Clone)] struct HeldRequestGate { entered: Arc, release: Arc, } async fn held_request(Extension(gate): Extension) -> &'static str { gate.entered.notify_one(); gate.release.notified().await; "ok" } async fn fast_request() -> &'static str { "ok" } fn test_request(path: &str) -> Request { Request::builder() .uri(path) .body(Body::empty()) .expect("test request should build") } fn build_test_app(max_concurrent_requests: usize, gate: HeldRequestGate) -> Router { let mut config = AppConfig::default(); config.max_concurrent_requests = Some(max_concurrent_requests); let state = AppState::new(config).expect("state should build"); Router::new() .route("/held", get(held_request)) .route("/fast", get(fast_request)) .route("/healthz", get(fast_request)) .layer(middleware::from_fn_with_state( state.clone(), limit_concurrent_requests, )) .layer(Extension(gate)) .with_state(state) } #[tokio::test] async fn returns_429_when_concurrency_permits_are_exhausted() { let gate = HeldRequestGate { entered: Arc::new(Notify::new()), release: Arc::new(Notify::new()), }; let app = build_test_app(1, gate.clone()); let entered = gate.entered.notified(); let held_response = tokio::spawn(app.clone().oneshot(test_request("/held"))); entered.await; let rejected_response = app .clone() .oneshot(test_request("/fast")) .await .expect("rejected request should complete"); assert_eq!(rejected_response.status(), StatusCode::TOO_MANY_REQUESTS); assert_eq!( rejected_response .headers() .get(RETRY_AFTER) .and_then(|value| value.to_str().ok()), Some("1") ); gate.release.notify_one(); let completed_response = held_response .await .expect("held request task should join") .expect("held request should complete"); assert_eq!(completed_response.status(), StatusCode::OK); } #[tokio::test] async fn healthz_bypasses_concurrency_backpressure() { let gate = HeldRequestGate { entered: Arc::new(Notify::new()), release: Arc::new(Notify::new()), }; let app = build_test_app(1, gate.clone()); let entered = gate.entered.notified(); let held_response = tokio::spawn(app.clone().oneshot(test_request("/held"))); entered.await; let health_response = app .clone() .oneshot(test_request("/healthz")) .await .expect("healthz request should complete"); assert_eq!(health_response.status(), StatusCode::OK); gate.release.notify_one(); let completed_response = held_response .await .expect("held request task should join") .expect("held request should complete"); assert_eq!(completed_response.status(), StatusCode::OK); } #[tokio::test] async fn permit_is_held_until_response_body_is_dropped() { let gate = HeldRequestGate { entered: Arc::new(Notify::new()), release: Arc::new(Notify::new()), }; let app = build_test_app(1, gate); let first_response = app .clone() .oneshot(test_request("/fast")) .await .expect("first request should complete"); assert_eq!(first_response.status(), StatusCode::OK); let rejected_response = app .clone() .oneshot(test_request("/fast")) .await .expect("second request should complete"); assert_eq!(rejected_response.status(), StatusCode::TOO_MANY_REQUESTS); drop(first_response); let accepted_response = app .oneshot(test_request("/fast")) .await .expect("third request should complete"); assert_eq!(accepted_response.status(), StatusCode::OK); } }