This commit is contained in:
2026-05-03 00:17:50 +08:00
parent 5831703156
commit 801d1d534a
16 changed files with 1337 additions and 449 deletions

View File

@@ -0,0 +1,632 @@
use std::time::{Duration, Instant};
use axum::http::StatusCode;
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
use reqwest::header;
use serde_json::{Map, Value, json};
use tokio::time::sleep;
use crate::{http_error::AppError, state::AppState};
pub(crate) const GPT_IMAGE_2_MODEL: &str = "gpt-image-2";
#[derive(Clone, Debug)]
pub(crate) struct OpenAiImageSettings {
pub base_url: String,
pub api_key: String,
pub request_timeout_ms: u64,
}
#[derive(Clone, Debug)]
pub(crate) struct OpenAiGeneratedImages {
pub task_id: String,
pub actual_prompt: Option<String>,
pub images: Vec<DownloadedOpenAiImage>,
}
#[derive(Clone, Debug)]
pub(crate) struct DownloadedOpenAiImage {
pub bytes: Vec<u8>,
pub mime_type: String,
pub extension: String,
}
// 中文注释RPG 图片资产与拼图一样走 APIMart 的 OpenAI 兼容图片入口,避免把密钥或供应商协议暴露到前端。
pub(crate) fn require_openai_image_settings(
state: &AppState,
) -> Result<OpenAiImageSettings, AppError> {
let base_url = state.config.apimart_base_url.trim().trim_end_matches('/');
if base_url.is_empty() {
return Err(
AppError::from_status(StatusCode::SERVICE_UNAVAILABLE).with_details(json!({
"provider": "apimart",
"reason": "APIMART_BASE_URL 未配置",
})),
);
}
let api_key = state
.config
.apimart_api_key
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty())
.ok_or_else(|| {
AppError::from_status(StatusCode::SERVICE_UNAVAILABLE).with_details(json!({
"provider": "apimart",
"reason": "APIMART_API_KEY 未配置",
}))
})?;
Ok(OpenAiImageSettings {
base_url: base_url.to_string(),
api_key: api_key.to_string(),
request_timeout_ms: state.config.apimart_image_request_timeout_ms.max(1),
})
}
pub(crate) fn build_openai_image_http_client(
settings: &OpenAiImageSettings,
) -> Result<reqwest::Client, AppError> {
reqwest::Client::builder()
.timeout(Duration::from_millis(settings.request_timeout_ms))
.build()
.map_err(|error| {
AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR).with_details(json!({
"provider": "apimart",
"message": format!("构造 APIMart 图片生成 HTTP 客户端失败:{error}"),
}))
})
}
pub(crate) async fn create_openai_image_generation(
http_client: &reqwest::Client,
settings: &OpenAiImageSettings,
prompt: &str,
negative_prompt: Option<&str>,
size: &str,
candidate_count: u32,
reference_images: &[String],
failure_context: &str,
) -> Result<OpenAiGeneratedImages, AppError> {
let request_body = build_openai_image_request_body(
prompt,
negative_prompt,
size,
candidate_count,
reference_images,
);
let response = http_client
.post(format!("{}/images/generations", settings.base_url))
.header(
header::AUTHORIZATION,
format!("Bearer {}", settings.api_key),
)
.header(header::CONTENT_TYPE, "application/json")
.json(&request_body)
.send()
.await
.map_err(|error| {
map_openai_image_request_error(format!(
"{failure_context}:创建图片生成任务失败:{error}"
))
})?;
let response_status = response.status();
let response_text = response.text().await.map_err(|error| {
map_openai_image_request_error(format!("{failure_context}:读取图片生成响应失败:{error}"))
})?;
if !response_status.is_success() {
return Err(map_openai_image_upstream_error(
response_status.as_u16(),
response_text.as_str(),
failure_context,
));
}
let response_json = parse_json_payload(response_text.as_str(), failure_context)?;
let image_urls = extract_image_urls(&response_json.payload);
if !image_urls.is_empty() {
return download_images_from_urls(
http_client,
format!("apimart-{}", current_utc_micros()),
image_urls,
candidate_count,
)
.await;
}
let b64_images = extract_b64_images(&response_json.payload);
if !b64_images.is_empty() {
return Ok(images_from_base64(
format!("apimart-{}", current_utc_micros()),
b64_images,
candidate_count,
));
}
let task_id = extract_task_id(&response_json.payload).ok_or_else(|| {
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
"provider": "apimart",
"message": format!("{failure_context}:上游未返回 task_id 或图片"),
}))
})?;
wait_openai_generated_images(
http_client,
settings,
task_id.as_str(),
candidate_count,
failure_context,
)
.await
}
pub(crate) fn build_openai_image_request_body(
prompt: &str,
negative_prompt: Option<&str>,
size: &str,
candidate_count: u32,
reference_images: &[String],
) -> Value {
let mut body = Map::from_iter([
(
"model".to_string(),
Value::String(GPT_IMAGE_2_MODEL.to_string()),
),
(
"prompt".to_string(),
Value::String(build_prompt_with_negative(prompt, negative_prompt)),
),
("n".to_string(), json!(candidate_count.clamp(1, 4))),
(
"size".to_string(),
Value::String(normalize_image_size(size)),
),
]);
if !reference_images.is_empty() {
body.insert("image_urls".to_string(), json!(reference_images));
}
Value::Object(body)
}
fn build_prompt_with_negative(prompt: &str, negative_prompt: Option<&str>) -> String {
let prompt = prompt.trim();
let Some(negative_prompt) = negative_prompt
.map(str::trim)
.filter(|value| !value.is_empty())
else {
return prompt.to_string();
};
format!("{prompt}\n避免:{negative_prompt}")
}
fn normalize_image_size(size: &str) -> String {
match size.trim() {
"1024*1024" | "1024x1024" | "1:1" => "1:1",
"1280*720" | "1280x720" | "1600*900" | "1600x900" | "16:9" => "16:9",
value if !value.is_empty() => value,
_ => "1:1",
}
.to_string()
}
async fn wait_openai_generated_images(
http_client: &reqwest::Client,
settings: &OpenAiImageSettings,
task_id: &str,
candidate_count: u32,
failure_context: &str,
) -> Result<OpenAiGeneratedImages, AppError> {
let deadline = Instant::now() + Duration::from_millis(settings.request_timeout_ms);
sleep(Duration::from_secs(10)).await;
while Instant::now() < deadline {
let poll_response = http_client
.get(format!("{}/tasks/{}", settings.base_url, task_id))
.header(
header::AUTHORIZATION,
format!("Bearer {}", settings.api_key),
)
.send()
.await
.map_err(|error| {
map_openai_image_request_error(format!(
"{failure_context}:查询图片生成任务失败:{error}"
))
})?;
let poll_status = poll_response.status();
let poll_text = poll_response.text().await.map_err(|error| {
map_openai_image_request_error(format!(
"{failure_context}:读取图片生成任务响应失败:{error}"
))
})?;
if !poll_status.is_success() {
return Err(map_openai_image_upstream_error(
poll_status.as_u16(),
poll_text.as_str(),
failure_context,
));
}
let poll_json = parse_json_payload(poll_text.as_str(), failure_context)?;
let task_status = find_first_string_by_key(&poll_json.payload, "status")
.or_else(|| find_first_string_by_key(&poll_json.payload, "task_status"))
.unwrap_or_default()
.trim()
.to_ascii_lowercase();
if matches!(task_status.as_str(), "completed" | "succeeded" | "success") {
let image_urls = extract_image_urls(&poll_json.payload);
if image_urls.is_empty() {
let b64_images = extract_b64_images(&poll_json.payload);
if b64_images.is_empty() {
return Err(AppError::from_status(StatusCode::BAD_GATEWAY).with_details(
json!({
"provider": "apimart",
"message": format!("{failure_context}:任务成功但未返回图片"),
}),
));
}
let mut generated =
images_from_base64(task_id.to_string(), b64_images, candidate_count);
generated.actual_prompt =
find_first_string_by_key(&poll_json.payload, "actual_prompt");
return Ok(generated);
}
let mut generated = download_images_from_urls(
http_client,
task_id.to_string(),
image_urls,
candidate_count,
)
.await?;
generated.actual_prompt = find_first_string_by_key(&poll_json.payload, "actual_prompt");
return Ok(generated);
}
if matches!(
task_status.as_str(),
"failed" | "error" | "canceled" | "cancelled" | "unknown"
) {
return Err(map_openai_image_upstream_error(
poll_status.as_u16(),
poll_text.as_str(),
failure_context,
));
}
sleep(Duration::from_secs(3)).await;
}
Err(
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
"provider": "apimart",
"message": format!("{failure_context}:图片生成超时或未返回图片地址"),
})),
)
}
async fn download_images_from_urls(
http_client: &reqwest::Client,
task_id: String,
image_urls: Vec<String>,
candidate_count: u32,
) -> Result<OpenAiGeneratedImages, AppError> {
let mut images = Vec::with_capacity(candidate_count.clamp(1, 4) as usize);
for image_url in image_urls
.into_iter()
.take(candidate_count.clamp(1, 4) as usize)
{
images.push(download_remote_image(http_client, image_url.as_str()).await?);
}
Ok(OpenAiGeneratedImages {
task_id,
actual_prompt: None,
images,
})
}
fn images_from_base64(
task_id: String,
b64_images: Vec<String>,
candidate_count: u32,
) -> OpenAiGeneratedImages {
let images = b64_images
.into_iter()
.take(candidate_count.clamp(1, 4) as usize)
.filter_map(|raw| decode_generated_image_base64(raw.as_str()))
.collect();
OpenAiGeneratedImages {
task_id,
actual_prompt: None,
images,
}
}
fn decode_generated_image_base64(raw: &str) -> Option<DownloadedOpenAiImage> {
let bytes = BASE64_STANDARD.decode(raw.trim()).ok()?;
let mime_type = infer_image_mime_type(bytes.as_slice());
Some(DownloadedOpenAiImage {
extension: mime_to_extension(mime_type.as_str()).to_string(),
mime_type,
bytes,
})
}
pub(crate) async fn download_remote_image(
http_client: &reqwest::Client,
image_url: &str,
) -> Result<DownloadedOpenAiImage, AppError> {
let response =
http_client.get(image_url).send().await.map_err(|error| {
map_openai_image_request_error(format!("下载生成图片失败:{error}"))
})?;
let status = response.status();
let content_type = response
.headers()
.get(header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.unwrap_or("image/jpeg")
.to_string();
let body = response.bytes().await.map_err(|error| {
map_openai_image_request_error(format!("读取生成图片内容失败:{error}"))
})?;
if !status.is_success() {
return Err(
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
"provider": "apimart",
"message": "下载生成图片失败",
"status": status.as_u16(),
})),
);
}
let normalized_mime_type = normalize_downloaded_image_mime_type(content_type.as_str());
Ok(DownloadedOpenAiImage {
extension: mime_to_extension(normalized_mime_type.as_str()).to_string(),
mime_type: normalized_mime_type,
bytes: body.to_vec(),
})
}
fn parse_json_payload(
raw_text: &str,
failure_context: &str,
) -> Result<ParsedJsonPayload, AppError> {
serde_json::from_str::<Value>(raw_text)
.map(|payload| ParsedJsonPayload { payload })
.map_err(|error| {
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
"provider": "apimart",
"message": format!("{failure_context}:解析响应失败:{error}"),
"rawExcerpt": truncate_raw(raw_text),
}))
})
}
fn map_openai_image_request_error(message: String) -> AppError {
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
"provider": "apimart",
"message": message,
}))
}
fn map_openai_image_upstream_error(
upstream_status: u16,
raw_text: &str,
failure_context: &str,
) -> AppError {
let message = parse_api_error_message(raw_text, failure_context);
tracing::warn!(
provider = "apimart",
upstream_status,
raw_excerpt = %truncate_raw(raw_text),
message,
"APIMart 图片生成上游错误"
);
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
"provider": "apimart",
"message": message,
"upstreamStatus": upstream_status,
"rawExcerpt": truncate_raw(raw_text),
}))
}
fn parse_api_error_message(raw_text: &str, fallback_message: &str) -> String {
if raw_text.trim().is_empty() {
return fallback_message.to_string();
}
if let Ok(parsed) = serde_json::from_str::<Value>(raw_text) {
for pointer in [
"/error/message",
"/message",
"/output/message",
"/data/message",
] {
if let Some(message) = parsed
.pointer(pointer)
.and_then(Value::as_str)
.map(str::trim)
.filter(|value| !value.is_empty())
{
return message.to_string();
}
}
for pointer in ["/error/code", "/code", "/output/code", "/data/code"] {
if let Some(code) = parsed
.pointer(pointer)
.and_then(Value::as_str)
.map(str::trim)
.filter(|value| !value.is_empty())
{
return format!("{fallback_message}{code}");
}
}
}
raw_text.trim().to_string()
}
fn collect_strings_by_key(value: &Value, target_key: &str, results: &mut Vec<String>) {
match value {
Value::Array(entries) => {
for entry in entries {
collect_strings_by_key(entry, target_key, results);
}
}
Value::Object(object) => {
for (key, nested_value) in object {
if key == target_key {
match nested_value {
Value::String(text) => {
let text = text.trim();
if !text.is_empty() {
results.push(text.to_string());
continue;
}
}
Value::Array(entries) => {
for entry in entries {
if let Some(text) = entry
.as_str()
.map(str::trim)
.filter(|value| !value.is_empty())
{
results.push(text.to_string());
}
}
}
_ => {}
}
}
collect_strings_by_key(nested_value, target_key, results);
}
}
_ => {}
}
}
fn find_first_string_by_key(value: &Value, target_key: &str) -> Option<String> {
let mut results = Vec::new();
collect_strings_by_key(value, target_key, &mut results);
results.into_iter().next()
}
fn extract_task_id(payload: &Value) -> Option<String> {
find_first_string_by_key(payload, "task_id")
.or_else(|| find_first_string_by_key(payload, "taskId"))
.or_else(|| find_first_string_by_key(payload, "id"))
}
fn extract_image_urls(payload: &Value) -> Vec<String> {
let mut urls = Vec::new();
collect_strings_by_key(payload, "url", &mut urls);
collect_strings_by_key(payload, "image", &mut urls);
collect_strings_by_key(payload, "image_url", &mut urls);
let mut deduped = Vec::new();
for url in urls {
if (url.starts_with("http://") || url.starts_with("https://")) && !deduped.contains(&url) {
deduped.push(url);
}
}
deduped
}
fn extract_b64_images(payload: &Value) -> Vec<String> {
let mut values = Vec::new();
collect_strings_by_key(payload, "b64_json", &mut values);
values
}
fn normalize_downloaded_image_mime_type(content_type: &str) -> String {
let mime_type = content_type
.split(';')
.next()
.map(str::trim)
.unwrap_or("image/jpeg");
match mime_type {
"image/png" | "image/webp" | "image/jpeg" | "image/jpg" | "image/gif" => {
mime_type.to_string()
}
_ => "image/jpeg".to_string(),
}
}
fn mime_to_extension(mime_type: &str) -> &str {
match mime_type {
"image/png" => "png",
"image/webp" => "webp",
"image/gif" => "gif",
_ => "jpg",
}
}
fn infer_image_mime_type(bytes: &[u8]) -> String {
if bytes.starts_with(b"\x89PNG\r\n\x1A\n") {
return "image/png".to_string();
}
if bytes.starts_with(b"\xFF\xD8\xFF") {
return "image/jpeg".to_string();
}
if bytes.starts_with(b"RIFF") && bytes.get(8..12) == Some(b"WEBP") {
return "image/webp".to_string();
}
if bytes.starts_with(b"GIF87a") || bytes.starts_with(b"GIF89a") {
return "image/gif".to_string();
}
"image/png".to_string()
}
fn truncate_raw(raw_text: &str) -> String {
raw_text.chars().take(800).collect()
}
fn current_utc_micros() -> i64 {
use std::time::{SystemTime, UNIX_EPOCH};
let duration = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time should be after unix epoch");
i64::try_from(duration.as_micros()).expect("current unix micros should fit in i64")
}
struct ParsedJsonPayload {
payload: Value,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn gpt_image_2_request_normalizes_legacy_sizes_and_reference_images() {
let body = build_openai_image_request_body(
"雾海神殿",
Some("文字,水印"),
"1280*720",
2,
&["data:image/png;base64,abcd".to_string()],
);
assert_eq!(body["model"], GPT_IMAGE_2_MODEL);
assert_eq!(body["size"], "16:9");
assert_eq!(body["n"], 2);
assert_eq!(body["image_urls"][0], "data:image/png;base64,abcd");
assert!(body["prompt"].as_str().unwrap_or_default().contains("避免"));
}
#[test]
fn b64_json_response_decodes_png_image() {
let images = images_from_base64(
"task-1".to_string(),
vec![BASE64_STANDARD.encode(b"\x89PNG\r\n\x1A\nrest")],
1,
);
assert_eq!(images.images.len(), 1);
assert_eq!(images.images[0].mime_type, "image/png");
assert_eq!(images.images[0].extension, "png");
}
}