feat(api-server): add request backpressure controls

This commit is contained in:
kdletters
2026-05-17 04:56:45 +08:00
parent fb23ee79d8
commit 02271e6c73
11 changed files with 478 additions and 2 deletions

View File

@@ -11,6 +11,7 @@ base64 = { workspace = true }
bytes = { workspace = true }
dotenvy = { workspace = true }
image = { workspace = true, features = ["jpeg", "png", "webp"] }
http-body-util = { workspace = true }
reqwest = { workspace = true, features = ["json", "multipart", "rustls-tls"] }
webp = { workspace = true }
module-ai = { workspace = true }
@@ -45,7 +46,7 @@ shared-kernel = { workspace = true }
shared-logging = { workspace = true }
socket2 = { workspace = true }
spacetime-client = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "net", "time"] }
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "net", "time", "sync"] }
tokio-stream = { workspace = true }
futures-util = { workspace = true }
time = { workspace = true, features = ["formatting"] }

View File

@@ -15,6 +15,7 @@ use tracing::{Level, Span, error, info_span};
use crate::{
auth::{AuthenticatedAccessToken, require_bearer_auth},
backpressure::limit_concurrent_requests,
creation_entry_config::require_creation_entry_route_enabled,
error_middleware::normalize_error_response,
modules,
@@ -76,6 +77,11 @@ pub fn build_router(state: AppState) -> Router {
state.clone(),
require_creation_entry_route_enabled,
))
// HTTP 背压在业务路由外侧快拒绝,避免过载请求继续占用 SpacetimeDB facade 与业务执行资源。
.layer(middleware::from_fn_with_state(
state.clone(),
limit_concurrent_requests,
))
// 错误归一化层放在 tracing 里侧,让 tracing 记录到最终对外返回的状态与错误体形态。
.layer(middleware::from_fn(normalize_error_response))
// 响应头回写放在错误归一化外侧,确保最终写回的是归一化后的最终响应。

View File

