Files
Genarrative/server-rs/crates/api-server/src/openai_image_generation.rs
2026-05-09 18:24:08 +08:00

546 lines
17 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use std::time::Duration;
use axum::http::StatusCode;
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
use reqwest::header;
use serde_json::{Map, Value, json};
use crate::{http_error::AppError, state::AppState};
pub(crate) const GPT_IMAGE_2_MODEL: &str = "gpt-image-2";
pub(crate) const VECTOR_ENGINE_GPT_IMAGE_2_MODEL: &str = "gpt-image-2-all";
const VECTOR_ENGINE_PROVIDER: &str = "vector-engine";
#[derive(Clone, Debug)]
pub(crate) struct OpenAiImageSettings {
pub base_url: String,
pub api_key: String,
pub request_timeout_ms: u64,
}
#[derive(Clone, Debug)]
pub(crate) struct OpenAiGeneratedImages {
pub task_id: String,
pub actual_prompt: Option<String>,
pub images: Vec<DownloadedOpenAiImage>,
}
#[derive(Clone, Debug)]
pub(crate) struct DownloadedOpenAiImage {
pub bytes: Vec<u8>,
pub mime_type: String,
pub extension: String,
}
// 中文注释RPG、方洞等图片资产统一走 VectorEngine GPT-image-2-all避免把密钥或供应商协议暴露到前端。
pub(crate) fn require_openai_image_settings(
state: &AppState,
) -> Result<OpenAiImageSettings, AppError> {
let base_url = state
.config
.vector_engine_base_url
.trim()
.trim_end_matches('/');
if base_url.is_empty() {
return Err(
AppError::from_status(StatusCode::SERVICE_UNAVAILABLE).with_details(json!({
"provider": VECTOR_ENGINE_PROVIDER,
"reason": "VECTOR_ENGINE_BASE_URL 未配置",
})),
);
}
let api_key = state
.config
.vector_engine_api_key
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty())
.ok_or_else(|| {
AppError::from_status(StatusCode::SERVICE_UNAVAILABLE).with_details(json!({
"provider": VECTOR_ENGINE_PROVIDER,
"reason": "VECTOR_ENGINE_API_KEY 未配置",
}))
})?;
Ok(OpenAiImageSettings {
base_url: base_url.to_string(),
api_key: api_key.to_string(),
request_timeout_ms: state.config.vector_engine_image_request_timeout_ms.max(1),
})
}
pub(crate) fn build_openai_image_http_client(
settings: &OpenAiImageSettings,
) -> Result<reqwest::Client, AppError> {
reqwest::Client::builder()
.timeout(Duration::from_millis(settings.request_timeout_ms))
.build()
.map_err(|error| {
AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR).with_details(json!({
"provider": VECTOR_ENGINE_PROVIDER,
"message": format!("构造 VectorEngine 图片生成 HTTP 客户端失败:{error}"),
}))
})
}
pub(crate) async fn create_openai_image_generation(
http_client: &reqwest::Client,
settings: &OpenAiImageSettings,
prompt: &str,
negative_prompt: Option<&str>,
size: &str,
candidate_count: u32,
reference_images: &[String],
failure_context: &str,
) -> Result<OpenAiGeneratedImages, AppError> {
let request_body = build_openai_image_request_body(
prompt,
negative_prompt,
size,
candidate_count,
reference_images,
);
let response = http_client
.post(vector_engine_images_generation_url(settings))
.header(
header::AUTHORIZATION,
format!("Bearer {}", settings.api_key),
)
.header(header::ACCEPT, "application/json")
.header(header::CONTENT_TYPE, "application/json")
.json(&request_body)
.send()
.await
.map_err(|error| {
map_openai_image_request_error(format!(
"{failure_context}:创建图片生成任务失败:{error}"
))
})?;
let response_status = response.status();
let response_text = response.text().await.map_err(|error| {
map_openai_image_request_error(format!("{failure_context}:读取图片生成响应失败:{error}"))
})?;
if !response_status.is_success() {
return Err(map_openai_image_upstream_error(
response_status.as_u16(),
response_text.as_str(),
failure_context,
));
}
let response_json = parse_json_payload(response_text.as_str(), failure_context)?;
let generation_id = extract_generation_id(&response_json.payload)
.unwrap_or_else(|| format!("vector-engine-{}", current_utc_micros()));
let actual_prompt = find_first_string_by_key(&response_json.payload, "revised_prompt")
.or_else(|| find_first_string_by_key(&response_json.payload, "actual_prompt"));
let image_urls = extract_image_urls(&response_json.payload);
if !image_urls.is_empty() {
let mut generated =
download_images_from_urls(http_client, generation_id, image_urls, candidate_count)
.await?;
generated.actual_prompt = actual_prompt;
return Ok(generated);
}
let b64_images = extract_b64_images(&response_json.payload);
if !b64_images.is_empty() {
let mut generated = images_from_base64(generation_id, b64_images, candidate_count);
generated.actual_prompt = actual_prompt;
return Ok(generated);
}
Err(
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
"provider": VECTOR_ENGINE_PROVIDER,
"message": format!("{failure_context}VectorEngine 未返回图片地址"),
})),
)
}
pub(crate) fn build_openai_image_request_body(
prompt: &str,
negative_prompt: Option<&str>,
size: &str,
candidate_count: u32,
reference_images: &[String],
) -> Value {
let mut body = Map::from_iter([
(
"model".to_string(),
Value::String(VECTOR_ENGINE_GPT_IMAGE_2_MODEL.to_string()),
),
(
"prompt".to_string(),
Value::String(build_prompt_with_negative(prompt, negative_prompt)),
),
("n".to_string(), json!(candidate_count.clamp(1, 4))),
(
"size".to_string(),
Value::String(normalize_image_size(size)),
),
]);
if !reference_images.is_empty() {
body.insert("image".to_string(), json!(reference_images));
}
Value::Object(body)
}
fn build_prompt_with_negative(prompt: &str, negative_prompt: Option<&str>) -> String {
let prompt = prompt.trim();
let Some(negative_prompt) = negative_prompt
.map(str::trim)
.filter(|value| !value.is_empty())
else {
return prompt.to_string();
};
format!("{prompt}\n避免:{negative_prompt}")
}
fn normalize_image_size(size: &str) -> String {
match size.trim() {
"1024*1024" | "1024x1024" | "1:1" => "1024x1024",
"1280*720" | "1280x720" | "1600*900" | "1600x900" | "16:9" | "1536x1024" | "2048x1152"
| "2k" => "1536x1024",
"1024*1536" | "1024x1536" | "9:16" => "1024x1536",
value if !value.is_empty() => value,
_ => "1024x1024",
}
.to_string()
}
async fn download_images_from_urls(
http_client: &reqwest::Client,
task_id: String,
image_urls: Vec<String>,
candidate_count: u32,
) -> Result<OpenAiGeneratedImages, AppError> {
let mut images = Vec::with_capacity(candidate_count.clamp(1, 4) as usize);
for image_url in image_urls
.into_iter()
.take(candidate_count.clamp(1, 4) as usize)
{
images.push(download_remote_image(http_client, image_url.as_str()).await?);
}
Ok(OpenAiGeneratedImages {
task_id,
actual_prompt: None,
images,
})
}
fn images_from_base64(
task_id: String,
b64_images: Vec<String>,
candidate_count: u32,
) -> OpenAiGeneratedImages {
let images = b64_images
.into_iter()
.take(candidate_count.clamp(1, 4) as usize)
.filter_map(|raw| decode_generated_image_base64(raw.as_str()))
.collect();
OpenAiGeneratedImages {
task_id,
actual_prompt: None,
images,
}
}
fn decode_generated_image_base64(raw: &str) -> Option<DownloadedOpenAiImage> {
let bytes = BASE64_STANDARD.decode(raw.trim()).ok()?;
let mime_type = infer_image_mime_type(bytes.as_slice());
Some(DownloadedOpenAiImage {
extension: mime_to_extension(mime_type.as_str()).to_string(),
mime_type,
bytes,
})
}
pub(crate) async fn download_remote_image(
http_client: &reqwest::Client,
image_url: &str,
) -> Result<DownloadedOpenAiImage, AppError> {
let response =
http_client.get(image_url).send().await.map_err(|error| {
map_openai_image_request_error(format!("下载生成图片失败:{error}"))
})?;
let status = response.status();
let content_type = response
.headers()
.get(header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.unwrap_or("image/jpeg")
.to_string();
let body = response.bytes().await.map_err(|error| {
map_openai_image_request_error(format!("读取生成图片内容失败:{error}"))
})?;
if !status.is_success() {
return Err(
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
"provider": VECTOR_ENGINE_PROVIDER,
"message": "下载生成图片失败",
"status": status.as_u16(),
})),
);
}
let normalized_mime_type = normalize_downloaded_image_mime_type(content_type.as_str());
Ok(DownloadedOpenAiImage {
extension: mime_to_extension(normalized_mime_type.as_str()).to_string(),
mime_type: normalized_mime_type,
bytes: body.to_vec(),
})
}
fn parse_json_payload(
raw_text: &str,
failure_context: &str,
) -> Result<ParsedJsonPayload, AppError> {
serde_json::from_str::<Value>(raw_text)
.map(|payload| ParsedJsonPayload { payload })
.map_err(|error| {
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
"provider": VECTOR_ENGINE_PROVIDER,
"message": format!("{failure_context}:解析响应失败:{error}"),
"rawExcerpt": truncate_raw(raw_text),
}))
})
}
fn map_openai_image_request_error(message: String) -> AppError {
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
"provider": VECTOR_ENGINE_PROVIDER,
"message": message,
}))
}
fn map_openai_image_upstream_error(
upstream_status: u16,
raw_text: &str,
failure_context: &str,
) -> AppError {
let message = parse_api_error_message(raw_text, failure_context);
tracing::warn!(
provider = VECTOR_ENGINE_PROVIDER,
upstream_status,
raw_excerpt = %truncate_raw(raw_text),
message,
"VectorEngine 图片生成上游错误"
);
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
"provider": VECTOR_ENGINE_PROVIDER,
"message": message,
"upstreamStatus": upstream_status,
"rawExcerpt": truncate_raw(raw_text),
}))
}
fn parse_api_error_message(raw_text: &str, fallback_message: &str) -> String {
if raw_text.trim().is_empty() {
return fallback_message.to_string();
}
if let Ok(parsed) = serde_json::from_str::<Value>(raw_text) {
for pointer in [
"/error/message",
"/message",
"/output/message",
"/data/message",
] {
if let Some(message) = parsed
.pointer(pointer)
.and_then(Value::as_str)
.map(str::trim)
.filter(|value| !value.is_empty())
{
return message.to_string();
}
}
for pointer in ["/error/code", "/code", "/output/code", "/data/code"] {
if let Some(code) = parsed
.pointer(pointer)
.and_then(Value::as_str)
.map(str::trim)
.filter(|value| !value.is_empty())
{
return format!("{fallback_message}{code}");
}
}
}
raw_text.trim().to_string()
}
fn collect_strings_by_key(value: &Value, target_key: &str, results: &mut Vec<String>) {
match value {
Value::Array(entries) => {
for entry in entries {
collect_strings_by_key(entry, target_key, results);
}
}
Value::Object(object) => {
for (key, nested_value) in object {
if key == target_key {
match nested_value {
Value::String(text) => {
let text = text.trim();
if !text.is_empty() {
results.push(text.to_string());
continue;
}
}
Value::Array(entries) => {
for entry in entries {
if let Some(text) = entry
.as_str()
.map(str::trim)
.filter(|value| !value.is_empty())
{
results.push(text.to_string());
}
}
}
_ => {}
}
}
collect_strings_by_key(nested_value, target_key, results);
}
}
_ => {}
}
}
fn find_first_string_by_key(value: &Value, target_key: &str) -> Option<String> {
let mut results = Vec::new();
collect_strings_by_key(value, target_key, &mut results);
results.into_iter().next()
}
fn extract_generation_id(payload: &Value) -> Option<String> {
find_first_string_by_key(payload, "id")
.or_else(|| find_first_string_by_key(payload, "created"))
.or_else(|| find_first_string_by_key(payload, "request_id"))
}
fn extract_image_urls(payload: &Value) -> Vec<String> {
let mut urls = Vec::new();
collect_strings_by_key(payload, "url", &mut urls);
collect_strings_by_key(payload, "image", &mut urls);
collect_strings_by_key(payload, "image_url", &mut urls);
let mut deduped = Vec::new();
for url in urls {
if (url.starts_with("http://") || url.starts_with("https://")) && !deduped.contains(&url) {
deduped.push(url);
}
}
deduped
}
fn extract_b64_images(payload: &Value) -> Vec<String> {
let mut values = Vec::new();
collect_strings_by_key(payload, "b64_json", &mut values);
values
}
fn vector_engine_images_generation_url(settings: &OpenAiImageSettings) -> String {
if settings.base_url.ends_with("/v1") {
format!("{}/images/generations", settings.base_url)
} else {
format!("{}/v1/images/generations", settings.base_url)
}
}
fn normalize_downloaded_image_mime_type(content_type: &str) -> String {
let mime_type = content_type
.split(';')
.next()
.map(str::trim)
.unwrap_or("image/jpeg");
match mime_type {
"image/png" | "image/webp" | "image/jpeg" | "image/jpg" | "image/gif" => {
mime_type.to_string()
}
_ => "image/jpeg".to_string(),
}
}
fn mime_to_extension(mime_type: &str) -> &str {
match mime_type {
"image/png" => "png",
"image/webp" => "webp",
"image/gif" => "gif",
_ => "jpg",
}
}
fn infer_image_mime_type(bytes: &[u8]) -> String {
if bytes.starts_with(b"\x89PNG\r\n\x1A\n") {
return "image/png".to_string();
}
if bytes.starts_with(b"\xFF\xD8\xFF") {
return "image/jpeg".to_string();
}
if bytes.starts_with(b"RIFF") && bytes.get(8..12) == Some(b"WEBP") {
return "image/webp".to_string();
}
if bytes.starts_with(b"GIF87a") || bytes.starts_with(b"GIF89a") {
return "image/gif".to_string();
}
"image/png".to_string()
}
fn truncate_raw(raw_text: &str) -> String {
raw_text.chars().take(800).collect()
}
fn current_utc_micros() -> i64 {
use std::time::{SystemTime, UNIX_EPOCH};
let duration = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time should be after unix epoch");
i64::try_from(duration.as_micros()).expect("current unix micros should fit in i64")
}
struct ParsedJsonPayload {
payload: Value,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn gpt_image_2_request_uses_vector_engine_contract() {
let body = build_openai_image_request_body(
"雾海神殿",
Some("文字,水印"),
"1280*720",
2,
&["data:image/png;base64,abcd".to_string()],
);
assert_eq!(body["model"], VECTOR_ENGINE_GPT_IMAGE_2_MODEL);
assert_eq!(body["size"], "1536x1024");
assert_eq!(body["n"], 2);
assert!(body.get("official_fallback").is_none());
assert_eq!(body["image"][0], "data:image/png;base64,abcd");
assert!(body["prompt"].as_str().unwrap_or_default().contains("避免"));
}
#[test]
fn b64_json_response_decodes_png_image() {
let images = images_from_base64(
"task-1".to_string(),
vec![BASE64_STANDARD.encode(b"\x89PNG\r\n\x1A\nrest")],
1,
);
assert_eq!(images.images.len(), 1);
assert_eq!(images.images[0].mime_type, "image/png");
assert_eq!(images.images[0].extension, "png");
}
}