246 lines
7.3 KiB
Rust
246 lines
7.3 KiB
Rust
use std::sync::Arc;
|
|
|
|
use axum::{
|
|
body::Body,
|
|
extract::{Request, State},
|
|
http::{HeaderValue, StatusCode, header::RETRY_AFTER},
|
|
middleware::Next,
|
|
response::Response,
|
|
};
|
|
use http_body_util::BodyExt;
|
|
use tokio::sync::{OwnedSemaphorePermit, TryAcquireError};
|
|
|
|
use crate::{
|
|
http_error::AppError,
|
|
request_context::RequestContext,
|
|
state::{AppState, HttpRequestPermitPool},
|
|
};
|
|
|
|
pub async fn limit_concurrent_requests(
|
|
State(state): State<AppState>,
|
|
request: Request,
|
|
next: Next,
|
|
) -> Response {
|
|
if should_bypass_backpressure(&request) {
|
|
return next.run(request).await;
|
|
}
|
|
|
|
let Some(permit_pool) = state.http_request_permit_pool() else {
|
|
return next.run(request).await;
|
|
};
|
|
|
|
match acquire_http_request_permit(permit_pool) {
|
|
Ok(permit) => hold_permit_until_response_body_dropped(next.run(request).await, permit),
|
|
Err(_) => reject_overloaded_request(&request),
|
|
}
|
|
}
|
|
|
|
fn acquire_http_request_permit(
|
|
permit_pool: Arc<HttpRequestPermitPool>,
|
|
) -> Result<HttpRequestPermitGuard, TryAcquireError> {
|
|
match permit_pool.clone().try_acquire_owned() {
|
|
Ok(permit) => {
|
|
crate::telemetry::update_http_request_permits_available(permit_pool.available_permits());
|
|
Ok(HttpRequestPermitGuard {
|
|
permit: Some(permit),
|
|
permit_pool,
|
|
})
|
|
}
|
|
Err(error) => {
|
|
crate::telemetry::update_http_request_permits_available(permit_pool.available_permits());
|
|
Err(error)
|
|
}
|
|
}
|
|
}
|
|
|
|
fn hold_permit_until_response_body_dropped(
|
|
response: Response,
|
|
permit: HttpRequestPermitGuard,
|
|
) -> Response {
|
|
response.map(|body| {
|
|
Body::new(body.map_frame(move |frame| {
|
|
let _permit_guard = &permit;
|
|
frame
|
|
}))
|
|
})
|
|
}
|
|
|
|
struct HttpRequestPermitGuard {
|
|
permit: Option<OwnedSemaphorePermit>,
|
|
permit_pool: Arc<HttpRequestPermitPool>,
|
|
}
|
|
|
|
impl Drop for HttpRequestPermitGuard {
|
|
fn drop(&mut self) {
|
|
drop(self.permit.take());
|
|
crate::telemetry::update_http_request_permits_available(self.permit_pool.available_permits());
|
|
}
|
|
}
|
|
|
|
fn reject_overloaded_request(request: &Request<Body>) -> Response {
|
|
let request_context = request.extensions().get::<RequestContext>().cloned();
|
|
let mut response = AppError::from_status(StatusCode::TOO_MANY_REQUESTS)
|
|
.with_message("服务繁忙,请稍后重试")
|
|
.into_response_with_context(request_context.as_ref());
|
|
response
|
|
.headers_mut()
|
|
.insert(RETRY_AFTER, HeaderValue::from_static("1"));
|
|
response
|
|
}
|
|
|
|
fn should_bypass_backpressure(request: &Request<Body>) -> bool {
|
|
request.uri().path() == "/healthz"
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use std::sync::Arc;
|
|
|
|
use axum::{
|
|
Router,
|
|
body::Body,
|
|
extract::Extension,
|
|
http::{Request, StatusCode, header::RETRY_AFTER},
|
|
middleware,
|
|
routing::get,
|
|
};
|
|
use tokio::sync::Notify;
|
|
use tower::ServiceExt;
|
|
|
|
use crate::{config::AppConfig, state::AppState};
|
|
|
|
use super::limit_concurrent_requests;
|
|
|
|
#[derive(Clone)]
|
|
struct HeldRequestGate {
|
|
entered: Arc<Notify>,
|
|
release: Arc<Notify>,
|
|
}
|
|
|
|
async fn held_request(Extension(gate): Extension<HeldRequestGate>) -> &'static str {
|
|
gate.entered.notify_one();
|
|
gate.release.notified().await;
|
|
"ok"
|
|
}
|
|
|
|
async fn fast_request() -> &'static str {
|
|
"ok"
|
|
}
|
|
|
|
fn test_request(path: &str) -> Request<Body> {
|
|
Request::builder()
|
|
.uri(path)
|
|
.body(Body::empty())
|
|
.expect("test request should build")
|
|
}
|
|
|
|
fn build_test_app(max_concurrent_requests: usize, gate: HeldRequestGate) -> Router {
|
|
let mut config = AppConfig::default();
|
|
config.max_concurrent_requests = Some(max_concurrent_requests);
|
|
let state = AppState::new(config).expect("state should build");
|
|
|
|
Router::new()
|
|
.route("/held", get(held_request))
|
|
.route("/fast", get(fast_request))
|
|
.route("/healthz", get(fast_request))
|
|
.layer(middleware::from_fn_with_state(
|
|
state.clone(),
|
|
limit_concurrent_requests,
|
|
))
|
|
.layer(Extension(gate))
|
|
.with_state(state)
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn returns_429_when_concurrency_permits_are_exhausted() {
|
|
let gate = HeldRequestGate {
|
|
entered: Arc::new(Notify::new()),
|
|
release: Arc::new(Notify::new()),
|
|
};
|
|
let app = build_test_app(1, gate.clone());
|
|
let entered = gate.entered.notified();
|
|
|
|
let held_response = tokio::spawn(app.clone().oneshot(test_request("/held")));
|
|
entered.await;
|
|
|
|
let rejected_response = app
|
|
.clone()
|
|
.oneshot(test_request("/fast"))
|
|
.await
|
|
.expect("rejected request should complete");
|
|
assert_eq!(rejected_response.status(), StatusCode::TOO_MANY_REQUESTS);
|
|
assert_eq!(
|
|
rejected_response
|
|
.headers()
|
|
.get(RETRY_AFTER)
|
|
.and_then(|value| value.to_str().ok()),
|
|
Some("1")
|
|
);
|
|
|
|
gate.release.notify_one();
|
|
let completed_response = held_response
|
|
.await
|
|
.expect("held request task should join")
|
|
.expect("held request should complete");
|
|
assert_eq!(completed_response.status(), StatusCode::OK);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn healthz_bypasses_concurrency_backpressure() {
|
|
let gate = HeldRequestGate {
|
|
entered: Arc::new(Notify::new()),
|
|
release: Arc::new(Notify::new()),
|
|
};
|
|
let app = build_test_app(1, gate.clone());
|
|
let entered = gate.entered.notified();
|
|
|
|
let held_response = tokio::spawn(app.clone().oneshot(test_request("/held")));
|
|
entered.await;
|
|
|
|
let health_response = app
|
|
.clone()
|
|
.oneshot(test_request("/healthz"))
|
|
.await
|
|
.expect("healthz request should complete");
|
|
assert_eq!(health_response.status(), StatusCode::OK);
|
|
|
|
gate.release.notify_one();
|
|
let completed_response = held_response
|
|
.await
|
|
.expect("held request task should join")
|
|
.expect("held request should complete");
|
|
assert_eq!(completed_response.status(), StatusCode::OK);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn permit_is_held_until_response_body_is_dropped() {
|
|
let gate = HeldRequestGate {
|
|
entered: Arc::new(Notify::new()),
|
|
release: Arc::new(Notify::new()),
|
|
};
|
|
let app = build_test_app(1, gate);
|
|
|
|
let first_response = app
|
|
.clone()
|
|
.oneshot(test_request("/fast"))
|
|
.await
|
|
.expect("first request should complete");
|
|
assert_eq!(first_response.status(), StatusCode::OK);
|
|
|
|
let rejected_response = app
|
|
.clone()
|
|
.oneshot(test_request("/fast"))
|
|
.await
|
|
.expect("second request should complete");
|
|
assert_eq!(rejected_response.status(), StatusCode::TOO_MANY_REQUESTS);
|
|
|
|
drop(first_response);
|
|
|
|
let accepted_response = app
|
|
.oneshot(test_request("/fast"))
|
|
.await
|
|
.expect("third request should complete");
|
|
assert_eq!(accepted_response.status(), StatusCode::OK);
|
|
}
|
|
}
|