@@ -0,0 +1,221 @@
use std::sync::Arc;
use axum::{
body::Body,
extract::{Request, State},
http::{HeaderValue, StatusCode, header::RETRY_AFTER},
middleware::Next,
response::Response,
};
use http_body_util::BodyExt;
use tokio::sync::{OwnedSemaphorePermit, TryAcquireError};
use crate::{
http_error::AppError,
request_context::RequestContext,
state::{AppState, HttpRequestPermitPool},
};
pub async fn limit_concurrent_requests(
State(state): State<AppState>,
request: Request,
next: Next,
) -> Response {
if should_bypass_backpressure(&request) {
return next.run(request).await;
}
let Some(permit_pool) = state.http_request_permit_pool() else {
return next.run(request).await;
};
match acquire_http_request_permit(permit_pool) {
Ok(permit) => hold_permit_until_response_body_dropped(next.run(request).await, permit),
Err(_) => reject_overloaded_request(&request),
}
}
fn acquire_http_request_permit(
permit_pool: Arc<HttpRequestPermitPool>,
) -> Result<OwnedSemaphorePermit, TryAcquireError> {
permit_pool.try_acquire_owned()
}
fn hold_permit_until_response_body_dropped(
response: Response,
permit: OwnedSemaphorePermit,
) -> Response {
response.map(|body| {
Body::new(body.map_frame(move |frame| {
let _permit_guard = &permit;
frame
}))
})
}
fn reject_overloaded_request(request: &Request<Body>) -> Response {
let request_context = request.extensions().get::<RequestContext>().cloned();
let mut response = AppError::from_status(StatusCode::TOO_MANY_REQUESTS)
.with_message("服务繁忙,请稍后重试")
.into_response_with_context(request_context.as_ref());
response
.headers_mut()
.insert(RETRY_AFTER, HeaderValue::from_static("1"));
response
}
fn should_bypass_backpressure(request: &Request<Body>) -> bool {
request.uri().path() == "/healthz"
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use axum::{
Router,
body::Body,
extract::Extension,
http::{Request, StatusCode, header::RETRY_AFTER},
middleware,
routing::get,
};
use tokio::sync::Notify;
use tower::ServiceExt;
use crate::{config::AppConfig, state::AppState};
use super::limit_concurrent_requests;
#[derive(Clone)]
struct HeldRequestGate {
entered: Arc<Notify>,
release: Arc<Notify>,
}
async fn held_request(Extension(gate): Extension<HeldRequestGate>) -> &'static str {
gate.entered.notify_one();
gate.release.notified().await;
"ok"
}
async fn fast_request() -> &'static str {
"ok"
}
fn test_request(path: &str) -> Request<Body> {
Request::builder()
.uri(path)
.body(Body::empty())
.expect("test request should build")
}
fn build_test_app(max_concurrent_requests: usize, gate: HeldRequestGate) -> Router {
let mut config = AppConfig::default();
config.max_concurrent_requests = Some(max_concurrent_requests);
let state = AppState::new(config).expect("state should build");
Router::new()
.route("/held", get(held_request))
.route("/fast", get(fast_request))
.route("/healthz", get(fast_request))
.layer(middleware::from_fn_with_state(
state.clone(),
limit_concurrent_requests,
))
.layer(Extension(gate))
.with_state(state)
}
#[tokio::test]
async fn returns_429_when_concurrency_permits_are_exhausted() {
let gate = HeldRequestGate {
entered: Arc::new(Notify::new()),
release: Arc::new(Notify::new()),
};
let app = build_test_app(1, gate.clone());
let entered = gate.entered.notified();
let held_response = tokio::spawn(app.clone().oneshot(test_request("/held")));
entered.await;
let rejected_response = app
.clone()
.oneshot(test_request("/fast"))
.await
.expect("rejected request should complete");
assert_eq!(rejected_response.status(), StatusCode::TOO_MANY_REQUESTS);
assert_eq!(
rejected_response
.headers()
.get(RETRY_AFTER)
.and_then(|value| value.to_str().ok()),
Some("1")
);
gate.release.notify_one();
let completed_response = held_response
.await
.expect("held request task should join")
.expect("held request should complete");
assert_eq!(completed_response.status(), StatusCode::OK);
}
#[tokio::test]
async fn healthz_bypasses_concurrency_backpressure() {
let gate = HeldRequestGate {
entered: Arc::new(Notify::new()),
release: Arc::new(Notify::new()),
};
let app = build_test_app(1, gate.clone());
let entered = gate.entered.notified();
let held_response = tokio::spawn(app.clone().oneshot(test_request("/held")));
entered.await;
let health_response = app
.clone()
.oneshot(test_request("/healthz"))
.await
.expect("healthz request should complete");
assert_eq!(health_response.status(), StatusCode::OK);
gate.release.notify_one();
let completed_response = held_response
.await
.expect("held request task should join")
.expect("held request should complete");
assert_eq!(completed_response.status(), StatusCode::OK);
}
#[tokio::test]
async fn permit_is_held_until_response_body_is_dropped() {
let gate = HeldRequestGate {
entered: Arc::new(Notify::new()),
release: Arc::new(Notify::new()),
};
let app = build_test_app(1, gate);
let first_response = app
.clone()
.oneshot(test_request("/fast"))
.await
.expect("first request should complete");
assert_eq!(first_response.status(), StatusCode::OK);
let rejected_response = app
.clone()
.oneshot(test_request("/fast"))
.await
.expect("second request should complete");
assert_eq!(rejected_response.status(), StatusCode::TOO_MANY_REQUESTS);
drop(first_response);
let accepted_response = app
.oneshot(test_request("/fast"))
.await
.expect("third request should complete");
assert_eq!(accepted_response.status(), StatusCode::OK);
}
}

View File

