Files
Genarrative/server-rs/crates/platform-image/src/lib.rs

1363 lines
44 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use std::{error::Error, fmt, time::Duration};
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
use reqwest::header;
use serde_json::{Map, Value, json};
pub const GPT_IMAGE_2_MODEL: &str = "gpt-image-2";
pub const VECTOR_ENGINE_GPT_IMAGE_2_MODEL: &str = GPT_IMAGE_2_MODEL;
pub const VECTOR_ENGINE_PROVIDER: &str = "vector-engine";
#[derive(Clone, Debug)]
pub struct VectorEngineImageSettings {
pub base_url: String,
pub api_key: String,
pub request_timeout_ms: u64,
}
#[derive(Clone, Debug)]
pub struct GeneratedImages {
pub task_id: String,
pub actual_prompt: Option<String>,
pub images: Vec<DownloadedImage>,
}
#[derive(Clone, Debug)]
pub struct DownloadedImage {
pub bytes: Vec<u8>,
pub mime_type: String,
pub extension: String,
}
#[derive(Clone, Debug)]
pub struct ReferenceImage {
pub bytes: Vec<u8>,
pub mime_type: String,
pub file_name: String,
}
#[derive(Clone, Debug)]
pub struct PlatformImageFailureAudit {
pub provider: &'static str,
pub endpoint: String,
pub operation: String,
pub failure_stage: &'static str,
pub status_code: Option<u16>,
pub status_class: Option<&'static str>,
pub timeout: bool,
pub retryable: bool,
pub error_message: String,
pub error_source: Option<String>,
pub raw_excerpt: Option<String>,
pub latency_ms: Option<u64>,
pub prompt_chars: Option<usize>,
pub reference_image_count: Option<usize>,
pub image_model: Option<&'static str>,
}
#[derive(Clone, Debug)]
pub enum PlatformImageError {
InvalidConfig {
provider: &'static str,
message: String,
},
InvalidRequest {
provider: &'static str,
message: String,
},
Request {
provider: &'static str,
message: String,
endpoint: Option<String>,
timeout: bool,
connect: bool,
request: bool,
body: bool,
status_code: Option<u16>,
source: Option<String>,
audit: Option<PlatformImageFailureAudit>,
},
Upstream {
provider: &'static str,
message: String,
upstream_status: u16,
raw_excerpt: String,
audit: Option<PlatformImageFailureAudit>,
},
ResponseParse {
provider: &'static str,
message: String,
raw_excerpt: String,
audit: Option<PlatformImageFailureAudit>,
},
MissingImage {
provider: &'static str,
message: String,
audit: Option<PlatformImageFailureAudit>,
},
}
impl PlatformImageError {
pub fn provider(&self) -> &'static str {
match self {
Self::InvalidConfig { provider, .. }
| Self::InvalidRequest { provider, .. }
| Self::Request { provider, .. }
| Self::Upstream { provider, .. }
| Self::ResponseParse { provider, .. }
| Self::MissingImage { provider, .. } => provider,
}
}
pub fn message(&self) -> &str {
match self {
Self::InvalidConfig { message, .. }
| Self::InvalidRequest { message, .. }
| Self::Request { message, .. }
| Self::Upstream { message, .. }
| Self::ResponseParse { message, .. }
| Self::MissingImage { message, .. } => message,
}
}
pub fn audit(&self) -> Option<&PlatformImageFailureAudit> {
match self {
Self::Request { audit, .. }
| Self::Upstream { audit, .. }
| Self::ResponseParse { audit, .. }
| Self::MissingImage { audit, .. } => audit.as_ref(),
Self::InvalidConfig { .. } | Self::InvalidRequest { .. } => None,
}
}
pub fn status_hint(&self) -> PlatformImageStatusHint {
match self {
Self::InvalidConfig { .. } => PlatformImageStatusHint::ServiceUnavailable,
Self::InvalidRequest { .. } => PlatformImageStatusHint::BadRequest,
Self::Request { timeout, .. } if *timeout => PlatformImageStatusHint::GatewayTimeout,
Self::Upstream { message, raw_excerpt, .. }
if is_timeout_message(message) || is_timeout_message(raw_excerpt) =>
{
PlatformImageStatusHint::GatewayTimeout
}
Self::Request { .. }
| Self::Upstream { .. }
| Self::ResponseParse { .. }
| Self::MissingImage { .. } => PlatformImageStatusHint::BadGateway,
}
}
}
impl fmt::Display for PlatformImageError {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str(self.message())
}
}
impl Error for PlatformImageError {}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum PlatformImageStatusHint {
BadRequest,
ServiceUnavailable,
BadGateway,
GatewayTimeout,
}
pub fn build_vector_engine_image_http_client(
settings: &VectorEngineImageSettings,
) -> Result<reqwest::Client, PlatformImageError> {
reqwest::Client::builder()
.timeout(Duration::from_millis(settings.request_timeout_ms.max(1)))
.http1_only()
.build()
.map_err(|error| PlatformImageError::InvalidConfig {
provider: VECTOR_ENGINE_PROVIDER,
message: format!("构造 VectorEngine 图片生成 HTTP 客户端失败:{error}"),
})
}
pub async fn create_vector_engine_image_generation(
http_client: &reqwest::Client,
settings: &VectorEngineImageSettings,
prompt: &str,
negative_prompt: Option<&str>,
size: &str,
candidate_count: u32,
reference_images: &[String],
failure_context: &str,
) -> Result<GeneratedImages, PlatformImageError> {
if !reference_images.is_empty() {
let resolved_references =
resolve_reference_images(http_client, reference_images, failure_context).await?;
return create_vector_engine_image_edit_with_references(
http_client,
settings,
prompt,
negative_prompt,
size,
candidate_count,
resolved_references.as_slice(),
failure_context,
)
.await;
}
let request_url = vector_engine_images_generation_url(settings);
let normalized_size = normalize_image_size(size);
let request_body = build_vector_engine_image_request_body(
prompt,
negative_prompt,
normalized_size.as_str(),
candidate_count,
reference_images,
);
let started_at = std::time::Instant::now();
let response = match http_client
.post(request_url.as_str())
.header(
header::AUTHORIZATION,
format!("Bearer {}", settings.api_key),
)
.header(header::ACCEPT, "application/json")
.header(header::CONTENT_TYPE, "application/json")
.json(&request_body)
.send()
.await
{
Ok(response) => response,
Err(error) => {
return Err(map_reqwest_error(
format!("{failure_context}:创建图片生成任务失败").as_str(),
request_url.as_str(),
"request_send",
error,
started_at.elapsed().as_millis() as u64,
Some(prompt.chars().count()),
Some(reference_images.len()),
));
}
};
let response_status = response.status();
tracing::info!(
provider = VECTOR_ENGINE_PROVIDER,
endpoint = %request_url,
status = response_status.as_u16(),
prompt_chars = prompt.chars().count(),
size = %normalized_size,
reference_image_count = reference_images.len(),
elapsed_ms = started_at.elapsed().as_millis() as u64,
failure_context,
"VectorEngine 图片生成 HTTP 返回"
);
let response_text = match response.text().await {
Ok(response_text) => response_text,
Err(error) => {
return Err(map_reqwest_error(
format!("{failure_context}:读取图片生成响应失败").as_str(),
request_url.as_str(),
"response_body",
error,
started_at.elapsed().as_millis() as u64,
Some(prompt.chars().count()),
Some(reference_images.len()),
));
}
};
handle_vector_engine_response(
http_client,
request_url.as_str(),
response_status.as_u16(),
response_text.as_str(),
failure_context,
started_at.elapsed().as_millis() as u64,
Some(prompt.chars().count()),
Some(reference_images.len()),
candidate_count,
"vector-engine",
)
.await
}
pub async fn create_vector_engine_image_edit(
http_client: &reqwest::Client,
settings: &VectorEngineImageSettings,
prompt: &str,
negative_prompt: Option<&str>,
size: &str,
reference_image: &ReferenceImage,
failure_context: &str,
) -> Result<GeneratedImages, PlatformImageError> {
create_vector_engine_image_edit_with_references(
http_client,
settings,
prompt,
negative_prompt,
size,
1,
std::slice::from_ref(reference_image),
failure_context,
)
.await
}
pub async fn create_vector_engine_image_edit_with_references(
http_client: &reqwest::Client,
settings: &VectorEngineImageSettings,
prompt: &str,
negative_prompt: Option<&str>,
size: &str,
candidate_count: u32,
reference_images: &[ReferenceImage],
failure_context: &str,
) -> Result<GeneratedImages, PlatformImageError> {
if reference_images.is_empty() {
return Err(PlatformImageError::InvalidRequest {
provider: VECTOR_ENGINE_PROVIDER,
message: format!("{failure_context}:缺少参考图,图片编辑需要至少一张参考图。"),
});
}
let request_url = vector_engine_images_edit_url(settings);
let normalized_size = normalize_image_size(size);
let mut form = reqwest::multipart::Form::new()
.text("model", GPT_IMAGE_2_MODEL.to_string())
.text("prompt", build_prompt_with_negative(prompt, negative_prompt))
.text("n", candidate_count.clamp(1, 4).to_string())
.text("size", normalized_size.clone());
for reference_image in reference_images.iter().take(5) {
let image_part = reqwest::multipart::Part::bytes(reference_image.bytes.clone())
.file_name(reference_image.file_name.clone())
.mime_str(reference_image.mime_type.as_str())
.map_err(|error| PlatformImageError::InvalidRequest {
provider: VECTOR_ENGINE_PROVIDER,
message: format!("{failure_context}:构造参考图失败:{error}"),
})?;
form = form.part("image", image_part);
}
let reference_image_count = reference_images.iter().take(5).count();
let started_at = std::time::Instant::now();
let response = match http_client
.post(request_url.as_str())
.header(
header::AUTHORIZATION,
format!("Bearer {}", settings.api_key),
)
.header(header::ACCEPT, "application/json")
.multipart(form)
.send()
.await
{
Ok(response) => response,
Err(error) => {
return Err(map_reqwest_error(
format!("{failure_context}:创建图片编辑任务失败").as_str(),
request_url.as_str(),
"request_send",
error,
started_at.elapsed().as_millis() as u64,
Some(prompt.chars().count()),
Some(reference_image_count),
));
}
};
let response_status = response.status();
tracing::info!(
provider = VECTOR_ENGINE_PROVIDER,
endpoint = %request_url,
status = response_status.as_u16(),
prompt_chars = prompt.chars().count(),
size = %normalized_size,
reference_image_count,
elapsed_ms = started_at.elapsed().as_millis() as u64,
failure_context,
"VectorEngine 图片编辑 HTTP 返回"
);
let response_text = match response.text().await {
Ok(response_text) => response_text,
Err(error) => {
return Err(map_reqwest_error(
format!("{failure_context}:读取图片编辑响应失败").as_str(),
request_url.as_str(),
"response_body",
error,
started_at.elapsed().as_millis() as u64,
Some(prompt.chars().count()),
Some(reference_image_count),
));
}
};
handle_vector_engine_response(
http_client,
request_url.as_str(),
response_status.as_u16(),
response_text.as_str(),
failure_context,
started_at.elapsed().as_millis() as u64,
Some(prompt.chars().count()),
Some(reference_image_count),
candidate_count,
"vector-engine-edit",
)
.await
}
#[allow(clippy::too_many_arguments)]
async fn handle_vector_engine_response(
http_client: &reqwest::Client,
request_url: &str,
response_status: u16,
response_text: &str,
failure_context: &str,
latency_ms: u64,
prompt_chars: Option<usize>,
reference_image_count: Option<usize>,
candidate_count: u32,
task_prefix: &str,
) -> Result<GeneratedImages, PlatformImageError> {
if !(200..=299).contains(&response_status) {
let message = parse_api_error_message(response_text, failure_context);
let raw_excerpt = truncate_raw(response_text);
let audit = build_failure_audit(
request_url,
failure_context,
"upstream_status",
Some(response_status),
None,
false,
false,
message.as_str(),
None,
Some(raw_excerpt.clone()),
Some(latency_ms),
prompt_chars,
reference_image_count,
);
tracing::warn!(
provider = VECTOR_ENGINE_PROVIDER,
endpoint = %request_url,
upstream_status = response_status,
timeout = is_timeout_message(message.as_str()) || is_timeout_message(raw_excerpt.as_str()),
retryable = audit.retryable,
message = %message,
raw_excerpt = %raw_excerpt,
"VectorEngine 图片生成上游错误"
);
return Err(PlatformImageError::Upstream {
provider: VECTOR_ENGINE_PROVIDER,
message,
upstream_status: response_status,
raw_excerpt,
audit: Some(audit),
});
}
let response_json = match parse_json_payload(response_text, failure_context) {
Ok(response_json) => response_json,
Err(error) => {
let audit = build_failure_audit(
request_url,
failure_context,
"response_parse",
Some(response_status),
None,
false,
false,
error.message(),
None,
Some(truncate_raw(response_text)),
Some(latency_ms),
prompt_chars,
reference_image_count,
);
tracing::warn!(
provider = VECTOR_ENGINE_PROVIDER,
endpoint = %request_url,
status = response_status,
raw_excerpt = %truncate_raw(response_text),
message = %error.message(),
"VectorEngine 图片响应解析失败"
);
return Err(error.with_audit(audit));
}
};
let task_id = extract_generation_id(&response_json.payload)
.unwrap_or_else(|| format!("{task_prefix}-{}", current_utc_micros()));
let actual_prompt = find_first_string_by_key(&response_json.payload, "revised_prompt")
.or_else(|| find_first_string_by_key(&response_json.payload, "actual_prompt"));
let image_urls = extract_image_urls(&response_json.payload);
if !image_urls.is_empty() {
let download_started_at = std::time::Instant::now();
let mut generated = match download_images_from_urls(
http_client,
task_id,
image_urls,
candidate_count,
)
.await
{
Ok(generated) => generated,
Err(error) => {
let audit = build_failure_audit(
request_url,
failure_context,
"image_download",
Some(response_status),
Some("5xx"),
false,
false,
error.message(),
None,
None,
Some(download_started_at.elapsed().as_millis() as u64),
prompt_chars,
reference_image_count,
);
return Err(error.with_audit(audit));
}
};
generated.actual_prompt = actual_prompt;
tracing::info!(
provider = VECTOR_ENGINE_PROVIDER,
endpoint = %request_url,
image_count = generated.images.len(),
elapsed_ms = download_started_at.elapsed().as_millis() as u64,
failure_context,
"VectorEngine 图片下载完成"
);
return Ok(generated);
}
let b64_images = extract_b64_images(&response_json.payload);
if !b64_images.is_empty() {
let mut generated = images_from_base64(task_id, b64_images, candidate_count);
generated.actual_prompt = actual_prompt;
tracing::info!(
provider = VECTOR_ENGINE_PROVIDER,
endpoint = %request_url,
image_count = generated.images.len(),
failure_context,
"VectorEngine 图片 base64 解码完成"
);
return Ok(generated);
}
let message = format!("{failure_context}VectorEngine 未返回图片地址");
let audit = build_failure_audit(
request_url,
failure_context,
"missing_image",
Some(response_status),
None,
false,
false,
message.as_str(),
None,
Some(truncate_raw(response_text)),
Some(latency_ms),
prompt_chars,
reference_image_count,
);
tracing::warn!(
provider = VECTOR_ENGINE_PROVIDER,
endpoint = %request_url,
status = response_status,
raw_excerpt = %truncate_raw(response_text),
"VectorEngine 图片响应未返回图片"
);
Err(PlatformImageError::MissingImage {
provider: VECTOR_ENGINE_PROVIDER,
message,
audit: Some(audit),
})
}
pub fn build_vector_engine_image_request_body(
prompt: &str,
negative_prompt: Option<&str>,
size: &str,
candidate_count: u32,
_reference_images: &[String],
) -> Value {
let body = Map::from_iter([
(
"model".to_string(),
Value::String(GPT_IMAGE_2_MODEL.to_string()),
),
(
"prompt".to_string(),
Value::String(build_prompt_with_negative(prompt, negative_prompt)),
),
("n".to_string(), json!(candidate_count.clamp(1, 4))),
(
"size".to_string(),
Value::String(normalize_image_size(size)),
),
]);
Value::Object(body)
}
pub fn normalize_image_size(size: &str) -> String {
match size.trim() {
"1024*1024" | "1024x1024" | "1:1" => "1024x1024",
"1280*720" | "1280x720" | "1600*900" | "1600x900" | "16:9" | "1536x1024" | "2048x1152"
| "2k" => "1536x1024",
"1024*1536" | "1024x1536" | "9:16" => "1024x1536",
value if !value.is_empty() => value,
_ => "1024x1024",
}
.to_string()
}
pub fn vector_engine_images_generation_url(settings: &VectorEngineImageSettings) -> String {
if settings.base_url.ends_with("/v1") {
format!("{}/images/generations", settings.base_url)
} else {
format!("{}/v1/images/generations", settings.base_url)
}
}
pub fn vector_engine_images_edit_url(settings: &VectorEngineImageSettings) -> String {
if settings.base_url.ends_with("/v1") {
format!("{}/images/edits", settings.base_url)
} else {
format!("{}/v1/images/edits", settings.base_url)
}
}
pub async fn download_remote_image(
http_client: &reqwest::Client,
image_url: &str,
) -> Result<DownloadedImage, PlatformImageError> {
let response = http_client.get(image_url).send().await.map_err(|error| {
map_simple_request_error(format!("下载生成图片失败:{error}"), Some(image_url.to_string()))
})?;
let status = response.status();
let content_type = response
.headers()
.get(header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.unwrap_or("image/jpeg")
.to_string();
let body = response.bytes().await.map_err(|error| {
map_simple_request_error(format!("读取生成图片内容失败:{error}"), Some(image_url.to_string()))
})?;
if !status.is_success() {
return Err(PlatformImageError::Request {
provider: VECTOR_ENGINE_PROVIDER,
message: "下载生成图片失败".to_string(),
endpoint: Some(image_url.to_string()),
timeout: false,
connect: false,
request: false,
body: false,
status_code: Some(status.as_u16()),
source: None,
audit: None,
});
}
let normalized_mime_type = normalize_downloaded_image_mime_type(content_type.as_str());
Ok(DownloadedImage {
extension: mime_to_extension(normalized_mime_type.as_str()).to_string(),
mime_type: normalized_mime_type,
bytes: body.to_vec(),
})
}
async fn download_images_from_urls(
http_client: &reqwest::Client,
task_id: String,
image_urls: Vec<String>,
candidate_count: u32,
) -> Result<GeneratedImages, PlatformImageError> {
let mut images = Vec::with_capacity(candidate_count.clamp(1, 4) as usize);
for image_url in image_urls
.into_iter()
.take(candidate_count.clamp(1, 4) as usize)
{
images.push(download_remote_image(http_client, image_url.as_str()).await?);
}
Ok(GeneratedImages {
task_id,
actual_prompt: None,
images,
})
}
async fn resolve_reference_images(
http_client: &reqwest::Client,
reference_images: &[String],
failure_context: &str,
) -> Result<Vec<ReferenceImage>, PlatformImageError> {
let mut resolved = Vec::new();
for (index, source) in reference_images.iter().take(5).enumerate() {
let source = source.trim();
if source.is_empty() {
continue;
}
if let Some(reference_image) = parse_reference_image_data_url(source, index)? {
resolved.push(reference_image);
continue;
}
if source.starts_with("http://") || source.starts_with("https://") {
let downloaded = download_remote_image(http_client, source)
.await
.map_err(|error| PlatformImageError::Request {
provider: VECTOR_ENGINE_PROVIDER,
message: format!("{failure_context}:下载参考图失败:{error}"),
endpoint: Some(source.to_string()),
timeout: false,
connect: false,
request: false,
body: false,
status_code: None,
source: None,
audit: None,
})?;
resolved.push(ReferenceImage {
bytes: downloaded.bytes,
mime_type: downloaded.mime_type.clone(),
file_name: format!(
"reference-{index}.{}",
mime_to_extension(downloaded.mime_type.as_str())
),
});
continue;
}
return Err(PlatformImageError::InvalidRequest {
provider: VECTOR_ENGINE_PROVIDER,
message: format!("{failure_context}:参考图必须是图片 Data URL 或 HTTP(S) URL。"),
});
}
if resolved.is_empty() {
return Err(PlatformImageError::InvalidRequest {
provider: VECTOR_ENGINE_PROVIDER,
message: format!("{failure_context}:图片编辑需要至少一张参考图。"),
});
}
Ok(resolved)
}
fn parse_reference_image_data_url(
source: &str,
index: usize,
) -> Result<Option<ReferenceImage>, PlatformImageError> {
let Some(body) = source.strip_prefix("data:") else {
return Ok(None);
};
let Some((mime_type, data)) = body.split_once(";base64,") else {
return Err(PlatformImageError::InvalidRequest {
provider: VECTOR_ENGINE_PROVIDER,
message: "参考图 Data URL 必须是 base64 图片。".to_string(),
});
};
if !mime_type.starts_with("image/") {
return Err(PlatformImageError::InvalidRequest {
provider: VECTOR_ENGINE_PROVIDER,
message: "参考图 Data URL 必须是图片类型。".to_string(),
});
}
let bytes = BASE64_STANDARD
.decode(data.trim())
.map_err(|error| PlatformImageError::InvalidRequest {
provider: VECTOR_ENGINE_PROVIDER,
message: format!("参考图 Data URL 解码失败:{error}"),
})?;
let mime_type = normalize_downloaded_image_mime_type(mime_type);
Ok(Some(ReferenceImage {
bytes,
file_name: format!(
"reference-{index}.{}",
mime_to_extension(mime_type.as_str())
),
mime_type,
}))
}
fn images_from_base64(
task_id: String,
b64_images: Vec<String>,
candidate_count: u32,
) -> GeneratedImages {
let images = b64_images
.into_iter()
.take(candidate_count.clamp(1, 4) as usize)
.filter_map(|raw| decode_generated_image_base64(raw.as_str()))
.collect();
GeneratedImages {
task_id,
actual_prompt: None,
images,
}
}
fn decode_generated_image_base64(raw: &str) -> Option<DownloadedImage> {
let bytes = BASE64_STANDARD.decode(raw.trim()).ok()?;
let mime_type = infer_image_mime_type(bytes.as_slice());
Some(DownloadedImage {
extension: mime_to_extension(mime_type.as_str()).to_string(),
mime_type,
bytes,
})
}
fn parse_json_payload(
raw_text: &str,
failure_context: &str,
) -> Result<ParsedJsonPayload, PlatformImageError> {
serde_json::from_str::<Value>(raw_text)
.map(|payload| ParsedJsonPayload { payload })
.map_err(|error| PlatformImageError::ResponseParse {
provider: VECTOR_ENGINE_PROVIDER,
message: format!("{failure_context}:解析响应失败:{error}"),
raw_excerpt: truncate_raw(raw_text),
audit: None,
})
}
fn map_reqwest_error(
context: &str,
request_url: &str,
failure_stage: &'static str,
error: reqwest::Error,
latency_ms: u64,
prompt_chars: Option<usize>,
reference_image_count: Option<usize>,
) -> PlatformImageError {
let is_timeout = error.is_timeout();
let is_connect = error.is_connect();
let source = error.source().map(ToString::to_string);
let message = format!("{context}{error}");
let audit = build_failure_audit(
request_url,
context,
failure_stage,
error.status().map(|status| status.as_u16()),
None,
is_timeout,
is_connect,
message.as_str(),
source.clone(),
None,
Some(latency_ms),
prompt_chars,
reference_image_count,
);
tracing::warn!(
provider = VECTOR_ENGINE_PROVIDER,
endpoint = %request_url,
failure_stage,
timeout = is_timeout,
connect = is_connect,
request = error.is_request(),
body = error.is_body(),
status = error.status().map(|status| status.as_u16()).unwrap_or_default(),
source = %source.clone().unwrap_or_default(),
message = %message,
elapsed_ms = latency_ms,
prompt_chars,
reference_image_count,
"VectorEngine 图片请求发送失败"
);
PlatformImageError::Request {
provider: VECTOR_ENGINE_PROVIDER,
message,
endpoint: Some(request_url.to_string()),
timeout: is_timeout,
connect: is_connect,
request: error.is_request(),
body: error.is_body(),
status_code: error.status().map(|status| status.as_u16()),
source,
audit: Some(audit),
}
}
fn map_simple_request_error(message: String, endpoint: Option<String>) -> PlatformImageError {
PlatformImageError::Request {
provider: VECTOR_ENGINE_PROVIDER,
message,
endpoint,
timeout: false,
connect: false,
request: true,
body: false,
status_code: None,
source: None,
audit: None,
}
}
#[allow(clippy::too_many_arguments)]
fn build_failure_audit(
request_url: &str,
operation: &str,
failure_stage: &'static str,
status_code: Option<u16>,
status_class: Option<&'static str>,
timeout: bool,
connect: bool,
error_message: &str,
error_source: Option<String>,
raw_excerpt: Option<String>,
latency_ms: Option<u64>,
prompt_chars: Option<usize>,
reference_image_count: Option<usize>,
) -> PlatformImageFailureAudit {
PlatformImageFailureAudit {
provider: VECTOR_ENGINE_PROVIDER,
endpoint: request_url.to_string(),
operation: operation.to_string(),
failure_stage,
status_code,
status_class,
timeout,
retryable: is_retryable_external_api_failure(status_code, timeout, connect),
error_message: error_message.to_string(),
error_source,
raw_excerpt,
latency_ms,
prompt_chars,
reference_image_count,
image_model: Some(VECTOR_ENGINE_GPT_IMAGE_2_MODEL),
}
}
fn is_retryable_external_api_failure(
status_code: Option<u16>,
timeout: bool,
connect: bool,
) -> bool {
timeout || connect || status_code.is_some_and(|status| status == 429 || status == 408 || status >= 500)
}
fn build_prompt_with_negative(prompt: &str, negative_prompt: Option<&str>) -> String {
let prompt = prompt.trim();
let Some(negative_prompt) = negative_prompt
.map(str::trim)
.filter(|value| !value.is_empty())
else {
return prompt.to_string();
};
format!("{prompt}\n避免:{negative_prompt}")
}
fn parse_api_error_message(raw_text: &str, fallback_message: &str) -> String {
if raw_text.trim().is_empty() {
return fallback_message.to_string();
}
if let Ok(parsed) = serde_json::from_str::<Value>(raw_text) {
for pointer in [
"/error/message",
"/message",
"/output/message",
"/data/message",
] {
if let Some(message) = parsed
.pointer(pointer)
.and_then(Value::as_str)
.map(str::trim)
.filter(|value| !value.is_empty())
{
return message.to_string();
}
}
for pointer in ["/error/code", "/code", "/output/code", "/data/code"] {
if let Some(code) = parsed
.pointer(pointer)
.and_then(Value::as_str)
.map(str::trim)
.filter(|value| !value.is_empty())
{
return format!("{fallback_message}{code}");
}
}
}
raw_text.trim().to_string()
}
fn collect_strings_by_key(value: &Value, target_key: &str, results: &mut Vec<String>) {
match value {
Value::Array(entries) => {
for entry in entries {
collect_strings_by_key(entry, target_key, results);
}
}
Value::Object(object) => {
for (key, nested_value) in object {
if key == target_key {
match nested_value {
Value::String(text) => {
let text = text.trim();
if !text.is_empty() {
results.push(text.to_string());
continue;
}
}
Value::Array(entries) => {
for entry in entries {
if let Some(text) = entry
.as_str()
.map(str::trim)
.filter(|value| !value.is_empty())
{
results.push(text.to_string());
}
}
}
_ => {}
}
}
collect_strings_by_key(nested_value, target_key, results);
}
}
_ => {}
}
}
fn find_first_string_by_key(value: &Value, target_key: &str) -> Option<String> {
let mut results = Vec::new();
collect_strings_by_key(value, target_key, &mut results);
results.into_iter().next()
}
fn extract_generation_id(payload: &Value) -> Option<String> {
find_first_string_by_key(payload, "id")
.or_else(|| find_first_string_by_key(payload, "created"))
.or_else(|| find_first_string_by_key(payload, "request_id"))
}
fn extract_image_urls(payload: &Value) -> Vec<String> {
let mut urls = Vec::new();
collect_strings_by_key(payload, "url", &mut urls);
collect_strings_by_key(payload, "image", &mut urls);
collect_strings_by_key(payload, "image_url", &mut urls);
let mut deduped = Vec::new();
for url in urls {
if (url.starts_with("http://") || url.starts_with("https://")) && !deduped.contains(&url) {
deduped.push(url);
}
}
deduped
}
fn extract_b64_images(payload: &Value) -> Vec<String> {
let mut values = Vec::new();
collect_strings_by_key(payload, "b64_json", &mut values);
values
}
fn normalize_downloaded_image_mime_type(content_type: &str) -> String {
let mime_type = content_type
.split(';')
.next()
.map(str::trim)
.unwrap_or("image/jpeg");
match mime_type {
"image/png" | "image/webp" | "image/jpeg" | "image/jpg" | "image/gif" => {
mime_type.to_string()
}
_ => "image/jpeg".to_string(),
}
}
fn mime_to_extension(mime_type: &str) -> &str {
match mime_type {
"image/png" => "png",
"image/webp" => "webp",
"image/gif" => "gif",
_ => "jpg",
}
}
fn infer_image_mime_type(bytes: &[u8]) -> String {
if bytes.starts_with(b"\x89PNG\r\n\x1A\n") {
return "image/png".to_string();
}
if bytes.starts_with(b"\xFF\xD8\xFF") {
return "image/jpeg".to_string();
}
if bytes.starts_with(b"RIFF") && bytes.get(8..12) == Some(b"WEBP") {
return "image/webp".to_string();
}
if bytes.starts_with(b"GIF87a") || bytes.starts_with(b"GIF89a") {
return "image/gif".to_string();
}
"image/png".to_string()
}
fn is_timeout_message(message: &str) -> bool {
let lower = message.to_ascii_lowercase();
lower.contains("timed out")
|| lower.contains("timeout")
|| lower.contains("operation timed out")
|| lower.contains("deadline has elapsed")
}
fn truncate_raw(raw_text: &str) -> String {
raw_text.chars().take(800).collect()
}
fn current_utc_micros() -> i64 {
use std::time::{SystemTime, UNIX_EPOCH};
let duration = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time should be after unix epoch");
i64::try_from(duration.as_micros()).expect("current unix micros should fit in i64")
}
impl PlatformImageError {
fn with_audit(self, audit: PlatformImageFailureAudit) -> Self {
match self {
Self::Request {
provider,
message,
endpoint,
timeout,
connect,
request,
body,
status_code,
source,
..
} => Self::Request {
provider,
message,
endpoint,
timeout,
connect,
request,
body,
status_code,
source,
audit: Some(audit),
},
Self::Upstream {
provider,
message,
upstream_status,
raw_excerpt,
..
} => Self::Upstream {
provider,
message,
upstream_status,
raw_excerpt,
audit: Some(audit),
},
Self::ResponseParse {
provider,
message,
raw_excerpt,
..
} => Self::ResponseParse {
provider,
message,
raw_excerpt,
audit: Some(audit),
},
Self::MissingImage {
provider, message, ..
} => Self::MissingImage {
provider,
message,
audit: Some(audit),
},
Self::InvalidConfig { .. } | Self::InvalidRequest { .. } => self,
}
}
}
struct ParsedJsonPayload {
payload: Value,
}
#[cfg(test)]
mod tests {
use super::*;
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use serde_json::json;
#[test]
fn request_body_normalizes_size_prompt_and_candidate_count() {
let body = build_vector_engine_image_request_body(
" 风雨夜里的街道 ",
Some(" 低清,水印 "),
" 1:1 ",
10,
&["data:image/png;base64,AAAA".to_string()],
);
assert_eq!(body["model"], GPT_IMAGE_2_MODEL);
assert_eq!(body["size"], "1024x1024");
assert_eq!(body["n"], 4);
assert_eq!(body["prompt"], "风雨夜里的街道\n避免:低清,水印");
assert!(body.get("image").is_none());
}
#[test]
fn provider_urls_normalize_root_and_v1_base_urls() {
let root_settings = VectorEngineImageSettings {
base_url: "https://vector.example".to_string(),
api_key: "test-key".to_string(),
request_timeout_ms: 1_000,
};
let v1_settings = VectorEngineImageSettings {
base_url: "https://vector.example/v1".to_string(),
api_key: "test-key".to_string(),
request_timeout_ms: 1_000,
};
assert_eq!(
vector_engine_images_generation_url(&root_settings),
"https://vector.example/v1/images/generations"
);
assert_eq!(
vector_engine_images_generation_url(&v1_settings),
"https://vector.example/v1/images/generations"
);
assert_eq!(
vector_engine_images_edit_url(&root_settings),
"https://vector.example/v1/images/edits"
);
assert_eq!(
vector_engine_images_edit_url(&v1_settings),
"https://vector.example/v1/images/edits"
);
}
#[test]
fn data_url_and_base64_image_decoding_preserves_image_metadata() {
let data_url = format!(
"data:image/png;base64,{}",
BASE64_STANDARD.encode(b"\x89PNG\r\n\x1A\nrest")
);
let reference = parse_reference_image_data_url(&data_url, 2)
.expect("data url should parse")
.expect("image data url should be accepted");
assert_eq!(reference.file_name, "reference-2.png");
assert_eq!(reference.mime_type, "image/png");
assert_eq!(reference.bytes, b"\x89PNG\r\n\x1A\nrest");
let image = decode_generated_image_base64(BASE64_STANDARD.encode(b"\x89PNG\r\n\x1A\nrest").as_str())
.expect("base64 image should decode");
assert_eq!(image.extension, "png");
assert_eq!(image.mime_type, "image/png");
assert_eq!(image.bytes, b"\x89PNG\r\n\x1A\nrest");
}
#[test]
fn error_status_hints_and_audit_fields_are_structured() {
let audit = PlatformImageFailureAudit {
provider: VECTOR_ENGINE_PROVIDER,
endpoint: "https://vector.example/v1/images/generations".to_string(),
operation: "图片生成失败".to_string(),
failure_stage: "upstream_status",
status_code: Some(504),
status_class: Some("5xx"),
timeout: true,
retryable: true,
error_message: "上游超时".to_string(),
error_source: Some("read timeout".to_string()),
raw_excerpt: Some("{\"error\":\"timeout\"}".to_string()),
latency_ms: Some(987),
prompt_chars: Some(64),
reference_image_count: Some(2),
image_model: Some(VECTOR_ENGINE_GPT_IMAGE_2_MODEL),
};
let request_error = PlatformImageError::Request {
provider: VECTOR_ENGINE_PROVIDER,
message: "请求发送失败".to_string(),
endpoint: Some("https://vector.example/v1/images/generations".to_string()),
timeout: true,
connect: false,
request: true,
body: false,
status_code: None,
source: None,
audit: None,
};
let invalid_config = PlatformImageError::InvalidConfig {
provider: VECTOR_ENGINE_PROVIDER,
message: "缺少配置".to_string(),
};
let invalid_request = PlatformImageError::InvalidRequest {
provider: VECTOR_ENGINE_PROVIDER,
message: "请求不合法".to_string(),
};
let upstream_timeout = PlatformImageError::Upstream {
provider: VECTOR_ENGINE_PROVIDER,
message: "upstream timeout".to_string(),
upstream_status: 502,
raw_excerpt: "deadline has elapsed".to_string(),
audit: Some(audit.clone()),
};
assert_eq!(invalid_config.status_hint(), PlatformImageStatusHint::ServiceUnavailable);
assert_eq!(invalid_request.status_hint(), PlatformImageStatusHint::BadRequest);
assert_eq!(request_error.status_hint(), PlatformImageStatusHint::GatewayTimeout);
assert_eq!(upstream_timeout.status_hint(), PlatformImageStatusHint::GatewayTimeout);
assert_eq!(
PlatformImageError::MissingImage {
provider: VECTOR_ENGINE_PROVIDER,
message: "缺图".to_string(),
audit: Some(audit.clone()),
}
.status_hint(),
PlatformImageStatusHint::BadGateway
);
let audit_ref = upstream_timeout.audit().expect("audit should be preserved");
assert_eq!(audit_ref.provider, VECTOR_ENGINE_PROVIDER);
assert_eq!(audit_ref.endpoint, "https://vector.example/v1/images/generations");
assert_eq!(audit_ref.status_code, Some(504));
assert_eq!(audit_ref.status_class, Some("5xx"));
assert!(audit_ref.timeout);
assert!(audit_ref.retryable);
assert_eq!(audit_ref.reference_image_count, Some(2));
assert_eq!(audit_ref.image_model, Some(VECTOR_ENGINE_GPT_IMAGE_2_MODEL));
assert!(invalid_config.audit().is_none());
assert!(invalid_request.audit().is_none());
}
#[test]
fn extract_image_urls_and_b64_values_are_deduped() {
let payload = json!({
"data": [
{"image": "https://example.com/a.png"},
{"url": "https://example.com/a.png"},
{"image_url": "ftp://example.com/b.png"},
{"url": "https://example.com/b.png"}
],
"nested": {
"b64_json": ["YWJj", "ZGVm"]
}
});
assert_eq!(
extract_image_urls(&payload),
vec![
"https://example.com/a.png".to_string(),
"https://example.com/b.png".to_string()
]
);
assert_eq!(
extract_b64_images(&payload),
vec!["YWJj".to_string(), "ZGVm".to_string()]
);
}
}