Files
Genarrative/server-rs/crates/api-server/src/hyper3d_generation.rs
2026-05-10 22:20:54 +08:00

1081 lines
36 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use std::time::Duration;
use axum::{
Json,
extract::{State, rejection::JsonRejection},
http::StatusCode,
response::Response,
};
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
use reqwest::{header, multipart};
use serde_json::{Value, json};
use shared_contracts::hyper3d as contract;
use crate::{
api_response::json_success_body, http_error::AppError, request_context::RequestContext,
state::AppState,
};
const HYPER3D_PROVIDER: &str = "hyper3d-rodin";
const RODIN_GEN2_TIER: &str = "Gen-2";
const DEFAULT_GEOMETRY_FILE_FORMAT: &str = "glb";
const DEFAULT_MATERIAL: &str = "PBR";
const DEFAULT_QUALITY: &str = "medium";
const DEFAULT_MESH_MODE: &str = "Quad";
const DEFAULT_CONDITION_MODE: &str = "concat";
const MAX_PROMPT_CHARS: usize = 2_000;
const MAX_NEGATIVE_PROMPT_CHARS: usize = 1_000;
const MAX_IMAGE_COUNT: usize = 5;
const MAX_IMAGE_BYTES: usize = 10 * 1024 * 1024;
#[derive(Clone, Debug)]
struct Hyper3dSettings {
base_url: String,
api_key: String,
request_timeout_ms: u64,
}
#[derive(Clone, Debug)]
struct DecodedImageDataUrl {
bytes: Vec<u8>,
mime_type: String,
file_name: String,
}
#[derive(Clone, Debug)]
struct SubmitOptions {
seed: Option<u32>,
geometry_file_format: String,
material: String,
quality: String,
mesh_mode: String,
addons: Vec<String>,
bbox_condition: Option<Vec<f32>>,
preview_render: bool,
}
pub async fn submit_hyper3d_text_to_model(
State(state): State<AppState>,
axum::extract::Extension(request_context): axum::extract::Extension<RequestContext>,
payload: Result<Json<contract::Hyper3dTextToModelRequest>, JsonRejection>,
) -> Result<Json<Value>, Response> {
let Json(payload) = parse_json_payload(&request_context, payload)?;
submit_text_to_model(&state, payload)
.await
.map(|payload| json_success_body(Some(&request_context), payload))
.map_err(|error| error.into_response_with_context(Some(&request_context)))
}
pub async fn submit_hyper3d_image_to_model(
State(state): State<AppState>,
axum::extract::Extension(request_context): axum::extract::Extension<RequestContext>,
payload: Result<Json<contract::Hyper3dImageToModelRequest>, JsonRejection>,
) -> Result<Json<Value>, Response> {
let Json(payload) = parse_json_payload(&request_context, payload)?;
submit_image_to_model(&state, payload)
.await
.map(|payload| json_success_body(Some(&request_context), payload))
.map_err(|error| error.into_response_with_context(Some(&request_context)))
}
pub async fn get_hyper3d_task_status(
State(state): State<AppState>,
axum::extract::Extension(request_context): axum::extract::Extension<RequestContext>,
payload: Result<Json<contract::Hyper3dTaskStatusRequest>, JsonRejection>,
) -> Result<Json<Value>, Response> {
let Json(payload) = parse_json_payload(&request_context, payload)?;
query_task_status(&state, payload)
.await
.map(|payload| json_success_body(Some(&request_context), payload))
.map_err(|error| error.into_response_with_context(Some(&request_context)))
}
pub async fn get_hyper3d_downloads(
State(state): State<AppState>,
axum::extract::Extension(request_context): axum::extract::Extension<RequestContext>,
payload: Result<Json<contract::Hyper3dDownloadRequest>, JsonRejection>,
) -> Result<Json<Value>, Response> {
let Json(payload) = parse_json_payload(&request_context, payload)?;
query_downloads(&state, payload)
.await
.map(|payload| json_success_body(Some(&request_context), payload))
.map_err(|error| error.into_response_with_context(Some(&request_context)))
}
async fn submit_text_to_model(
state: &AppState,
payload: contract::Hyper3dTextToModelRequest,
) -> Result<contract::Hyper3dTaskSubmitResponse, AppError> {
let settings = require_hyper3d_settings(state)?;
let http_client = build_hyper3d_http_client(&settings)?;
let prompt = normalize_required_text(&payload.prompt, "prompt", MAX_PROMPT_CHARS)?;
let options = SubmitOptions::from_text_request(&payload)?;
let mut form = multipart::Form::new()
.text("tier", RODIN_GEN2_TIER.to_string())
.text("prompt", prompt);
form = append_common_submit_fields(form, &options)?;
if let Some(negative_prompt) = normalize_optional_limited_text(
payload.negative_prompt.as_deref(),
MAX_NEGATIVE_PROMPT_CHARS,
)? {
form = form.text("negative_prompt", negative_prompt);
}
let response = post_hyper3d_multipart(
&http_client,
&settings,
"/rodin",
form,
"提交 Hyper3D 文生模型任务失败",
)
.await?;
Ok(build_submit_response(
contract::Hyper3dGenerationMode::TextToModel,
response,
)?)
}
pub(crate) async fn submit_image_to_model(
state: &AppState,
payload: contract::Hyper3dImageToModelRequest,
) -> Result<contract::Hyper3dTaskSubmitResponse, AppError> {
let settings = require_hyper3d_settings(state)?;
let http_client = build_hyper3d_http_client(&settings)?;
let options = SubmitOptions::from_image_request(&payload)?;
let mut form = multipart::Form::new().text("tier", RODIN_GEN2_TIER.to_string());
form = append_common_submit_fields(form, &options)?;
let condition_mode = normalize_enum(
payload.condition_mode.as_deref(),
DEFAULT_CONDITION_MODE,
&["concat", "fuse"],
"conditionMode",
)?;
form = form.text("condition_mode", condition_mode);
if let Some(prompt) =
normalize_optional_limited_text(payload.prompt.as_deref(), MAX_PROMPT_CHARS)?
{
form = form.text("prompt", prompt);
}
for image_url in payload
.image_urls
.iter()
.map(|value| value.trim())
.filter(|value| !value.is_empty())
{
form = form.text("image_urls", image_url.to_string());
}
for image in decode_image_data_urls(&payload.image_data_urls)? {
let part = multipart::Part::bytes(image.bytes)
.file_name(image.file_name)
.mime_str(&image.mime_type)
.map_err(|error| {
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
"provider": HYPER3D_PROVIDER,
"message": format!("构造图生模型图片字段失败:{error}"),
}))
})?;
form = form.part("images", part);
}
if payload.image_data_urls.is_empty() && payload.image_urls.is_empty() {
return Err(
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
"provider": HYPER3D_PROVIDER,
"field": "imageDataUrls",
"message": "图生模型至少需要一张参考图",
})),
);
}
if payload.image_data_urls.len() + payload.image_urls.len() > MAX_IMAGE_COUNT {
return Err(
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
"provider": HYPER3D_PROVIDER,
"field": "imageDataUrls",
"message": format!("图生模型最多支持 {} 张参考图", MAX_IMAGE_COUNT),
})),
);
}
let response = post_hyper3d_multipart(
&http_client,
&settings,
"/rodin",
form,
"提交 Hyper3D 图生模型任务失败",
)
.await?;
Ok(build_submit_response(
contract::Hyper3dGenerationMode::ImageToModel,
response,
)?)
}
pub(crate) async fn query_task_status(
state: &AppState,
payload: contract::Hyper3dTaskStatusRequest,
) -> Result<contract::Hyper3dTaskStatusResponse, AppError> {
let settings = require_hyper3d_settings(state)?;
let http_client = build_hyper3d_http_client(&settings)?;
// 中文注释Hyper3D 返回的 subscriptionKey 是上游 opaque token只做非空校验不做人为 256 字符截断。
let subscription_key =
normalize_required_opaque_text(&payload.subscription_key, "subscriptionKey")?;
let response = post_hyper3d_json(
&http_client,
&settings,
"/status",
json!({ "subscription_key": subscription_key }),
"查询 Hyper3D 模型任务状态失败",
)
.await?;
let jobs = extract_job_statuses(&response);
let status = normalize_task_status(
find_first_string_by_key(&response, "status")
.or_else(|| jobs.first().map(|job| job.status.clone()))
.as_deref()
.unwrap_or("unknown"),
);
Ok(contract::Hyper3dTaskStatusResponse {
ok: true,
provider: HYPER3D_PROVIDER.to_string(),
status,
jobs,
raw: response,
})
}
pub(crate) async fn query_downloads(
state: &AppState,
payload: contract::Hyper3dDownloadRequest,
) -> Result<contract::Hyper3dDownloadResponse, AppError> {
let settings = require_hyper3d_settings(state)?;
let http_client = build_hyper3d_http_client(&settings)?;
let task_uuid = normalize_required_text(&payload.task_uuid, "taskUuid", 256)?;
let response = post_hyper3d_json(
&http_client,
&settings,
"/download",
json!({ "task_uuid": task_uuid }),
"获取 Hyper3D 模型下载列表失败",
)
.await?;
Ok(contract::Hyper3dDownloadResponse {
ok: true,
provider: HYPER3D_PROVIDER.to_string(),
files: extract_download_files(&response),
raw: response,
})
}
impl SubmitOptions {
fn from_text_request(payload: &contract::Hyper3dTextToModelRequest) -> Result<Self, AppError> {
Self::new(
payload.seed,
payload.geometry_file_format.as_deref(),
payload.material.as_deref(),
payload.quality.as_deref(),
payload.mesh_mode.as_deref(),
payload.addons.clone(),
payload.bbox_condition.clone(),
payload.preview_render,
)
}
fn from_image_request(
payload: &contract::Hyper3dImageToModelRequest,
) -> Result<Self, AppError> {
Self::new(
payload.seed,
payload.geometry_file_format.as_deref(),
payload.material.as_deref(),
payload.quality.as_deref(),
payload.mesh_mode.as_deref(),
payload.addons.clone(),
payload.bbox_condition.clone(),
payload.preview_render,
)
}
#[allow(clippy::too_many_arguments)]
fn new(
seed: Option<u32>,
geometry_file_format: Option<&str>,
material: Option<&str>,
quality: Option<&str>,
mesh_mode: Option<&str>,
addons: Vec<String>,
bbox_condition: Option<Vec<f32>>,
preview_render: Option<bool>,
) -> Result<Self, AppError> {
Ok(Self {
seed,
geometry_file_format: normalize_enum(
geometry_file_format,
DEFAULT_GEOMETRY_FILE_FORMAT,
&["glb", "usdz", "fbx", "obj", "stl"],
"geometryFileFormat",
)?,
material: normalize_enum(
material,
DEFAULT_MATERIAL,
&["PBR", "Shaded", "All"],
"material",
)?,
quality: normalize_enum(
quality,
DEFAULT_QUALITY,
&["high", "medium", "low", "extra-low"],
"quality",
)?,
mesh_mode: normalize_enum(mesh_mode, DEFAULT_MESH_MODE, &["Quad", "Raw"], "meshMode")?,
addons: normalize_addons(addons)?,
bbox_condition: normalize_bbox_condition(bbox_condition)?,
preview_render: preview_render.unwrap_or(true),
})
}
}
fn append_common_submit_fields(
mut form: multipart::Form,
options: &SubmitOptions,
) -> Result<multipart::Form, AppError> {
form = form
.text(
"geometry_file_format",
options.geometry_file_format.to_string(),
)
.text("material", options.material.to_string())
.text("quality", options.quality.to_string())
.text("mesh_mode", options.mesh_mode.to_string())
.text("preview_render", options.preview_render.to_string());
if let Some(seed) = options.seed {
form = form.text("seed", seed.to_string());
}
for addon in &options.addons {
form = form.text("addons", addon.to_string());
}
if let Some(bbox_condition) = &options.bbox_condition {
form = form.text(
"bbox_condition",
serde_json::to_string(bbox_condition).map_err(|error| {
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
"provider": HYPER3D_PROVIDER,
"field": "bboxCondition",
"message": format!("bboxCondition 序列化失败:{error}"),
}))
})?,
);
}
Ok(form)
}
fn require_hyper3d_settings(state: &AppState) -> Result<Hyper3dSettings, AppError> {
let base_url = state.config.hyper3d_base_url.trim().trim_end_matches('/');
if base_url.is_empty() {
return Err(
AppError::from_status(StatusCode::SERVICE_UNAVAILABLE).with_details(json!({
"provider": HYPER3D_PROVIDER,
"reason": "HYPER3D_BASE_URL 未配置",
"message": "Hyper3D Rodin 服务地址未配置,请设置 HYPER3D_BASE_URL 或 RODIN_BASE_URL 后重启 api-server。",
})),
);
}
let api_key = state
.config
.hyper3d_api_key
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty())
.ok_or_else(|| {
AppError::from_status(StatusCode::SERVICE_UNAVAILABLE).with_details(json!({
"provider": HYPER3D_PROVIDER,
"reason": "HYPER3D_API_KEY 未配置",
"message": "Hyper3D Rodin API Key 未配置,请在本地私密环境设置 HYPER3D_API_KEY 或 RODIN_API_KEY 后重启 api-server。",
}))
})?;
Ok(Hyper3dSettings {
base_url: base_url.to_string(),
api_key: api_key.to_string(),
request_timeout_ms: state.config.hyper3d_model_request_timeout_ms.max(1),
})
}
fn build_hyper3d_http_client(settings: &Hyper3dSettings) -> Result<reqwest::Client, AppError> {
reqwest::Client::builder()
.timeout(Duration::from_millis(settings.request_timeout_ms))
.build()
.map_err(|error| {
AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR).with_details(json!({
"provider": HYPER3D_PROVIDER,
"message": format!("构造 Hyper3D HTTP 客户端失败:{error}"),
}))
})
}
async fn post_hyper3d_multipart(
http_client: &reqwest::Client,
settings: &Hyper3dSettings,
path: &str,
form: multipart::Form,
failure_context: &str,
) -> Result<Value, AppError> {
let response = http_client
.post(format!("{}{}", settings.base_url, path))
.header(
header::AUTHORIZATION,
format!("Bearer {}", settings.api_key),
)
.header(header::ACCEPT, "application/json")
.multipart(form)
.send()
.await
.map_err(|error| hyper3d_bad_gateway(format!("{failure_context}{error}")))?;
parse_hyper3d_response(response, failure_context).await
}
async fn post_hyper3d_json(
http_client: &reqwest::Client,
settings: &Hyper3dSettings,
path: &str,
body: Value,
failure_context: &str,
) -> Result<Value, AppError> {
let response = http_client
.post(format!("{}{}", settings.base_url, path))
.header(
header::AUTHORIZATION,
format!("Bearer {}", settings.api_key),
)
.header(header::ACCEPT, "application/json")
.header(header::CONTENT_TYPE, "application/json")
.json(&body)
.send()
.await
.map_err(|error| hyper3d_bad_gateway(format!("{failure_context}{error}")))?;
parse_hyper3d_response(response, failure_context).await
}
async fn parse_hyper3d_response(
response: reqwest::Response,
failure_context: &str,
) -> Result<Value, AppError> {
let status = response.status();
let raw_text = response.text().await.map_err(|error| {
hyper3d_bad_gateway(format!("{failure_context}:读取上游响应失败:{error}"))
})?;
if !status.is_success() {
return Err(
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
"provider": HYPER3D_PROVIDER,
"message": parse_api_error_message(&raw_text, failure_context),
"status": status.as_u16(),
"rawExcerpt": truncate_raw(&raw_text),
})),
);
}
serde_json::from_str::<Value>(&raw_text).map_err(|error| {
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
"provider": HYPER3D_PROVIDER,
"message": format!("{failure_context}:解析上游 JSON 失败:{error}"),
"rawExcerpt": truncate_raw(&raw_text),
}))
})
}
fn build_submit_response(
mode: contract::Hyper3dGenerationMode,
response: Value,
) -> Result<contract::Hyper3dTaskSubmitResponse, AppError> {
let task_uuid = find_root_string_by_keys(&response, &["uuid", "task_uuid", "taskUuid"])
.or_else(|| find_first_string_by_keys(&response, &["task_uuid", "taskUuid"]))
.ok_or_else(|| hyper3d_bad_gateway("Hyper3D 已响应,但未返回任务 uuid"))?;
let subscription_key =
find_root_string_by_keys(&response, &["subscription_key", "subscriptionKey"])
.or_else(|| {
find_first_string_by_keys(&response, &["subscription_key", "subscriptionKey"])
})
.ok_or_else(|| hyper3d_bad_gateway("Hyper3D 已响应,但未返回 subscription_key"))?;
let job_uuids = extract_job_uuids(&response);
let message = find_first_string_by_keys(&response, &["message", "detail"]);
Ok(contract::Hyper3dTaskSubmitResponse {
ok: true,
provider: HYPER3D_PROVIDER.to_string(),
mode,
task_uuid,
subscription_key,
job_uuids,
message,
tier: RODIN_GEN2_TIER.to_string(),
})
}
fn extract_job_statuses(payload: &Value) -> Vec<contract::Hyper3dJobStatusPayload> {
let Some(array) = find_first_array_by_keys(payload, &["jobs", "tasks"]) else {
return Vec::new();
};
array
.iter()
.filter_map(|value| {
let status = find_first_string_by_keys(value, &["status", "state"])
.map(|value| normalize_task_status(&value))?;
Some(contract::Hyper3dJobStatusPayload {
uuid: find_first_string_by_keys(value, &["uuid", "task_uuid", "taskUuid"]),
progress: find_first_f64_by_keys(value, &["progress", "percentage"])
.map(|value| value as f32),
message: find_first_string_by_keys(value, &["message", "detail", "error"]),
status,
})
})
.collect()
}
fn extract_job_uuids(payload: &Value) -> Vec<String> {
let mut job_uuids = Vec::new();
if let Some(jobs) = find_first_array_by_keys(payload, &["jobs"]) {
for job in jobs {
if let Some(uuid) = find_first_string_by_keys(job, &["uuid", "task_uuid", "taskUuid"])
&& !job_uuids.contains(&uuid)
{
job_uuids.push(uuid);
}
}
}
for uuid in collect_strings_by_keys(payload, &["job_uuids", "jobUuids", "uuids"]) {
if !job_uuids.contains(&uuid) {
job_uuids.push(uuid);
}
}
job_uuids
}
fn extract_download_files(payload: &Value) -> Vec<contract::Hyper3dDownloadFilePayload> {
let mut files = Vec::new();
collect_download_files(payload, &mut files);
let mut deduped = Vec::new();
for file in files {
if !deduped
.iter()
.any(|entry: &contract::Hyper3dDownloadFilePayload| entry.url == file.url)
{
deduped.push(file);
}
}
deduped
}
fn collect_download_files(value: &Value, output: &mut Vec<contract::Hyper3dDownloadFilePayload>) {
match value {
Value::Object(object) => {
let maybe_url = object
.get("url")
.or_else(|| object.get("download_url"))
.or_else(|| object.get("downloadUrl"))
.and_then(Value::as_str)
.map(str::trim)
.filter(|value| value.starts_with("http://") || value.starts_with("https://"));
if let Some(url) = maybe_url {
let name = object
.get("name")
.or_else(|| object.get("file_name"))
.or_else(|| object.get("filename"))
.and_then(Value::as_str)
.map(str::trim)
.filter(|value| !value.is_empty())
.unwrap_or("model")
.to_string();
output.push(contract::Hyper3dDownloadFilePayload {
name,
url: url.to_string(),
});
}
for nested in object.values() {
collect_download_files(nested, output);
}
}
Value::Array(items) => {
for item in items {
collect_download_files(item, output);
}
}
_ => {}
}
}
fn decode_image_data_urls(values: &[String]) -> Result<Vec<DecodedImageDataUrl>, AppError> {
values
.iter()
.enumerate()
.map(|(index, value)| decode_image_data_url(value, index + 1))
.collect()
}
fn decode_image_data_url(value: &str, index: usize) -> Result<DecodedImageDataUrl, AppError> {
let value = value.trim();
let Some((metadata, encoded)) = value.split_once(',') else {
return Err(invalid_image_data_url("参考图必须是 data URL"));
};
if !metadata.starts_with("data:image/") || !metadata.ends_with(";base64") {
return Err(invalid_image_data_url(
"参考图只支持 image/png、image/jpeg 或 image/webp 的 base64 data URL",
));
}
let mime_type = metadata
.trim_start_matches("data:")
.trim_end_matches(";base64")
.to_string();
let extension = match mime_type.as_str() {
"image/png" => "png",
"image/jpeg" | "image/jpg" => "jpg",
"image/webp" => "webp",
_ => {
return Err(invalid_image_data_url(
"参考图只支持 image/png、image/jpeg 或 image/webp",
));
}
};
let bytes = BASE64_STANDARD
.decode(encoded)
.map_err(|_| invalid_image_data_url("参考图 base64 解码失败"))?;
if bytes.is_empty() || bytes.len() > MAX_IMAGE_BYTES {
return Err(invalid_image_data_url("参考图为空或超过 10MB"));
}
Ok(DecodedImageDataUrl {
bytes,
mime_type,
file_name: format!("reference-{index:02}.{extension}"),
})
}
fn invalid_image_data_url(message: &str) -> AppError {
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
"provider": HYPER3D_PROVIDER,
"field": "imageDataUrls",
"message": message,
}))
}
fn normalize_required_text(
value: &str,
field: &'static str,
max_chars: usize,
) -> Result<String, AppError> {
let normalized = value.trim().to_string();
if normalized.is_empty() {
return Err(
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
"provider": HYPER3D_PROVIDER,
"field": field,
"message": format!("{field} 不能为空"),
})),
);
}
if normalized.chars().count() > max_chars {
return Err(
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
"provider": HYPER3D_PROVIDER,
"field": field,
"message": format!("{field} 超过 {} 字符", max_chars),
})),
);
}
Ok(normalized)
}
fn normalize_required_opaque_text(value: &str, field: &'static str) -> Result<String, AppError> {
let normalized = value.trim().to_string();
if normalized.is_empty() {
return Err(
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
"provider": HYPER3D_PROVIDER,
"field": field,
"message": format!("{field} 不能为空"),
})),
);
}
Ok(normalized)
}
fn normalize_optional_limited_text(
value: Option<&str>,
max_chars: usize,
) -> Result<Option<String>, AppError> {
let Some(normalized) = value.map(str::trim).filter(|value| !value.is_empty()) else {
return Ok(None);
};
if normalized.chars().count() > max_chars {
return Err(
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
"provider": HYPER3D_PROVIDER,
"message": format!("文本超过 {} 字符", max_chars),
})),
);
}
Ok(Some(normalized.to_string()))
}
fn normalize_enum(
value: Option<&str>,
default_value: &str,
allowed_values: &[&str],
field: &'static str,
) -> Result<String, AppError> {
let value = value
.map(str::trim)
.filter(|value| !value.is_empty())
.unwrap_or(default_value);
if let Some(allowed) = allowed_values
.iter()
.find(|allowed| allowed.eq_ignore_ascii_case(value))
{
return Ok((*allowed).to_string());
}
Err(
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
"provider": HYPER3D_PROVIDER,
"field": field,
"message": format!("{} 取值非法", field),
"allowed": allowed_values,
})),
)
}
fn normalize_addons(values: Vec<String>) -> Result<Vec<String>, AppError> {
let mut addons = Vec::new();
for value in values {
let value = value.trim();
if value.is_empty() {
continue;
}
if value != "HighPack" {
return Err(
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
"provider": HYPER3D_PROVIDER,
"field": "addons",
"message": "addons 首版只支持 HighPack",
})),
);
}
if !addons.iter().any(|addon| addon == value) {
addons.push(value.to_string());
}
}
Ok(addons)
}
fn normalize_bbox_condition(value: Option<Vec<f32>>) -> Result<Option<Vec<f32>>, AppError> {
let Some(value) = value else {
return Ok(None);
};
if value.len() != 3 || value.iter().any(|item| !item.is_finite() || *item <= 0.0) {
return Err(
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
"provider": HYPER3D_PROVIDER,
"field": "bboxCondition",
"message": "bboxCondition 必须包含 3 个正数",
})),
);
}
Ok(Some(value))
}
fn normalize_task_status(status: &str) -> String {
match status.trim().to_ascii_lowercase().as_str() {
"waiting" | "pending" | "queued" => "waiting".to_string(),
"generating" | "running" | "processing" => "generating".to_string(),
"done" | "finished" | "completed" | "success" | "succeeded" => "done".to_string(),
"failed" | "error" | "canceled" | "cancelled" => "failed".to_string(),
_ => "unknown".to_string(),
}
}
fn parse_api_error_message(raw_text: &str, fallback_message: &str) -> String {
if let Ok(parsed) = serde_json::from_str::<Value>(raw_text) {
for key in ["message", "detail", "error"] {
if let Some(message) = find_first_string_by_key(&parsed, key)
&& !message.trim().is_empty()
{
return message;
}
}
}
raw_text
.trim()
.chars()
.take(240)
.collect::<String>()
.trim()
.to_string()
.chars()
.next()
.map(|_| raw_text.trim().chars().take(240).collect())
.unwrap_or_else(|| fallback_message.to_string())
}
fn find_first_array_by_keys<'a>(value: &'a Value, keys: &[&str]) -> Option<&'a Vec<Value>> {
match value {
Value::Object(object) => {
for (key, value) in object {
if keys.iter().any(|target| key.eq_ignore_ascii_case(target))
&& let Some(array) = value.as_array()
{
return Some(array);
}
if let Some(found) = find_first_array_by_keys(value, keys) {
return Some(found);
}
}
None
}
Value::Array(items) => items
.iter()
.find_map(|item| find_first_array_by_keys(item, keys)),
_ => None,
}
}
fn find_first_string_by_keys(value: &Value, keys: &[&str]) -> Option<String> {
keys.iter()
.find_map(|key| find_first_string_by_key(value, key))
}
fn find_root_string_by_keys(value: &Value, keys: &[&str]) -> Option<String> {
let object = value.as_object()?;
for key in keys {
if let Some(text) = object
.iter()
.find(|(candidate, _)| candidate.eq_ignore_ascii_case(key))
.and_then(|(_, value)| value.as_str())
.map(str::trim)
.filter(|value| !value.is_empty())
{
return Some(text.to_string());
}
}
None
}
fn find_first_string_by_key(value: &Value, target_key: &str) -> Option<String> {
match value {
Value::Object(object) => {
for (key, value) in object {
if key.eq_ignore_ascii_case(target_key)
&& let Some(text) = value.as_str()
{
return Some(text.trim().to_string());
}
if let Some(found) = find_first_string_by_key(value, target_key) {
return Some(found);
}
}
None
}
Value::Array(items) => items
.iter()
.find_map(|item| find_first_string_by_key(item, target_key)),
_ => None,
}
}
fn find_first_f64_by_keys(value: &Value, keys: &[&str]) -> Option<f64> {
match value {
Value::Object(object) => {
for (key, value) in object {
if keys.iter().any(|target| key.eq_ignore_ascii_case(target))
&& let Some(number) = value.as_f64()
{
return Some(number);
}
if let Some(found) = find_first_f64_by_keys(value, keys) {
return Some(found);
}
}
None
}
Value::Array(items) => items
.iter()
.find_map(|item| find_first_f64_by_keys(item, keys)),
_ => None,
}
}
fn collect_strings_by_keys(value: &Value, keys: &[&str]) -> Vec<String> {
let mut results = Vec::new();
collect_strings(value, keys, &mut results);
let mut deduped = Vec::new();
for result in results {
if !deduped.contains(&result) {
deduped.push(result);
}
}
deduped
}
fn collect_strings(value: &Value, keys: &[&str], output: &mut Vec<String>) {
match value {
Value::Object(object) => {
for (key, value) in object {
if keys.iter().any(|target| key.eq_ignore_ascii_case(target)) {
match value {
Value::String(text) if !text.trim().is_empty() => {
output.push(text.trim().to_string());
}
Value::Array(items) => {
for item in items {
if let Some(text) = item.as_str().map(str::trim)
&& !text.is_empty()
{
output.push(text.to_string());
}
}
}
_ => {}
}
}
collect_strings(value, keys, output);
}
}
Value::Array(items) => {
for item in items {
collect_strings(item, keys, output);
}
}
_ => {}
}
}
fn truncate_raw(raw_text: &str) -> String {
raw_text.chars().take(800).collect()
}
fn hyper3d_bad_gateway(message: impl Into<String>) -> AppError {
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
"provider": HYPER3D_PROVIDER,
"message": message.into(),
}))
}
fn parse_json_payload<T>(
request_context: &RequestContext,
payload: Result<Json<T>, JsonRejection>,
) -> Result<Json<T>, Response> {
payload.map_err(|rejection| {
AppError::from_status(StatusCode::BAD_REQUEST)
.with_message(format!("请求体 JSON 不合法:{rejection}"))
.into_response_with_context(Some(request_context))
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn validates_and_defaults_submit_options() {
let payload = contract::Hyper3dTextToModelRequest {
prompt: "宝箱".to_string(),
negative_prompt: None,
seed: Some(7),
geometry_file_format: None,
material: None,
quality: None,
mesh_mode: None,
addons: vec!["HighPack".to_string()],
bbox_condition: Some(vec![1.0, 2.0, 3.0]),
preview_render: None,
};
let options = SubmitOptions::from_text_request(&payload).expect("options should build");
assert_eq!(options.geometry_file_format, "glb");
assert_eq!(options.material, "PBR");
assert_eq!(options.quality, "medium");
assert_eq!(options.mesh_mode, "Quad");
assert_eq!(options.addons, vec!["HighPack"]);
assert!(options.preview_render);
}
#[test]
fn rejects_invalid_bbox_condition() {
let error = normalize_bbox_condition(Some(vec![1.0, 0.0, 3.0]))
.expect_err("invalid bbox should fail");
assert_eq!(error.status_code(), StatusCode::BAD_REQUEST);
}
#[test]
fn accepts_opaque_subscription_key_without_length_cap() {
let long_key = "a".repeat(300);
let normalized =
normalize_required_opaque_text(&format!(" {long_key} "), "subscriptionKey")
.expect("subscription key should be accepted");
assert_eq!(normalized, long_key);
}
#[test]
fn decodes_png_data_url() {
let data_url = format!(
"data:image/png;base64,{}",
BASE64_STANDARD.encode(b"\x89PNG\r\n\x1A\nrest")
);
let image = decode_image_data_url(&data_url, 1).expect("image should decode");
assert_eq!(image.mime_type, "image/png");
assert_eq!(image.file_name, "reference-01.png");
assert!(!image.bytes.is_empty());
}
#[test]
fn extracts_submit_response_from_nested_payload() {
let response = build_submit_response(
contract::Hyper3dGenerationMode::TextToModel,
json!({
"uuid": "task-1",
"subscription_key": "sub-1",
"jobs": [{ "uuid": "job-1" }],
"message": "submitted"
}),
)
.expect("submit response should build");
assert_eq!(response.task_uuid, "task-1");
assert_eq!(response.subscription_key, "sub-1");
assert_eq!(response.job_uuids, vec!["job-1"]);
}
#[test]
fn extracts_download_files_from_list() {
let files = extract_download_files(&json!({
"list": [
{ "name": "model.glb", "url": "https://cdn.example/model.glb" },
{ "name": "preview.png", "url": "https://cdn.example/preview.png" }
]
}));
assert_eq!(files.len(), 2);
assert_eq!(files[0].name, "model.glb");
}
#[test]
fn normalizes_status_values() {
assert_eq!(normalize_task_status("Waiting"), "waiting");
assert_eq!(normalize_task_status("Generating"), "generating");
assert_eq!(normalize_task_status("Done"), "done");
assert_eq!(normalize_task_status("Failed"), "failed");
}
}