@@ -22,6 +22,7 @@ pub struct AppConfig {
pub bind_port: u16,
pub listen_backlog: i32,
pub worker_threads: Option<usize>,
pub max_concurrent_requests: Option<usize>,
pub log_filter: String,
pub otel_enabled: bool,
pub admin_username: Option<String>,
@@ -152,6 +153,7 @@ impl Default for AppConfig {
bind_port: 3000,
listen_backlog: 1024,
worker_threads: None,
max_concurrent_requests: None,
log_filter: "info,tower_http=info".to_string(),
otel_enabled: false,
admin_username: None,
@@ -315,6 +317,11 @@ impl AppConfig {
if let Some(worker_threads) = read_first_usize_env(&["GENARRATIVE_API_WORKER_THREADS"]) {
config.worker_threads = Some(worker_threads);
}
if let Some(max_concurrent_requests) =
read_first_usize_env(&["GENARRATIVE_API_MAX_CONCURRENT_REQUESTS"])
{
config.max_concurrent_requests = Some(max_concurrent_requests);
}
if let Some(otel_enabled) = read_first_bool_env(&["GENARRATIVE_OTEL_ENABLED"]) {
config.otel_enabled = otel_enabled;
}
@@ -1195,20 +1202,24 @@ mod tests {
unsafe {
std::env::remove_var("GENARRATIVE_API_LISTEN_BACKLOG");
std::env::remove_var("GENARRATIVE_API_WORKER_THREADS");
std::env::remove_var("GENARRATIVE_API_MAX_CONCURRENT_REQUESTS");
std::env::remove_var("GENARRATIVE_OTEL_ENABLED");
std::env::set_var("GENARRATIVE_API_LISTEN_BACKLOG", "2048");
std::env::set_var("GENARRATIVE_API_WORKER_THREADS", "6");
std::env::set_var("GENARRATIVE_API_MAX_CONCURRENT_REQUESTS", "128");
std::env::set_var("GENARRATIVE_OTEL_ENABLED", "true");
}
let config = AppConfig::from_env();
assert_eq!(config.listen_backlog, 2048);
assert_eq!(config.worker_threads, Some(6));
assert_eq!(config.max_concurrent_requests, Some(128));
assert!(config.otel_enabled);
unsafe {
std::env::remove_var("GENARRATIVE_API_LISTEN_BACKLOG");
std::env::remove_var("GENARRATIVE_API_WORKER_THREADS");
std::env::remove_var("GENARRATIVE_API_MAX_CONCURRENT_REQUESTS");
std::env::remove_var("GENARRATIVE_OTEL_ENABLED");
}
}

View File

@@ -13,6 +13,7 @@ mod auth_payload;
mod auth_public_user;
mod auth_session;
mod auth_sessions;
mod backpressure;
mod bark_battle;
mod big_fish;
mod big_fish_agent_turn;

View File

@@ -27,6 +27,7 @@ use shared_contracts::creation_entry_config::CreationEntryConfigResponse;
use shared_contracts::creative_agent::CreativeAgentSessionSnapshot;
use spacetime_client::{SpacetimeClient, SpacetimeClientConfig, SpacetimeClientError};
use time::OffsetDateTime;
use tokio::sync::Semaphore;
use tracing::{info, warn};
use crate::config::AppConfig;
@@ -35,12 +36,15 @@ use crate::wechat_provider::build_wechat_provider;
const ADMIN_ROLE: &str = "admin";
pub type HttpRequestPermitPool = Semaphore;
// 当前阶段先保留最小共享状态壳,后续逐步接入配置、客户端与平台适配。
#[derive(Clone, Debug)]
pub struct AppState {
// 配置会在后续中间件、路由和平台适配接入时逐步消费。
#[allow(dead_code)]
pub config: AppConfig,
http_request_permit_pool: Option<Arc<HttpRequestPermitPool>>,
auth_jwt_config: JwtConfig,
admin_runtime: Option<AdminRuntime>,
refresh_cookie_config: RefreshCookieConfig,
@@ -192,9 +196,14 @@ impl AppState {
});
let llm_client = build_llm_client(&config)?;
let creative_agent_gpt5_client = build_creative_agent_gpt5_client(&config)?;
let http_request_permit_pool = config
.max_concurrent_requests
.map(HttpRequestPermitPool::new)
.map(Arc::new);
Ok(Self {
config,
http_request_permit_pool,
auth_jwt_config,
admin_runtime,
refresh_cookie_config,
@@ -235,6 +244,10 @@ impl AppState {
&self.refresh_cookie_config
}
pub fn http_request_permit_pool(&self) -> Option<Arc<HttpRequestPermitPool>> {
self.http_request_permit_pool.clone()
}
pub async fn upsert_creation_entry_type_config(
&self,
input: module_runtime::CreationEntryTypeAdminUpsertInput,