This commit is contained in:
2026-06-05 23:44:36 +08:00
53 changed files with 2823 additions and 929 deletions

View File

@@ -54,7 +54,7 @@ shared-kernel = { workspace = true }
shared-logging = { workspace = true }
socket2 = { workspace = true }
spacetime-client = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "net", "time", "sync", "fs", "io-util"] }
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "net", "time", "sync", "fs", "io-util", "signal"] }
tokio-stream = { workspace = true }
futures-util = { workspace = true }
time = { workspace = true, features = ["formatting"] }

View File

@@ -877,6 +877,46 @@ mod tests {
);
}
#[tokio::test]
async fn readyz_reports_readiness_and_draining_state() {
let state = AppState::new(AppConfig::default()).expect("state should build");
let app = build_router(state.clone());
let ready_response = app
.clone()
.oneshot(
Request::builder()
.uri("/readyz")
.header("x-request-id", "req-ready")
.body(Body::empty())
.expect("readyz request should build"),
)
.await
.expect("readyz request should succeed");
assert_eq!(ready_response.status(), StatusCode::OK);
let ready_body = read_json_response(ready_response).await;
assert_eq!(ready_body["ok"], Value::Bool(true));
assert_eq!(ready_body["ready"], Value::Bool(true));
state.mark_not_ready();
let draining_response = app
.oneshot(
Request::builder()
.uri("/readyz")
.header("x-request-id", "req-draining")
.body(Body::empty())
.expect("readyz request should build"),
)
.await
.expect("readyz request should succeed");
assert_eq!(draining_response.status(), StatusCode::SERVICE_UNAVAILABLE);
let draining_body = read_json_response(draining_response).await;
assert_eq!(
draining_body["error"]["details"]["reason"],
"api_server_draining"
);
}
#[tokio::test]
async fn creative_agent_draft_edit_rejects_unconfirmed_template_session() {
let app = build_internal_creative_agent_app();

View File

@@ -102,7 +102,7 @@ fn reject_overloaded_request(request: &Request<Body>) -> Response {
}
fn should_bypass_backpressure(request: &Request<Body>) -> bool {
request.uri().path() == "/healthz"
matches!(request.uri().path(), "/healthz" | "/readyz")
}
fn classify_request_permit_pool(path: &str) -> HttpRequestPermitPoolKind {
@@ -200,6 +200,7 @@ mod tests {
.route("/held", get(held_request))
.route("/fast", get(fast_request))
.route("/healthz", get(fast_request))
.route("/readyz", get(fast_request))
.layer(middleware::from_fn_with_state(
backpressure_state,
limit_concurrent_requests,
@@ -297,6 +298,13 @@ mod tests {
.expect("healthz request should complete");
assert_eq!(health_response.status(), StatusCode::OK);
let ready_response = app
.clone()
.oneshot(test_request("/readyz"))
.await
.expect("readyz request should complete");
assert_eq!(ready_response.status(), StatusCode::OK);
gate.release.notify_one();
let completed_response = held_response
.await

View File

@@ -25,6 +25,7 @@ pub struct AppConfig {
pub gallery_max_concurrent_requests: Option<usize>,
pub detail_max_concurrent_requests: Option<usize>,
pub admin_max_concurrent_requests: Option<usize>,
pub shutdown_outbox_flush_timeout: Duration,
pub tracking_outbox_enabled: bool,
pub tracking_outbox_dir: PathBuf,
pub tracking_outbox_batch_size: usize,
@@ -169,6 +170,7 @@ impl Default for AppConfig {
gallery_max_concurrent_requests: None,
detail_max_concurrent_requests: None,
admin_max_concurrent_requests: None,
shutdown_outbox_flush_timeout: Duration::from_millis(5_000),
tracking_outbox_enabled: true,
tracking_outbox_dir: PathBuf::from("server-rs/.data/tracking-outbox"),
tracking_outbox_batch_size: 500,
@@ -365,6 +367,11 @@ impl AppConfig {
{
config.admin_max_concurrent_requests = Some(max_concurrent_requests);
}
if let Some(timeout_ms) =
read_first_positive_u64_env(&["GENARRATIVE_API_SHUTDOWN_OUTBOX_FLUSH_TIMEOUT_MS"])
{
config.shutdown_outbox_flush_timeout = Duration::from_millis(timeout_ms);
}
if let Some(enabled) = read_first_bool_env(&["GENARRATIVE_TRACKING_OUTBOX_ENABLED"]) {
config.tracking_outbox_enabled = enabled;
}
@@ -1324,6 +1331,7 @@ mod tests {
std::env::remove_var("GENARRATIVE_API_GALLERY_MAX_CONCURRENT_REQUESTS");
std::env::remove_var("GENARRATIVE_API_DETAIL_MAX_CONCURRENT_REQUESTS");
std::env::remove_var("GENARRATIVE_API_ADMIN_MAX_CONCURRENT_REQUESTS");
std::env::remove_var("GENARRATIVE_API_SHUTDOWN_OUTBOX_FLUSH_TIMEOUT_MS");
std::env::remove_var("GENARRATIVE_TRACKING_OUTBOX_ENABLED");
std::env::remove_var("GENARRATIVE_TRACKING_OUTBOX_DIR");
std::env::remove_var("GENARRATIVE_TRACKING_OUTBOX_BATCH_SIZE");
@@ -1336,6 +1344,7 @@ mod tests {
std::env::set_var("GENARRATIVE_API_GALLERY_MAX_CONCURRENT_REQUESTS", "64");
std::env::set_var("GENARRATIVE_API_DETAIL_MAX_CONCURRENT_REQUESTS", "32");
std::env::set_var("GENARRATIVE_API_ADMIN_MAX_CONCURRENT_REQUESTS", "16");
std::env::set_var("GENARRATIVE_API_SHUTDOWN_OUTBOX_FLUSH_TIMEOUT_MS", "3000");
std::env::set_var("GENARRATIVE_TRACKING_OUTBOX_ENABLED", "false");
std::env::set_var(
"GENARRATIVE_TRACKING_OUTBOX_DIR",
@@ -1354,6 +1363,10 @@ mod tests {
assert_eq!(config.gallery_max_concurrent_requests, Some(64));
assert_eq!(config.detail_max_concurrent_requests, Some(32));
assert_eq!(config.admin_max_concurrent_requests, Some(16));
assert_eq!(
config.shutdown_outbox_flush_timeout,
std::time::Duration::from_millis(3_000)
);
assert!(!config.tracking_outbox_enabled);
assert_eq!(
config.tracking_outbox_dir,
@@ -1374,6 +1387,7 @@ mod tests {
std::env::remove_var("GENARRATIVE_API_GALLERY_MAX_CONCURRENT_REQUESTS");
std::env::remove_var("GENARRATIVE_API_DETAIL_MAX_CONCURRENT_REQUESTS");
std::env::remove_var("GENARRATIVE_API_ADMIN_MAX_CONCURRENT_REQUESTS");
std::env::remove_var("GENARRATIVE_API_SHUTDOWN_OUTBOX_FLUSH_TIMEOUT_MS");
std::env::remove_var("GENARRATIVE_TRACKING_OUTBOX_ENABLED");
std::env::remove_var("GENARRATIVE_TRACKING_OUTBOX_DIR");
std::env::remove_var("GENARRATIVE_TRACKING_OUTBOX_BATCH_SIZE");

View File

@@ -1,7 +1,15 @@
use axum::{Json, extract::Extension};
use axum::{
Json,
extract::{Extension, State},
http::StatusCode,
response::{IntoResponse, Response},
};
use serde_json::{Value, json};
use crate::{api_response::json_success_body, request_context::RequestContext};
use crate::{
api_response::json_success_body, http_error::AppError, request_context::RequestContext,
state::AppState,
};
pub async fn health_check(Extension(request_context): Extension<RequestContext>) -> Json<Value> {
json_success_body(
@@ -12,3 +20,28 @@ pub async fn health_check(Extension(request_context): Extension<RequestContext>)
}),
)
}
pub async fn readiness_check(
State(state): State<AppState>,
Extension(request_context): Extension<RequestContext>,
) -> Response {
if state.is_ready() {
return json_success_body(
Some(&request_context),
json!({
"ok": true,
"ready": true,
"service": "genarrative-api-server",
}),
)
.into_response();
}
AppError::from_status(StatusCode::SERVICE_UNAVAILABLE)
.with_message("api-server 正在退出,不再接收新流量")
.with_details(json!({
"reason": "api_server_draining",
"ready": false,
}))
.into_response_with_context(Some(&request_context))
}

View File

@@ -99,25 +99,35 @@ use shared_logging::{OtelConfig, init_tracing};
use socket2::{Domain, Protocol, Socket, Type};
use std::{
collections::HashSet,
env, fs, io,
env, fs, future, io,
net::{SocketAddr, TcpListener as StdTcpListener},
panic, thread,
panic,
sync::Arc,
thread,
time::Duration,
};
use tokio::net::TcpListener;
use tokio::runtime::Builder as TokioRuntimeBuilder;
use tokio::time::timeout;
use tracing::{error, info};
use tracing::{error, info, warn};
use crate::{
app::{build_router, build_spacetime_unavailable_router},
config::AppConfig,
state::{AppState, AppStateInitError},
tracking_outbox::TrackingOutbox,
};
const API_SERVER_STARTUP_STACK_SIZE_BYTES: usize = 32 * 1024 * 1024;
const AUTH_STORE_STARTUP_RESTORE_TIMEOUT: Duration = Duration::from_secs(8);
#[derive(Clone)]
struct ShutdownContext {
app_state: Option<AppState>,
tracking_outbox: Option<Arc<TrackingOutbox>>,
outbox_flush_timeout: Duration,
}
fn main() -> Result<(), io::Error> {
// Windows 本地调试下 Axum 路由树和启动恢复链较重,显式放大启动线程栈,避免 debug 构建在进入监听前栈溢出。
let server_thread = thread::Builder::new()
@@ -158,19 +168,33 @@ async fn run_server(config: AppConfig) -> Result<(), io::Error> {
let listen_backlog = config.listen_backlog;
let worker_threads = config.worker_threads;
let otel_enabled = config.otel_enabled;
let outbox_flush_timeout = config.shutdown_outbox_flush_timeout;
let listener = build_tcp_listener(bind_address, listen_backlog)?;
let router = match restore_app_state_for_startup(config).await {
let (router, shutdown_context) = match restore_app_state_for_startup(config).await {
Ok(state) => {
state.puzzle_gallery_cache().spawn_cleanup_task();
if let Some(outbox) = state.tracking_outbox() {
let tracking_outbox = state.tracking_outbox();
if let Some(outbox) = tracking_outbox.clone() {
outbox.spawn_worker();
}
build_router(state)
}
Err(AppStateInitError::DependencyUnavailable(message)) => {
build_spacetime_unavailable_router(message)
(
build_router(state.clone()),
ShutdownContext {
app_state: Some(state),
tracking_outbox,
outbox_flush_timeout,
},
)
}
Err(AppStateInitError::DependencyUnavailable(message)) => (
build_spacetime_unavailable_router(message),
ShutdownContext {
app_state: None,
tracking_outbox: None,
outbox_flush_timeout,
},
),
Err(error) => {
return Err(std::io::Error::other(format!(
"初始化应用状态失败:{error}"
@@ -186,7 +210,98 @@ async fn run_server(config: AppConfig) -> Result<(), io::Error> {
"api-server 已完成 tracing 初始化并开始监听"
);
axum::serve(listener, router).await
let result = axum::serve(listener, router)
.with_graceful_shutdown(shutdown_signal(shutdown_context.clone()))
.await;
finalize_shutdown(shutdown_context).await;
result
}
async fn shutdown_signal(context: ShutdownContext) {
let signal = wait_for_shutdown_signal().await;
if let Some(state) = context.app_state.as_ref() {
state.mark_not_ready();
}
info!(
signal,
"api-server 收到退出信号,已标记 readiness 不可用并开始排空 HTTP 请求"
);
}
async fn wait_for_shutdown_signal() -> &'static str {
#[cfg(unix)]
{
tokio::select! {
signal = wait_for_ctrl_c_signal() => signal,
signal = wait_for_sigterm_signal() => signal,
}
}
#[cfg(not(unix))]
{
wait_for_ctrl_c_signal().await
}
}
async fn wait_for_ctrl_c_signal() -> &'static str {
if let Err(error) = tokio::signal::ctrl_c().await {
error!(error = %error, "监听 SIGINT 失败,无法通过 Ctrl-C 触发优雅退出");
future::pending::<()>().await;
}
"sigint"
}
#[cfg(unix)]
async fn wait_for_sigterm_signal() -> &'static str {
let mut signal = match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
{
Ok(signal) => signal,
Err(error) => {
error!(error = %error, "监听 SIGTERM 失败,无法通过 systemd terminate 触发优雅退出");
future::pending::<()>().await;
unreachable!("pending future never returns");
}
};
signal.recv().await;
"sigterm"
}
async fn finalize_shutdown(context: ShutdownContext) {
if let Some(state) = context.app_state.as_ref() {
state.mark_not_ready();
}
let Some(outbox) = context.tracking_outbox else {
return;
};
if context.outbox_flush_timeout.is_zero() {
warn!("api-server 退出时 tracking outbox flush timeout 为 0跳过主动 flush");
return;
}
let timeout_ms = context
.outbox_flush_timeout
.as_millis()
.min(u128::from(u64::MAX)) as u64;
info!(timeout_ms, "api-server 退出前封存并 flush tracking outbox");
match timeout(context.outbox_flush_timeout, outbox.flush_for_shutdown()).await {
Ok(Ok(())) => {
info!("api-server 退出前 tracking outbox flush 完成");
}
Ok(Err(error)) => {
warn!(
error = %error,
"api-server 退出前 tracking outbox flush 未完成,已保留本地文件等待下次启动重试"
);
}
Err(_) => {
warn!(
timeout_ms,
"api-server 退出前 tracking outbox flush 超时,已保留本地文件等待下次启动重试"
);
}
}
}
fn build_tcp_listener(

View File

@@ -1,7 +1,12 @@
use axum::{Router, routing::get};
use crate::{health::health_check, state::AppState};
use crate::{
health::{health_check, readiness_check},
state::AppState,
};
pub fn router(_state: AppState) -> Router<AppState> {
Router::new().route("/healthz", get(health_check))
Router::new()
.route("/healthz", get(health_check))
.route("/readyz", get(readiness_check))
}

View File

@@ -2,7 +2,10 @@ use std::{
collections::HashMap,
error::Error,
fmt,
sync::{Arc, Mutex},
sync::{
Arc, Mutex,
atomic::{AtomicBool, Ordering},
},
};
use axum::extract::FromRef;
@@ -229,6 +232,7 @@ pub struct AppStateInner {
// 配置会在后续中间件、路由和平台适配接入时逐步消费。
#[allow(dead_code)]
pub config: AppConfig,
ready: AtomicBool,
http_request_permit_pools: HttpRequestPermitPools,
auth_jwt_config: JwtConfig,
admin_runtime: Option<AdminRuntime>,
@@ -399,6 +403,7 @@ impl AppState {
Ok(Self(Arc::new(AppStateInner {
config,
ready: AtomicBool::new(true),
http_request_permit_pools,
auth_jwt_config,
admin_runtime,
@@ -447,6 +452,14 @@ impl AppState {
self.http_request_permit_pools.clone()
}
pub fn is_ready(&self) -> bool {
self.ready.load(Ordering::Acquire)
}
pub fn mark_not_ready(&self) {
self.ready.store(false, Ordering::Release);
}
pub async fn upsert_creation_entry_type_config(
&self,
input: module_runtime::CreationEntryTypeAdminUpsertInput,

View File

@@ -159,6 +159,16 @@ impl TrackingOutbox {
});
}
pub async fn flush_for_shutdown(&self) -> Result<(), TrackingOutboxError> {
{
let mut inner = self.inner.lock().await;
self.ensure_initialized_locked(&mut inner).await?;
self.seal_active_locked(&mut inner, "shutdown").await?;
}
self.flush_sealed_files_once().await
}
async fn seal_active_if_due(&self) -> Result<(), TrackingOutboxError> {
let mut inner = self.inner.lock().await;
self.ensure_initialized_locked(&mut inner).await?;
@@ -176,7 +186,11 @@ impl TrackingOutbox {
crate::telemetry::update_tracking_outbox_pending_files(sealed_files.len());
for path in sealed_files {
let started_at = Instant::now();
let metadata = fs::metadata(&path).await?;
let metadata = match fs::metadata(&path).await {
Ok(metadata) => metadata,
Err(error) if error.kind() == std::io::ErrorKind::NotFound => continue,
Err(error) => return Err(error.into()),
};
let file_bytes = metadata.len();
let events = match read_outbox_events(&path).await {
Ok(events) => events,
@@ -203,7 +217,11 @@ impl TrackingOutbox {
match self.spacetime_client.record_tracking_events(events).await {
Ok(accepted_count) => {
fs::remove_file(&path).await?;
match fs::remove_file(&path).await {
Ok(()) => {}
Err(error) if error.kind() == std::io::ErrorKind::NotFound => {}
Err(error) => return Err(error.into()),
}
self.subtract_total_bytes(file_bytes).await;
crate::telemetry::record_tracking_outbox_flush(
started_at.elapsed(),
@@ -596,6 +614,34 @@ mod tests {
let _ = std::fs::remove_dir_all(dir);
}
#[tokio::test]
async fn shutdown_flush_seals_active_file_for_later_retry() {
let dir = test_dir("shutdown");
let outbox = test_outbox(dir.clone(), 500, 1024 * 1024);
outbox.enqueue(sample_event("event-1")).await.unwrap();
let result = outbox.flush_for_shutdown().await;
assert!(
matches!(result, Err(TrackingOutboxError::Spacetime(_))),
"missing test SpacetimeDB should keep sealed file for retry"
);
assert!(!dir.join(ACTIVE_FILE_NAME).exists());
let sealed_count = std::fs::read_dir(&dir)
.unwrap()
.filter_map(Result::ok)
.filter(|entry| {
entry
.file_name()
.to_str()
.is_some_and(|name| name.starts_with(SEALED_FILE_PREFIX))
})
.count();
assert_eq!(sealed_count, 1);
let _ = std::fs::remove_dir_all(dir);
}
#[test]
fn directory_size_excludes_quarantined_corrupt_files() {
let dir = test_dir("directory-size");

View File

@@ -6,9 +6,10 @@ license.workspace = true
[dependencies]
base64 = { workspace = true }
curl = { workspace = true }
image = { workspace = true, features = ["jpeg", "png", "webp"] }
reqwest = { workspace = true, features = ["json", "multipart", "rustls-tls"] }
serde_json = { workspace = true }
tokio = { workspace = true, features = ["time"] }
tokio = { workspace = true, features = ["io-util", "macros", "net", "time"] }
tracing = { workspace = true }
platform-oss = { workspace = true }

View File

@@ -1,16 +1,22 @@
use reqwest::header;
use std::time::{SystemTime, UNIX_EPOCH};
const VECTOR_ENGINE_SEND_MAX_ATTEMPTS: u32 = 5;
const VECTOR_ENGINE_SEND_RETRY_BASE_DELAY_MS: u64 = 500;
const VECTOR_ENGINE_SEND_RETRY_MAX_JITTER_MS: u64 = 999;
use super::{
constants::{GPT_IMAGE_2_MODEL, VECTOR_ENGINE_PROVIDER},
curl_transport::{
map_curl_error, send_vector_engine_json_request_with_curl,
send_vector_engine_multipart_edit_request_with_curl,
},
error::PlatformImageError,
image_source::resolve_reference_images,
request::{
build_prompt_with_negative, build_vector_engine_image_edit_request_log_params,
build_vector_engine_image_request_body, normalize_image_size,
vector_engine_images_edit_url, vector_engine_images_generation_url,
build_vector_engine_image_edit_request_log_params, build_vector_engine_image_request_body,
normalize_image_size, vector_engine_images_edit_url, vector_engine_images_generation_url,
},
response::handle_vector_engine_response,
transport::map_reqwest_error,
types::{GeneratedImages, ReferenceImage, VectorEngineImageSettings},
};
@@ -50,63 +56,69 @@ pub async fn create_vector_engine_image_generation(
reference_images,
);
let started_at = std::time::Instant::now();
let response = match http_client
.post(request_url.as_str())
.header(
header::AUTHORIZATION,
format!("Bearer {}", settings.api_key),
let mut attempt = 1;
let response = loop {
match send_vector_engine_json_request_with_curl(
request_url.as_str(),
settings.api_key.as_str(),
&request_body,
settings.request_timeout_ms,
)
.header(header::ACCEPT, "application/json")
.header(header::CONTENT_TYPE, "application/json")
.json(&request_body)
.send()
.await
{
Ok(response) => response,
Err(error) => {
return Err(map_reqwest_error(
format!("{failure_context}:创建图片生成任务失败").as_str(),
request_url.as_str(),
"request_send",
error,
started_at.elapsed().as_millis() as u64,
Some(prompt.chars().count()),
Some(reference_images.len()),
Some(&request_body),
));
{
Ok(response) => break response,
Err(error) => {
if should_retry_vector_engine_curl_send_error(&error, attempt) {
retry_vector_engine_send_after_delay(
"generation",
request_url.as_str(),
"request_send",
attempt,
error.is_timeout(),
error.is_connect(),
true,
false,
error.to_string().as_str(),
started_at.elapsed().as_millis() as u64,
Some(prompt.chars().count()),
Some(reference_images.len()),
Some(&request_body),
)
.await;
attempt += 1;
continue;
}
return Err(map_curl_error(
format!("{failure_context}:创建图片生成任务失败").as_str(),
request_url.as_str(),
"request_send",
error,
started_at.elapsed().as_millis() as u64,
Some(prompt.chars().count()),
Some(reference_images.len()),
Some(&request_body),
));
}
}
};
let response_status = response.status();
let response_status = response.status;
tracing::info!(
provider = VECTOR_ENGINE_PROVIDER,
endpoint = %request_url,
status = response_status.as_u16(),
status = response_status,
prompt_chars = prompt.chars().count(),
size = %normalized_size,
reference_image_count = reference_images.len(),
attempt,
elapsed_ms = started_at.elapsed().as_millis() as u64,
failure_context,
"VectorEngine 图片生成 HTTP 返回"
);
let response_text = match response.text().await {
Ok(response_text) => response_text,
Err(error) => {
return Err(map_reqwest_error(
format!("{failure_context}:读取图片生成响应失败").as_str(),
request_url.as_str(),
"response_body",
error,
started_at.elapsed().as_millis() as u64,
Some(prompt.chars().count()),
Some(reference_images.len()),
Some(&request_body),
));
}
};
let response_text = response.body;
handle_vector_engine_response(
http_client,
request_url.as_str(),
response_status.as_u16(),
response_status,
response_text.as_str(),
failure_context,
started_at.elapsed().as_millis() as u64,
@@ -167,26 +179,6 @@ pub async fn create_vector_engine_image_edit_with_references(
reference_images,
);
let mut form = reqwest::multipart::Form::new()
.text("model", GPT_IMAGE_2_MODEL.to_string())
.text(
"prompt",
build_prompt_with_negative(prompt, negative_prompt),
)
.text("n", candidate_count.clamp(1, 4).to_string())
.text("size", normalized_size.clone());
for reference_image in reference_images.iter().take(5) {
let image_part = reqwest::multipart::Part::bytes(reference_image.bytes.clone())
.file_name(reference_image.file_name.clone())
.mime_str(reference_image.mime_type.as_str())
.map_err(|error| PlatformImageError::InvalidRequest {
provider: VECTOR_ENGINE_PROVIDER,
message: format!("{failure_context}:构造参考图失败:{error}"),
})?;
form = form.part("image", image_part);
}
let reference_image_count = reference_images.iter().take(5).count();
let reference_image_bytes_total: usize = reference_images
.iter()
@@ -214,64 +206,75 @@ pub async fn create_vector_engine_image_edit_with_references(
failure_context,
"VectorEngine 图片编辑请求参数"
);
let response = match http_client
.post(request_url.as_str())
.header(
header::AUTHORIZATION,
format!("Bearer {}", settings.api_key),
let mut attempt = 1;
let response = loop {
match send_vector_engine_multipart_edit_request_with_curl(
request_url.as_str(),
settings.api_key.as_str(),
prompt,
negative_prompt,
normalized_size.as_str(),
candidate_count,
reference_images,
settings.request_timeout_ms,
)
.header(header::ACCEPT, "application/json")
.multipart(form)
.send()
.await
{
Ok(response) => response,
Err(error) => {
return Err(map_reqwest_error(
format!("{failure_context}:创建图片编辑任务失败").as_str(),
request_url.as_str(),
"request_send",
error,
started_at.elapsed().as_millis() as u64,
Some(prompt.chars().count()),
Some(reference_image_count),
Some(&request_params),
));
{
Ok(response) => break response,
Err(error) => {
if should_retry_vector_engine_curl_send_error(&error, attempt) {
retry_vector_engine_send_after_delay(
"edit",
request_url.as_str(),
"request_send",
attempt,
error.is_timeout(),
error.is_connect(),
true,
false,
error.to_string().as_str(),
started_at.elapsed().as_millis() as u64,
Some(prompt.chars().count()),
Some(reference_image_count),
Some(&request_params),
)
.await;
attempt += 1;
continue;
}
return Err(map_curl_error(
format!("{failure_context}:创建图片编辑任务失败").as_str(),
request_url.as_str(),
"request_send",
error,
started_at.elapsed().as_millis() as u64,
Some(prompt.chars().count()),
Some(reference_image_count),
Some(&request_params),
));
}
}
};
let response_status = response.status();
let response_status = response.status;
tracing::info!(
provider = VECTOR_ENGINE_PROVIDER,
endpoint = %request_url,
status = response_status.as_u16(),
status = response_status,
prompt_chars = prompt.chars().count(),
size = %normalized_size,
reference_image_count,
reference_image_bytes_total,
request_params = %request_params,
attempt,
elapsed_ms = started_at.elapsed().as_millis() as u64,
failure_context,
"VectorEngine 图片编辑 HTTP 返回"
);
let response_text = match response.text().await {
Ok(response_text) => response_text,
Err(error) => {
return Err(map_reqwest_error(
format!("{failure_context}:读取图片编辑响应失败").as_str(),
request_url.as_str(),
"response_body",
error,
started_at.elapsed().as_millis() as u64,
Some(prompt.chars().count()),
Some(reference_image_count),
Some(&request_params),
));
}
};
let response_text = response.body;
handle_vector_engine_response(
http_client,
request_url.as_str(),
response_status.as_u16(),
response_status,
response_text.as_str(),
failure_context,
started_at.elapsed().as_millis() as u64,
@@ -282,3 +285,84 @@ pub async fn create_vector_engine_image_edit_with_references(
)
.await
}
fn should_retry_vector_engine_curl_send_error(
error: &super::curl_transport::VectorEngineCurlError,
attempt: u32,
) -> bool {
attempt < VECTOR_ENGINE_SEND_MAX_ATTEMPTS && (error.is_timeout() || error.is_connect())
}
async fn retry_vector_engine_send_after_delay(
request_kind: &'static str,
request_url: &str,
failure_stage: &'static str,
attempt: u32,
timeout: bool,
connect: bool,
request: bool,
body: bool,
error: &str,
elapsed_ms: u64,
prompt_chars: Option<usize>,
reference_image_count: Option<usize>,
request_params: Option<&serde_json::Value>,
) {
let delay_ms = vector_engine_send_retry_delay_ms(attempt, vector_engine_send_retry_jitter_ms());
tracing::warn!(
provider = VECTOR_ENGINE_PROVIDER,
endpoint = %request_url,
request_kind,
failure_stage,
attempt,
max_attempts = VECTOR_ENGINE_SEND_MAX_ATTEMPTS,
retry_delay_ms = delay_ms,
timeout,
connect,
request,
body,
status = 0,
error,
elapsed_ms,
prompt_chars,
reference_image_count,
request_params = %request_params
.map(|value| value.to_string())
.unwrap_or_default(),
"VectorEngine 图片请求发送失败,准备重试"
);
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
}
fn vector_engine_send_retry_delay_ms(attempt: u32, jitter_ms: u64) -> u64 {
let exponential_factor = 1_u64 << attempt.saturating_sub(1).min(10);
let bounded_jitter_ms = jitter_ms.min(VECTOR_ENGINE_SEND_RETRY_MAX_JITTER_MS);
VECTOR_ENGINE_SEND_RETRY_BASE_DELAY_MS * exponential_factor + bounded_jitter_ms
}
fn vector_engine_send_retry_jitter_ms() -> u64 {
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|duration| duration.subsec_nanos())
.unwrap_or_default();
u64::from(nanos) % (VECTOR_ENGINE_SEND_RETRY_MAX_JITTER_MS + 1)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn vector_engine_send_retry_policy_allows_four_retries_before_final_attempt() {
assert_eq!(VECTOR_ENGINE_SEND_MAX_ATTEMPTS, 5);
}
#[test]
fn vector_engine_send_retry_delay_uses_exponential_backoff_with_bounded_jitter() {
assert_eq!(vector_engine_send_retry_delay_ms(1, 0), 500);
assert_eq!(vector_engine_send_retry_delay_ms(2, 0), 1_000);
assert_eq!(vector_engine_send_retry_delay_ms(3, 0), 2_000);
assert_eq!(vector_engine_send_retry_delay_ms(4, 0), 4_000);
assert_eq!(vector_engine_send_retry_delay_ms(4, 999), 4_999);
}
}

View File

@@ -0,0 +1,406 @@
use std::{error::Error, fmt, time::Duration};
use curl::{
FormError,
easy::{Easy, Form, List},
};
use serde_json::Value;
use super::{
audit::build_failure_audit,
constants::{GPT_IMAGE_2_MODEL, VECTOR_ENGINE_PROVIDER},
error::PlatformImageError,
request::build_prompt_with_negative,
types::ReferenceImage,
};
#[derive(Debug)]
pub(crate) struct VectorEngineCurlResponse {
pub(crate) status: u16,
pub(crate) body: String,
}
#[derive(Debug)]
pub(crate) enum VectorEngineCurlError {
Curl(curl::Error),
Form(FormError),
WorkerJoin(tokio::task::JoinError),
}
impl VectorEngineCurlError {
pub(crate) fn is_timeout(&self) -> bool {
match self {
Self::Curl(error) => error.is_operation_timedout(),
Self::Form(_) | Self::WorkerJoin(_) => false,
}
}
pub(crate) fn is_connect(&self) -> bool {
match self {
Self::Curl(error) => {
error.is_couldnt_connect()
|| error.is_couldnt_resolve_host()
|| error.is_couldnt_resolve_proxy()
}
Self::Form(_) | Self::WorkerJoin(_) => false,
}
}
}
impl fmt::Display for VectorEngineCurlError {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Curl(error) => write!(formatter, "{error}"),
Self::Form(error) => write!(formatter, "multipart form error: {error}"),
Self::WorkerJoin(error) => write!(formatter, "curl worker join failed: {error}"),
}
}
}
impl Error for VectorEngineCurlError {}
impl From<curl::Error> for VectorEngineCurlError {
fn from(error: curl::Error) -> Self {
Self::Curl(error)
}
}
impl From<FormError> for VectorEngineCurlError {
fn from(error: FormError) -> Self {
Self::Form(error)
}
}
pub(crate) async fn send_vector_engine_json_request_with_curl(
request_url: &str,
api_key: &str,
request_body: &Value,
timeout_ms: u64,
) -> Result<VectorEngineCurlResponse, VectorEngineCurlError> {
let request_url = request_url.to_string();
let api_key = api_key.to_string();
let request_body = request_body.to_string();
tokio::task::spawn_blocking(move || {
send_json_request_with_curl_blocking(
request_url.as_str(),
api_key.as_str(),
request_body.as_str(),
timeout_ms,
)
})
.await
.map_err(VectorEngineCurlError::WorkerJoin)?
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn send_vector_engine_multipart_edit_request_with_curl(
request_url: &str,
api_key: &str,
prompt: &str,
negative_prompt: Option<&str>,
normalized_size: &str,
candidate_count: u32,
reference_images: &[ReferenceImage],
timeout_ms: u64,
) -> Result<VectorEngineCurlResponse, VectorEngineCurlError> {
let request_url = request_url.to_string();
let api_key = api_key.to_string();
let prompt = prompt.to_string();
let negative_prompt = negative_prompt.map(str::to_string);
let normalized_size = normalized_size.to_string();
let reference_images = reference_images.iter().take(5).cloned().collect::<Vec<_>>();
tokio::task::spawn_blocking(move || {
send_multipart_edit_request_with_curl_blocking(
request_url.as_str(),
api_key.as_str(),
prompt.as_str(),
negative_prompt.as_deref(),
normalized_size.as_str(),
candidate_count,
reference_images.as_slice(),
timeout_ms,
)
})
.await
.map_err(VectorEngineCurlError::WorkerJoin)?
}
pub(crate) fn map_curl_error(
context: &str,
request_url: &str,
failure_stage: &'static str,
error: VectorEngineCurlError,
latency_ms: u64,
prompt_chars: Option<usize>,
reference_image_count: Option<usize>,
request_params: Option<&Value>,
) -> PlatformImageError {
let is_timeout = error.is_timeout();
let is_connect = error.is_connect();
let source = error.to_string();
let message = format!("{context}{source}");
let audit = build_failure_audit(
request_url,
context,
failure_stage,
None,
None,
is_timeout,
is_connect,
message.as_str(),
Some(source.clone()),
None,
Some(latency_ms),
prompt_chars,
reference_image_count,
);
tracing::warn!(
provider = VECTOR_ENGINE_PROVIDER,
endpoint = %request_url,
failure_stage,
timeout = is_timeout,
connect = is_connect,
request = true,
body = false,
status = 0,
source = %source,
source_chain = %source,
source_chain_depth = 1,
message = %message,
elapsed_ms = latency_ms,
prompt_chars,
reference_image_count,
request_params = %request_params
.map(|value| value.to_string())
.unwrap_or_default(),
"VectorEngine 图片 libcurl 请求失败"
);
PlatformImageError::Request {
provider: VECTOR_ENGINE_PROVIDER,
message,
endpoint: Some(request_url.to_string()),
timeout: is_timeout,
connect: is_connect,
request: true,
body: false,
status_code: None,
source: Some(source),
audit: Some(audit),
}
}
fn send_json_request_with_curl_blocking(
request_url: &str,
api_key: &str,
request_body: &str,
timeout_ms: u64,
) -> Result<VectorEngineCurlResponse, VectorEngineCurlError> {
let mut headers = vector_engine_curl_headers(api_key)?;
headers.append("Content-Type: application/json")?;
let mut easy = Easy::new();
easy.url(request_url)?;
easy.post(true)?;
easy.http_headers(headers)?;
easy.timeout(Duration::from_millis(timeout_ms.max(1)))?;
easy.post_fields_copy(request_body.as_bytes())?;
Ok(perform_curl_request(easy)?)
}
#[allow(clippy::too_many_arguments)]
fn send_multipart_edit_request_with_curl_blocking(
request_url: &str,
api_key: &str,
prompt: &str,
negative_prompt: Option<&str>,
normalized_size: &str,
candidate_count: u32,
reference_images: &[ReferenceImage],
timeout_ms: u64,
) -> Result<VectorEngineCurlResponse, VectorEngineCurlError> {
let mut form = Form::new();
form.part("model")
.contents(GPT_IMAGE_2_MODEL.as_bytes())
.add()?;
form.part("prompt")
.contents(build_prompt_with_negative(prompt, negative_prompt).as_bytes())
.add()?;
form.part("n")
.contents(candidate_count.clamp(1, 4).to_string().as_bytes())
.add()?;
form.part("size")
.contents(normalized_size.as_bytes())
.add()?;
for reference_image in reference_images {
form.part("image")
.buffer(
reference_image.file_name.as_str(),
reference_image.bytes.clone(),
)
.content_type(reference_image.mime_type.as_str())
.add()?;
}
let headers = vector_engine_curl_headers(api_key)?;
let mut easy = Easy::new();
easy.url(request_url)?;
easy.httppost(form)?;
easy.http_headers(headers)?;
easy.timeout(Duration::from_millis(timeout_ms.max(1)))?;
Ok(perform_curl_request(easy)?)
}
fn vector_engine_curl_headers(api_key: &str) -> Result<List, curl::Error> {
let mut headers = List::new();
headers.append(format!("Authorization: Bearer {api_key}").as_str())?;
headers.append("Accept: application/json")?;
Ok(headers)
}
fn perform_curl_request(mut easy: Easy) -> Result<VectorEngineCurlResponse, curl::Error> {
let mut body = Vec::new();
{
let mut transfer = easy.transfer();
transfer.write_function(|data| {
body.extend_from_slice(data);
Ok(data.len())
})?;
transfer.perform()?;
}
let status = easy.response_code()? as u16;
let body = String::from_utf8_lossy(body.as_slice()).into_owned();
Ok(VectorEngineCurlResponse { status, body })
}
#[cfg(test)]
mod tests {
use super::*;
use crate::vector_engine::types::ReferenceImage;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpListener,
sync::oneshot,
};
#[tokio::test]
async fn vector_engine_curl_transport_posts_json_request() {
let (base_url, server, request_rx) = start_single_response_server().await;
let response = send_vector_engine_json_request_with_curl(
format!("{base_url}/v1/images/generations").as_str(),
"test-key",
&serde_json::json!({"model":"gpt-image-2","prompt":"测试"}),
1_000,
)
.await
.expect("curl json request should succeed");
assert_eq!(response.status, 200);
assert_eq!(response.body, "{\"data\":[]}");
let request = request_rx
.await
.expect("mock server should capture request");
let request_text = String::from_utf8_lossy(request.as_slice());
assert!(request_text.contains("Content-Type: application/json"));
server.abort();
}
#[tokio::test]
async fn vector_engine_curl_transport_posts_multipart_request() {
let (base_url, server, request_rx) = start_single_response_server().await;
let response = send_vector_engine_multipart_edit_request_with_curl(
format!("{base_url}/v1/images/edits").as_str(),
"test-key",
"测试提示词",
None,
"1024x1024",
1,
&[ReferenceImage {
bytes: b"reference".to_vec(),
mime_type: "image/png".to_string(),
file_name: "reference.png".to_string(),
}],
1_000,
)
.await
.expect("curl multipart request should succeed");
assert_eq!(response.status, 200);
assert_eq!(response.body, "{\"data\":[]}");
let request = request_rx
.await
.expect("mock server should capture request");
let request_text = String::from_utf8_lossy(request.as_slice());
assert!(request_text.contains("name=\"image\"; filename=\"reference.png\""));
assert!(request_text.contains("Content-Type: image/png"));
assert!(request_text.contains("reference"));
server.abort();
}
async fn start_single_response_server() -> (
String,
tokio::task::JoinHandle<()>,
oneshot::Receiver<Vec<u8>>,
) {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("mock server should bind");
let addr = listener
.local_addr()
.expect("mock server addr should be readable");
let (request_tx, request_rx) = oneshot::channel();
let server = tokio::spawn(async move {
let Ok((mut stream, _)) = listener.accept().await else {
return;
};
let mut request = Vec::new();
let mut buffer = [0_u8; 4096];
loop {
let Ok(read) = stream.read(&mut buffer).await else {
return;
};
if read == 0 {
return;
}
request.extend_from_slice(&buffer[..read]);
if request.windows(4).any(|window| window == b"\r\n\r\n") {
break;
}
}
let header_end = request
.windows(4)
.position(|window| window == b"\r\n\r\n")
.map(|index| index + 4)
.unwrap_or(request.len());
let headers = String::from_utf8_lossy(&request[..header_end]);
let content_length = headers
.lines()
.find_map(|line| {
line.strip_prefix("Content-Length:")
.or_else(|| line.strip_prefix("content-length:"))
})
.and_then(|value| value.trim().parse::<usize>().ok())
.unwrap_or_default();
let expected_len = header_end + content_length;
while request.len() < expected_len {
let Ok(read) = stream.read(&mut buffer).await else {
return;
};
if read == 0 {
break;
}
request.extend_from_slice(&buffer[..read]);
}
let _ = request_tx.send(request);
let body = "{\"data\":[]}";
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
body.len(),
body
);
let _ = stream.write_all(response.as_bytes()).await;
});
(format!("http://{addr}"), server, request_rx)
}
}

View File

@@ -1,6 +1,7 @@
mod audit;
mod client;
mod constants;
mod curl_transport;
mod error;
mod image_source;
mod payload;

View File

@@ -1,10 +1,7 @@
use std::{error::Error, time::Duration};
use serde_json::Value;
use std::time::Duration;
use super::{
audit::build_failure_audit, constants::VECTOR_ENGINE_PROVIDER, error::PlatformImageError,
types::VectorEngineImageSettings,
constants::VECTOR_ENGINE_PROVIDER, error::PlatformImageError, types::VectorEngineImageSettings,
};
pub fn build_vector_engine_image_http_client(
@@ -20,130 +17,3 @@ pub fn build_vector_engine_image_http_client(
message: format!("构造 VectorEngine 图片生成 HTTP 客户端失败:{error}"),
})
}
pub(super) fn map_reqwest_error(
context: &str,
request_url: &str,
failure_stage: &'static str,
error: reqwest::Error,
latency_ms: u64,
prompt_chars: Option<usize>,
reference_image_count: Option<usize>,
request_params: Option<&Value>,
) -> PlatformImageError {
let is_timeout = error.is_timeout();
let is_connect = error.is_connect();
let source_chain_parts = collect_error_source_chain(&error);
let source = source_chain_parts.first().cloned();
let source_chain_depth = source_chain_parts.len();
let source_chain = if source_chain_parts.is_empty() {
None
} else {
Some(source_chain_parts.join(" -> "))
};
let message = format!("{context}{error}");
let audit = build_failure_audit(
request_url,
context,
failure_stage,
error.status().map(|status| status.as_u16()),
None,
is_timeout,
is_connect,
message.as_str(),
source_chain.clone().or_else(|| source.clone()),
None,
Some(latency_ms),
prompt_chars,
reference_image_count,
);
tracing::warn!(
provider = VECTOR_ENGINE_PROVIDER,
endpoint = %request_url,
failure_stage,
timeout = is_timeout,
connect = is_connect,
request = error.is_request(),
body = error.is_body(),
status = error.status().map(|status| status.as_u16()).unwrap_or_default(),
source = %source.clone().unwrap_or_default(),
source_chain = %source_chain.clone().unwrap_or_default(),
source_chain_depth,
message = %message,
elapsed_ms = latency_ms,
prompt_chars,
reference_image_count,
request_params = %request_params
.map(|value| value.to_string())
.unwrap_or_default(),
"VectorEngine 图片请求发送失败"
);
PlatformImageError::Request {
provider: VECTOR_ENGINE_PROVIDER,
message,
endpoint: Some(request_url.to_string()),
timeout: is_timeout,
connect: is_connect,
request: error.is_request(),
body: error.is_body(),
status_code: error.status().map(|status| status.as_u16()),
source: source_chain.or(source),
audit: Some(audit),
}
}
fn collect_error_source_chain(error: &(dyn Error + 'static)) -> Vec<String> {
let mut chain = Vec::new();
let mut next = error.source();
while let Some(source) = next {
chain.push(source.to_string());
next = source.source();
}
chain
}
#[cfg(test)]
mod tests {
use super::*;
use std::fmt;
#[derive(Debug)]
struct TestError {
message: &'static str,
source: Option<Box<TestError>>,
}
impl fmt::Display for TestError {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str(self.message)
}
}
impl Error for TestError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
self.source
.as_deref()
.map(|source| source as &(dyn Error + 'static))
}
}
#[test]
fn collect_error_source_chain_keeps_nested_causes() {
let error = TestError {
message: "top",
source: Some(Box::new(TestError {
message: "middle",
source: Some(Box::new(TestError {
message: "bottom",
source: None,
})),
})),
};
assert_eq!(
collect_error_source_chain(&error),
vec!["middle".to_string(), "bottom".to_string()]
);
}
}

View File

@@ -1,8 +1,20 @@
use platform_image::vector_engine::{
GPT_IMAGE_2_MODEL, VECTOR_ENGINE_PROVIDER, VectorEngineImageSettings,
build_vector_engine_image_request_body, vector_engine_images_edit_url,
GPT_IMAGE_2_MODEL, ReferenceImage, VECTOR_ENGINE_PROVIDER, VectorEngineImageSettings,
build_vector_engine_image_http_client, build_vector_engine_image_request_body,
create_vector_engine_image_edit, vector_engine_images_edit_url,
vector_engine_images_generation_url,
};
use std::{
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
time::Duration,
};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpListener,
};
#[test]
fn vector_engine_module_exposes_provider_protocol_helpers() {
@@ -30,3 +42,70 @@ fn vector_engine_module_exposes_provider_protocol_helpers() {
"https://vector.example/v1/images/edits"
);
}
#[tokio::test]
async fn vector_engine_image_edit_retries_send_timeout_once_and_succeeds() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("mock server should bind");
let server_addr = listener
.local_addr()
.expect("mock server address should be readable");
let request_count = Arc::new(AtomicUsize::new(0));
let request_count_for_server = Arc::clone(&request_count);
let server = tokio::spawn(async move {
loop {
let Ok((mut stream, _)) = listener.accept().await else {
break;
};
let request_index = request_count_for_server.fetch_add(1, Ordering::SeqCst);
tokio::spawn(async move {
let mut buffer = [0_u8; 4096];
let _ = stream.read(&mut buffer).await;
if request_index == 0 {
tokio::time::sleep(Duration::from_millis(120)).await;
return;
}
let body = r#"{"data":[{"b64_json":"iVBORw0KGgpyZXN0"}]}"#;
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
body.len(),
body
);
let _ = stream.write_all(response.as_bytes()).await;
});
}
});
let settings = VectorEngineImageSettings {
base_url: format!("http://{server_addr}/v1"),
api_key: "test-key".to_string(),
request_timeout_ms: 40,
};
let http_client =
build_vector_engine_image_http_client(&settings).expect("client should build");
let reference_image = ReferenceImage {
bytes: b"reference".to_vec(),
mime_type: "image/png".to_string(),
file_name: "reference.png".to_string(),
};
let generated = create_vector_engine_image_edit(
&http_client,
&settings,
"测试提示词",
None,
"1024x1024",
&reference_image,
"测试 VectorEngine 图片编辑失败",
)
.await
.expect("second attempt should return generated image");
assert_eq!(generated.images.len(), 1);
assert_eq!(generated.images[0].mime_type, "image/png");
assert_eq!(request_count.load(Ordering::SeqCst), 2);
server.abort();
}

View File

@@ -12,6 +12,7 @@ serde = { workspace = true }
serde_json = { workspace = true }
sha2 = { workspace = true }
time = { workspace = true, features = ["formatting"] }
tracing = { workspace = true }
[dev-dependencies]
tokio = { workspace = true, features = ["macros", "rt"] }

View File

@@ -22,6 +22,7 @@
5. 服务端 `PutObject` 上传 helper
6. `x-oss-meta-*` 元数据归一化与大小限制校验
7. `content-type``content-length-range``success_action_status` policy 条件生成
8. `PostObject` 签名、`GetObject` 读签名、`HEAD Object``PutObject` 的结构化日志
当前仍未落地的内容:
@@ -34,8 +35,9 @@
1. 当前产品口径为服务器上传 AI 生成资源、Web 端只负责读取。
2. 因此 `STS` 不作为默认上传主链,`api-server` 只暴露禁用式 contract避免浏览器拿到 OSS 写权限。
3. 服务端生成资源应优先复用 `OssClient::put_object`,上传成功后再走对象确认链路写入 `asset_object`
4. 读签名和 `HEAD Object` 的入参必须直接传 object_key不要把 bucket 名拼进路径;例如 `generated-square-hole-assets/.../image.png` 才是正确入参,`xushi-dev/...` 这类前缀不属于 object_key。
5. OSS V4 `x-oss-date` 必须固定为 `yyyyMMdd'T'HHmmss'Z'`,不能依赖 `time::Time::to_string()`;后者在小时小于 10 时可能输出非补零时间,导致签名格式错误。
4. 读签名和 `HEAD Object` 的入参必须直接传 object_key不要把 bucket 名拼进路径;例如 `generated-square-hole-assets/.../image.png` 才是正确入参,`xushi-dev/...` 这类前缀不属于 object_key。
5. OSS V4 `x-oss-date` 必须固定为 `yyyyMMdd'T'HHmmss'Z'`,不能依赖 `time::Time::to_string()`;后者在小时小于 10 时可能输出非补零时间,导致签名格式错误。
6. 结构化日志只记录 `provider``operation``bucket``endpoint``object_key` / `key_prefix``access``content_type``content_length``status``status_class``error_kind``elapsed_ms` 等排障字段;禁止输出 AccessKey、policy、signature、Authorization header 或完整 signed URL。
## 3. 边界约束

View File

@@ -1,4 +1,4 @@
use std::{collections::BTreeMap, error::Error, fmt};
use std::{collections::BTreeMap, error::Error, fmt, time::Instant};
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
use hmac::{Hmac, Mac};
@@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use sha2::{Digest, Sha256};
use time::{Duration, OffsetDateTime, format_description::well_known::Rfc3339};
use tracing::{info, warn};
type HmacSha256 = Hmac<Sha256>;
@@ -19,6 +20,7 @@ const OSS_V4_ALGORITHM: &str = "OSS4-HMAC-SHA256";
const OSS_V4_REQUEST: &str = "aliyun_v4_request";
const OSS_V4_SERVICE: &str = "oss";
const OSS_UNSIGNED_PAYLOAD: &str = "UNSIGNED-PAYLOAD";
const OSS_PROVIDER: &str = "aliyun-oss";
pub const LEGACY_PUBLIC_PREFIXES: [&str; 13] = [
"generated-character-drafts",
@@ -369,105 +371,154 @@ impl OssClient {
&self,
request: OssPostObjectRequest,
) -> Result<OssPostObjectResponse, OssError> {
let max_size_bytes = request
.max_size_bytes
.unwrap_or(self.config.default_post_max_size_bytes);
let expire_seconds = request
.expire_seconds
.unwrap_or(self.config.default_post_expire_seconds);
let success_action_status = request
.success_action_status
.unwrap_or(self.config.default_success_action_status);
let started_at = Instant::now();
let requested_prefix = request.prefix.as_str();
let requested_content_type = request
.content_type
.as_deref()
.map(str::trim)
.unwrap_or("")
.to_string();
let requested_metadata_count = request.metadata.len();
if max_size_bytes == 0 {
return Err(OssError::InvalidRequest(
"maxSizeBytes 必须大于 0".to_string(),
));
let result = (|| {
let max_size_bytes = request
.max_size_bytes
.unwrap_or(self.config.default_post_max_size_bytes);
let expire_seconds = request
.expire_seconds
.unwrap_or(self.config.default_post_expire_seconds);
let success_action_status = request
.success_action_status
.unwrap_or(self.config.default_success_action_status);
if max_size_bytes == 0 {
return Err(OssError::InvalidRequest(
"maxSizeBytes 必须大于 0".to_string(),
));
}
if expire_seconds == 0 {
return Err(OssError::InvalidRequest(
"expireSeconds 必须大于 0".to_string(),
));
}
if !(100..=999).contains(&success_action_status) {
return Err(OssError::InvalidRequest(
"successActionStatus 必须是三位 HTTP 状态码".to_string(),
));
}
let sanitized_segments = request
.path_segments
.iter()
.map(|segment| sanitize_path_segment(segment))
.filter(|segment| !segment.is_empty())
.collect::<Vec<_>>();
let file_name = sanitize_file_name(&request.file_name)?;
let object_key = build_object_key(request.prefix, &sanitized_segments, &file_name);
let legacy_public_path = format!("/{}", object_key);
let content_type = normalize_optional_value(request.content_type);
let metadata = normalize_metadata(request.metadata)?;
let expires_at = OffsetDateTime::now_utc()
.checked_add(Duration::seconds(i64::try_from(expire_seconds).map_err(
|_| OssError::InvalidRequest("expireSeconds 超出可支持范围".to_string()),
)?))
.ok_or_else(|| {
OssError::InvalidRequest("expireSeconds 计算结果溢出".to_string())
})?;
let expires_at = expires_at.format(&Rfc3339).map_err(|error| {
OssError::SerializePolicy(format!("格式化过期时间失败:{error}"))
})?;
let signed_at = OffsetDateTime::now_utc();
let signature_scope = build_v4_signature_scope(&self.config.endpoint, signed_at)?;
let signature_date = build_v4_signature_date(signed_at)?;
let credential = format!("{}/{}", self.config.access_key_id, signature_scope);
let policy_json = build_policy_json(
&self.config.bucket,
&object_key,
&expires_at,
max_size_bytes,
success_action_status,
content_type.as_deref(),
&metadata,
&credential,
&signature_date,
);
let policy = serde_json::to_string(&policy_json).map_err(|error| {
OssError::SerializePolicy(format!("序列化 policy 失败:{error}"))
})?;
let encoded_policy = BASE64_STANDARD.encode(policy.as_bytes());
let signature = sign_v4_content(
&self.config.access_key_secret,
&signature_scope,
&encoded_policy,
)?;
Ok(OssPostObjectResponse {
signature_version: "v4",
provider: OSS_PROVIDER,
bucket: self.config.bucket.clone(),
endpoint: self.config.endpoint.clone(),
host: self.config.upload_host(),
object_key: object_key.clone(),
legacy_public_path,
content_type: content_type.clone(),
access: request.access,
key_prefix: build_key_prefix(request.prefix, &sanitized_segments),
expires_at,
max_size_bytes,
success_action_status,
form_fields: OssPostObjectFormFields {
key: object_key,
policy: encoded_policy,
signature_version: OSS_V4_ALGORITHM.to_string(),
credential,
date: signature_date,
signature,
success_action_status: success_action_status.to_string(),
content_type,
metadata,
},
})
})();
match &result {
Ok(response) => info!(
provider = OSS_PROVIDER,
operation = "sign_post_object",
bucket = %response.bucket,
endpoint = %response.endpoint,
object_key = %response.object_key,
key_prefix = %response.key_prefix,
access = oss_access_label(response.access),
content_type = %response.content_type.as_deref().unwrap_or(""),
max_size_bytes = response.max_size_bytes,
success_action_status = response.success_action_status,
metadata_count = response.form_fields.metadata.len(),
expires_at = %response.expires_at,
elapsed_ms = elapsed_ms(started_at),
"OSS PostObject 签名完成"
),
Err(error) => warn!(
provider = OSS_PROVIDER,
operation = "sign_post_object",
bucket = %self.config.bucket(),
endpoint = %self.config.endpoint(),
key_prefix = requested_prefix,
content_type = %requested_content_type,
metadata_count = requested_metadata_count,
error_kind = oss_error_kind_label(error),
message = %error,
elapsed_ms = elapsed_ms(started_at),
"OSS PostObject 签名失败"
),
}
if expire_seconds == 0 {
return Err(OssError::InvalidRequest(
"expireSeconds 必须大于 0".to_string(),
));
}
if !(100..=999).contains(&success_action_status) {
return Err(OssError::InvalidRequest(
"successActionStatus 必须是三位 HTTP 状态码".to_string(),
));
}
let sanitized_segments = request
.path_segments
.iter()
.map(|segment| sanitize_path_segment(segment))
.filter(|segment| !segment.is_empty())
.collect::<Vec<_>>();
let file_name = sanitize_file_name(&request.file_name)?;
let object_key = build_object_key(request.prefix, &sanitized_segments, &file_name);
let legacy_public_path = format!("/{}", object_key);
let content_type = normalize_optional_value(request.content_type);
let metadata = normalize_metadata(request.metadata)?;
let expires_at = OffsetDateTime::now_utc()
.checked_add(Duration::seconds(i64::try_from(expire_seconds).map_err(
|_| OssError::InvalidRequest("expireSeconds 超出可支持范围".to_string()),
)?))
.ok_or_else(|| OssError::InvalidRequest("expireSeconds 计算结果溢出".to_string()))?;
let expires_at = expires_at
.format(&Rfc3339)
.map_err(|error| OssError::SerializePolicy(format!("格式化过期时间失败:{error}")))?;
let signed_at = OffsetDateTime::now_utc();
let signature_scope = build_v4_signature_scope(&self.config.endpoint, signed_at)?;
let signature_date = build_v4_signature_date(signed_at)?;
let credential = format!("{}/{}", self.config.access_key_id, signature_scope);
let policy_json = build_policy_json(
&self.config.bucket,
&object_key,
&expires_at,
max_size_bytes,
success_action_status,
content_type.as_deref(),
&metadata,
&credential,
&signature_date,
);
let policy = serde_json::to_string(&policy_json)
.map_err(|error| OssError::SerializePolicy(format!("序列化 policy 失败:{error}")))?;
let encoded_policy = BASE64_STANDARD.encode(policy.as_bytes());
let signature = sign_v4_content(
&self.config.access_key_secret,
&signature_scope,
&encoded_policy,
)?;
Ok(OssPostObjectResponse {
signature_version: "v4",
provider: "aliyun-oss",
bucket: self.config.bucket.clone(),
endpoint: self.config.endpoint.clone(),
host: self.config.upload_host(),
object_key: object_key.clone(),
legacy_public_path,
content_type: content_type.clone(),
access: request.access,
key_prefix: build_key_prefix(request.prefix, &sanitized_segments),
expires_at,
max_size_bytes,
success_action_status,
form_fields: OssPostObjectFormFields {
key: object_key,
policy: encoded_policy,
signature_version: OSS_V4_ALGORITHM.to_string(),
credential,
date: signature_date,
signature,
success_action_status: success_action_status.to_string(),
content_type,
metadata,
},
})
result
}
// 私有 bucket 的对象读取统一走短期签名 URL避免把长期主凭证下发给浏览器。
@@ -475,81 +526,119 @@ impl OssClient {
&self,
request: OssSignedGetObjectUrlRequest,
) -> Result<OssSignedGetObjectUrlResponse, OssError> {
let expire_seconds = request
.expire_seconds
.unwrap_or(self.config.default_read_expire_seconds);
let started_at = Instant::now();
let requested_object_key = request
.object_key
.trim()
.trim_start_matches('/')
.trim()
.to_string();
if expire_seconds == 0 {
return Err(OssError::InvalidRequest(
"expireSeconds 必须大于 0".to_string(),
));
let result = (|| {
let expire_seconds = request
.expire_seconds
.unwrap_or(self.config.default_read_expire_seconds);
if expire_seconds == 0 {
return Err(OssError::InvalidRequest(
"expireSeconds 必须大于 0".to_string(),
));
}
let object_key = normalize_object_key(&request.object_key)?;
let expires_at = OffsetDateTime::now_utc()
.checked_add(Duration::seconds(i64::try_from(expire_seconds).map_err(
|_| OssError::InvalidRequest("expireSeconds 超出可支持范围".to_string()),
)?))
.ok_or_else(|| {
OssError::InvalidRequest("expireSeconds 计算结果溢出".to_string())
})?;
let expires_at_text = expires_at
.format(&Rfc3339)
.map_err(|error| OssError::Sign(format!("格式化过期时间失败:{error}")))?;
let signed_at = OffsetDateTime::now_utc();
let signed_at_text = build_v4_signature_date(signed_at)?;
let signature_scope = build_v4_signature_scope(&self.config.endpoint, signed_at)?;
let credential = format!("{}/{}", self.config.access_key_id, signature_scope);
let mut query = BTreeMap::from([
("x-oss-additional-headers".to_string(), "host".to_string()),
(
"x-oss-signature-version".to_string(),
OSS_V4_ALGORITHM.to_string(),
),
("x-oss-credential".to_string(), credential),
("x-oss-date".to_string(), signed_at_text),
("x-oss-expires".to_string(), expire_seconds.to_string()),
]);
let canonical_uri = build_v4_canonical_uri(&self.config.bucket, Some(&object_key));
let object_url_path = format!("/{}", encode_url_path(&object_key));
let additional_headers = "host";
let canonical_headers =
format!("host:{}.{}\n", self.config.bucket(), self.config.endpoint());
let canonical_query = build_canonical_query_string(&query);
let canonical_request = build_v4_canonical_request(
Method::GET.as_str(),
&canonical_uri,
&canonical_query,
&canonical_headers,
additional_headers,
OSS_UNSIGNED_PAYLOAD,
);
let string_to_sign = build_v4_string_to_sign(
query["x-oss-date"].as_str(),
&signature_scope,
&canonical_request,
);
let signature = sign_v4_content(
&self.config.access_key_secret,
&signature_scope,
&string_to_sign,
)?;
query.insert("x-oss-signature".to_string(), signature);
let signed_url = format!(
"{}{}?{}",
self.config.upload_host(),
object_url_path,
build_canonical_query_string(&query)
);
Ok(OssSignedGetObjectUrlResponse {
provider: OSS_PROVIDER,
bucket: self.config.bucket.clone(),
endpoint: self.config.endpoint.clone(),
host: self.config.upload_host(),
object_key,
expires_at: expires_at_text,
signed_url,
})
})();
match &result {
Ok(response) => info!(
provider = OSS_PROVIDER,
operation = "sign_get_object_url",
bucket = %response.bucket,
endpoint = %response.endpoint,
object_key = %response.object_key,
expires_at = %response.expires_at,
elapsed_ms = elapsed_ms(started_at),
"OSS GetObject 读签名完成"
),
Err(error) => warn!(
provider = OSS_PROVIDER,
operation = "sign_get_object_url",
bucket = %self.config.bucket(),
endpoint = %self.config.endpoint(),
object_key = %requested_object_key,
error_kind = oss_error_kind_label(error),
message = %error,
elapsed_ms = elapsed_ms(started_at),
"OSS GetObject 读签名失败"
),
}
let object_key = normalize_object_key(&request.object_key)?;
let expires_at = OffsetDateTime::now_utc()
.checked_add(Duration::seconds(i64::try_from(expire_seconds).map_err(
|_| OssError::InvalidRequest("expireSeconds 超出可支持范围".to_string()),
)?))
.ok_or_else(|| OssError::InvalidRequest("expireSeconds 计算结果溢出".to_string()))?;
let expires_at_text = expires_at
.format(&Rfc3339)
.map_err(|error| OssError::Sign(format!("格式化过期时间失败:{error}")))?;
let signed_at = OffsetDateTime::now_utc();
let signed_at_text = build_v4_signature_date(signed_at)?;
let signature_scope = build_v4_signature_scope(&self.config.endpoint, signed_at)?;
let credential = format!("{}/{}", self.config.access_key_id, signature_scope);
let mut query = BTreeMap::from([
("x-oss-additional-headers".to_string(), "host".to_string()),
(
"x-oss-signature-version".to_string(),
OSS_V4_ALGORITHM.to_string(),
),
("x-oss-credential".to_string(), credential),
("x-oss-date".to_string(), signed_at_text),
("x-oss-expires".to_string(), expire_seconds.to_string()),
]);
let canonical_uri = build_v4_canonical_uri(&self.config.bucket, Some(&object_key));
let object_url_path = format!("/{}", encode_url_path(&object_key));
let additional_headers = "host";
let canonical_headers =
format!("host:{}.{}\n", self.config.bucket(), self.config.endpoint());
let canonical_query = build_canonical_query_string(&query);
let canonical_request = build_v4_canonical_request(
Method::GET.as_str(),
&canonical_uri,
&canonical_query,
&canonical_headers,
additional_headers,
OSS_UNSIGNED_PAYLOAD,
);
let string_to_sign = build_v4_string_to_sign(
query["x-oss-date"].as_str(),
&signature_scope,
&canonical_request,
);
let signature = sign_v4_content(
&self.config.access_key_secret,
&signature_scope,
&string_to_sign,
)?;
query.insert("x-oss-signature".to_string(), signature);
let signed_url = format!(
"{}{}?{}",
self.config.upload_host(),
object_url_path,
build_canonical_query_string(&query)
);
Ok(OssSignedGetObjectUrlResponse {
provider: "aliyun-oss",
bucket: self.config.bucket.clone(),
endpoint: self.config.endpoint.clone(),
host: self.config.upload_host(),
object_key,
expires_at: expires_at_text,
signed_url,
})
result
}
// 上传完成确认前,服务端必须自己探测一次对象,不能只相信客户端回传的 object_key。
@@ -558,59 +647,107 @@ impl OssClient {
client: &reqwest::Client,
request: OssHeadObjectRequest,
) -> Result<OssHeadObjectResponse, OssError> {
let object_key = normalize_object_key(&request.object_key)?;
let target_url = build_object_url(&self.config.bucket, &self.config.endpoint, &object_key)
.map_err(|error| OssError::Request(format!("构造 OSS 对象 URL 失败:{error}")))?;
let response = send_signed_request(
client,
&self.config,
Method::HEAD,
Some(&object_key),
target_url,
)
.await?;
let started_at = Instant::now();
let requested_object_key = request
.object_key
.trim()
.trim_start_matches('/')
.trim()
.to_string();
let mut response_status = None;
if response.status() == reqwest::StatusCode::NOT_FOUND {
return Err(OssError::ObjectNotFound(format!(
"OSS 对象不存在:{}",
request.object_key
)));
let result = async {
let object_key = normalize_object_key(&request.object_key)?;
let target_url =
build_object_url(&self.config.bucket, &self.config.endpoint, &object_key).map_err(
|error| OssError::Request(format!("构造 OSS 对象 URL 失败:{error}")),
)?;
let response = send_signed_request(
client,
&self.config,
Method::HEAD,
Some(&object_key),
target_url,
)
.await?;
response_status = Some(response.status().as_u16());
if response.status() == reqwest::StatusCode::NOT_FOUND {
return Err(OssError::ObjectNotFound(format!(
"OSS 对象不存在:{}",
request.object_key
)));
}
if !response.status().is_success() {
return Err(OssError::Request(format!(
"OSS HEAD Object 失败,状态码:{}",
response.status()
)));
}
let headers = response.headers();
let content_length = headers
.get(reqwest::header::CONTENT_LENGTH)
.and_then(|value| value.to_str().ok())
.and_then(|value| value.parse::<u64>().ok())
.unwrap_or(0);
let content_type = headers
.get(reqwest::header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.map(|value| value.to_string());
let etag = headers
.get(reqwest::header::ETAG)
.and_then(|value| value.to_str().ok())
.map(|value| value.trim_matches('"').to_string());
let last_modified = headers
.get(reqwest::header::LAST_MODIFIED)
.and_then(|value| value.to_str().ok())
.map(|value| value.to_string());
Ok(OssHeadObjectResponse {
bucket: self.config.bucket.clone(),
object_key,
content_length,
content_type,
etag,
last_modified,
})
}
.await;
match &result {
Ok(response) => info!(
provider = OSS_PROVIDER,
operation = "head_object",
bucket = %response.bucket,
endpoint = %self.config.endpoint(),
object_key = %response.object_key,
status = response_status.unwrap_or(reqwest::StatusCode::OK.as_u16()),
status_class = http_status_class_from_option(response_status),
content_length = response.content_length,
content_type = %response.content_type.as_deref().unwrap_or(""),
etag_present = response.etag.is_some(),
last_modified_present = response.last_modified.is_some(),
elapsed_ms = elapsed_ms(started_at),
"OSS HEAD Object 完成"
),
Err(error) => warn!(
provider = OSS_PROVIDER,
operation = "head_object",
bucket = %self.config.bucket(),
endpoint = %self.config.endpoint(),
object_key = %requested_object_key,
status = response_status.unwrap_or_default(),
status_class = http_status_class_from_option(response_status),
error_kind = oss_error_kind_label(error),
message = %error,
elapsed_ms = elapsed_ms(started_at),
"OSS HEAD Object 失败"
),
}
if !response.status().is_success() {
return Err(OssError::Request(format!(
"OSS HEAD Object 失败,状态码:{}",
response.status()
)));
}
let headers = response.headers();
let content_length = headers
.get(reqwest::header::CONTENT_LENGTH)
.and_then(|value| value.to_str().ok())
.and_then(|value| value.parse::<u64>().ok())
.unwrap_or(0);
let content_type = headers
.get(reqwest::header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.map(|value| value.to_string());
let etag = headers
.get(reqwest::header::ETAG)
.and_then(|value| value.to_str().ok())
.map(|value| value.trim_matches('"').to_string());
let last_modified = headers
.get(reqwest::header::LAST_MODIFIED)
.and_then(|value| value.to_str().ok())
.map(|value| value.to_string());
Ok(OssHeadObjectResponse {
bucket: self.config.bucket.clone(),
object_key,
content_length,
content_type,
etag,
last_modified,
})
result
}
// AI 生成资源默认由服务端上传 OSSWeb 端只拿签名读地址,不直接持有写权限。
@@ -619,73 +756,128 @@ impl OssClient {
client: &reqwest::Client,
request: OssPutObjectRequest,
) -> Result<OssPutObjectResponse, OssError> {
if request.body.is_empty() {
return Err(OssError::InvalidRequest(
"服务端上传对象内容不能为空".to_string(),
));
let started_at = Instant::now();
let requested_prefix = request.prefix.as_str();
let requested_content_type = request
.content_type
.as_deref()
.map(str::trim)
.unwrap_or("")
.to_string();
let requested_content_length = request.body.len();
let requested_metadata_count = request.metadata.len();
let mut response_status = None;
let result = async {
if request.body.is_empty() {
return Err(OssError::InvalidRequest(
"服务端上传对象内容不能为空".to_string(),
));
}
let sanitized_segments = request
.path_segments
.iter()
.map(|segment| sanitize_path_segment(segment))
.filter(|segment| !segment.is_empty())
.collect::<Vec<_>>();
let file_name = sanitize_file_name(&request.file_name)?;
let object_key = build_object_key(request.prefix, &sanitized_segments, &file_name);
let content_type = normalize_optional_value(request.content_type);
let metadata = normalize_metadata(request.metadata)?;
let target_url =
build_object_url(&self.config.bucket, &self.config.endpoint, &object_key).map_err(
|error| OssError::Request(format!("构造 OSS 对象 URL 失败:{error}")),
)?;
let content_length = u64::try_from(request.body.len())
.map_err(|_| OssError::InvalidRequest("上传对象大小超出可支持范围".to_string()))?;
let builder = signed_request_builder(
client,
&self.config,
Method::PUT,
Some(&object_key),
target_url,
content_type.as_deref(),
&metadata,
)?
.header(reqwest::header::CONTENT_LENGTH, content_length)
.body(request.body);
let response = builder
.send()
.await
.map_err(|error| OssError::Request(format!("请求 OSS 失败:{error}")))?;
response_status = Some(response.status().as_u16());
if !response.status().is_success() {
return Err(OssError::Request(format!(
"OSS PutObject 失败,状态码:{}",
response.status()
)));
}
let headers = response.headers();
let etag = headers
.get(reqwest::header::ETAG)
.and_then(|value| value.to_str().ok())
.map(|value| value.trim_matches('"').to_string());
let last_modified = headers
.get(reqwest::header::LAST_MODIFIED)
.and_then(|value| value.to_str().ok())
.map(|value| value.to_string());
Ok(OssPutObjectResponse {
provider: OSS_PROVIDER,
bucket: self.config.bucket.clone(),
endpoint: self.config.endpoint.clone(),
host: self.config.upload_host(),
legacy_public_path: format!("/{object_key}"),
object_key,
content_type,
content_length,
access: request.access,
etag,
last_modified,
})
}
.await;
match &result {
Ok(response) => info!(
provider = OSS_PROVIDER,
operation = "put_object",
bucket = %response.bucket,
endpoint = %response.endpoint,
object_key = %response.object_key,
access = oss_access_label(response.access),
status = response_status.unwrap_or(reqwest::StatusCode::OK.as_u16()),
status_class = http_status_class_from_option(response_status),
content_length = response.content_length,
content_type = %response.content_type.as_deref().unwrap_or(""),
etag_present = response.etag.is_some(),
last_modified_present = response.last_modified.is_some(),
elapsed_ms = elapsed_ms(started_at),
"OSS PutObject 上传完成"
),
Err(error) => warn!(
provider = OSS_PROVIDER,
operation = "put_object",
bucket = %self.config.bucket(),
endpoint = %self.config.endpoint(),
key_prefix = requested_prefix,
content_length = requested_content_length,
content_type = %requested_content_type,
metadata_count = requested_metadata_count,
status = response_status.unwrap_or_default(),
status_class = http_status_class_from_option(response_status),
error_kind = oss_error_kind_label(error),
message = %error,
elapsed_ms = elapsed_ms(started_at),
"OSS PutObject 上传失败"
),
}
let sanitized_segments = request
.path_segments
.iter()
.map(|segment| sanitize_path_segment(segment))
.filter(|segment| !segment.is_empty())
.collect::<Vec<_>>();
let file_name = sanitize_file_name(&request.file_name)?;
let object_key = build_object_key(request.prefix, &sanitized_segments, &file_name);
let content_type = normalize_optional_value(request.content_type);
let metadata = normalize_metadata(request.metadata)?;
let target_url = build_object_url(&self.config.bucket, &self.config.endpoint, &object_key)
.map_err(|error| OssError::Request(format!("构造 OSS 对象 URL 失败:{error}")))?;
let content_length = u64::try_from(request.body.len())
.map_err(|_| OssError::InvalidRequest("上传对象大小超出可支持范围".to_string()))?;
let builder = signed_request_builder(
client,
&self.config,
Method::PUT,
Some(&object_key),
target_url,
content_type.as_deref(),
&metadata,
)?
.header(reqwest::header::CONTENT_LENGTH, content_length)
.body(request.body);
let response = builder
.send()
.await
.map_err(|error| OssError::Request(format!("请求 OSS 失败:{error}")))?;
if !response.status().is_success() {
return Err(OssError::Request(format!(
"OSS PutObject 失败,状态码:{}",
response.status()
)));
}
let headers = response.headers();
let etag = headers
.get(reqwest::header::ETAG)
.and_then(|value| value.to_str().ok())
.map(|value| value.trim_matches('"').to_string());
let last_modified = headers
.get(reqwest::header::LAST_MODIFIED)
.and_then(|value| value.to_str().ok())
.map(|value| value.to_string());
Ok(OssPutObjectResponse {
provider: "aliyun-oss",
bucket: self.config.bucket.clone(),
endpoint: self.config.endpoint.clone(),
host: self.config.upload_host(),
legacy_public_path: format!("/{object_key}"),
object_key,
content_type,
content_length,
access: request.access,
etag,
last_modified,
})
result
}
}
@@ -717,6 +909,43 @@ impl OssError {
}
}
fn elapsed_ms(started_at: Instant) -> u64 {
started_at.elapsed().as_millis().min(u64::MAX as u128) as u64
}
fn oss_access_label(access: OssObjectAccess) -> &'static str {
match access {
OssObjectAccess::Public => "public",
OssObjectAccess::Private => "private",
}
}
fn oss_error_kind_label(error: &OssError) -> &'static str {
match error.kind() {
OssErrorKind::InvalidConfig => "invalid_config",
OssErrorKind::InvalidRequest => "invalid_request",
OssErrorKind::ObjectNotFound => "object_not_found",
OssErrorKind::Request => "request",
OssErrorKind::SerializePolicy => "serialize_policy",
OssErrorKind::Sign => "sign",
}
}
fn http_status_class_from_option(status: Option<u16>) -> &'static str {
status.map(http_status_class).unwrap_or("unknown")
}
fn http_status_class(status: u16) -> &'static str {
match status {
100..=199 => "1xx",
200..=299 => "2xx",
300..=399 => "3xx",
400..=499 => "4xx",
500..=599 => "5xx",
_ => "unknown",
}
}
fn build_policy_json(
bucket: &str,
object_key: &str,
@@ -1295,6 +1524,18 @@ mod tests {
);
}
#[test]
fn structured_log_labels_are_stable() {
assert_eq!(
oss_error_kind_label(&OssError::InvalidRequest("bad input".to_string())),
"invalid_request"
);
assert_eq!(oss_access_label(OssObjectAccess::Private), "private");
assert_eq!(http_status_class(204), "2xx");
assert_eq!(http_status_class(404), "4xx");
assert_eq!(http_status_class_from_option(None), "unknown");
}
fn build_client() -> OssClient {
OssClient::new(
OssConfig::new(