180 lines
6.0 KiB
Rust
180 lines
6.0 KiB
Rust
use reqwest::multipart;
|
|
use serde_json::json;
|
|
use shared_contracts::hyper3d as contract;
|
|
|
|
use crate::{
|
|
error::Hyper3dError,
|
|
request::{
|
|
build_common_submit_fields, build_submit_options_from_image,
|
|
build_submit_options_from_text, decode_image_data_urls, normalize_condition_mode,
|
|
normalize_optional_limited_text, normalize_required_opaque_text, normalize_required_text,
|
|
},
|
|
response::{
|
|
build_submit_response, extract_download_files, extract_job_statuses,
|
|
resolve_hyper3d_overall_status,
|
|
},
|
|
transport::{post_hyper3d_json, post_hyper3d_multipart},
|
|
types::{
|
|
HYPER3D_PROVIDER, Hyper3dSettings, MAX_IMAGE_COUNT, MAX_NEGATIVE_PROMPT_CHARS,
|
|
MAX_PROMPT_CHARS, RODIN_GEN2_TIER,
|
|
},
|
|
};
|
|
|
|
pub fn build_hyper3d_http_client(
|
|
settings: &Hyper3dSettings,
|
|
) -> Result<reqwest::Client, Hyper3dError> {
|
|
reqwest::Client::builder()
|
|
.timeout(std::time::Duration::from_millis(
|
|
settings.request_timeout_ms.max(1),
|
|
))
|
|
.build()
|
|
.map_err(|error| {
|
|
Hyper3dError::invalid_config(
|
|
"build_hyper3d_http_client",
|
|
format!("构造 Hyper3D HTTP 客户端失败:{error}"),
|
|
)
|
|
})
|
|
}
|
|
|
|
pub async fn submit_text_to_model(
|
|
state: &Hyper3dSettings,
|
|
payload: contract::Hyper3dTextToModelRequest,
|
|
) -> Result<contract::Hyper3dTaskSubmitResponse, Hyper3dError> {
|
|
let http_client = build_hyper3d_http_client(state)?;
|
|
let prompt = normalize_required_text(&payload.prompt, "prompt", MAX_PROMPT_CHARS)?;
|
|
let options = build_submit_options_from_text(&payload)?;
|
|
let mut form = multipart::Form::new()
|
|
.text("tier", RODIN_GEN2_TIER.to_string())
|
|
.text("prompt", prompt);
|
|
form = build_common_submit_fields(form, &options)?;
|
|
if let Some(negative_prompt) = normalize_optional_limited_text(
|
|
payload.negative_prompt.as_deref(),
|
|
MAX_NEGATIVE_PROMPT_CHARS,
|
|
)? {
|
|
form = form.text("negative_prompt", negative_prompt);
|
|
}
|
|
|
|
let response = post_hyper3d_multipart(
|
|
&http_client,
|
|
state,
|
|
"/rodin",
|
|
form,
|
|
"提交 Hyper3D 文生模型任务失败",
|
|
)
|
|
.await?;
|
|
|
|
build_submit_response(contract::Hyper3dGenerationMode::TextToModel, response)
|
|
}
|
|
|
|
pub async fn submit_image_to_model(
|
|
state: &Hyper3dSettings,
|
|
payload: contract::Hyper3dImageToModelRequest,
|
|
) -> Result<contract::Hyper3dTaskSubmitResponse, Hyper3dError> {
|
|
let http_client = build_hyper3d_http_client(state)?;
|
|
let options = build_submit_options_from_image(&payload)?;
|
|
let mut form = multipart::Form::new().text("tier", RODIN_GEN2_TIER.to_string());
|
|
form = build_common_submit_fields(form, &options)?;
|
|
let condition_mode = normalize_condition_mode(payload.condition_mode.as_deref())?;
|
|
form = form.text("condition_mode", condition_mode);
|
|
if let Some(prompt) =
|
|
normalize_optional_limited_text(payload.prompt.as_deref(), MAX_PROMPT_CHARS)?
|
|
{
|
|
form = form.text("prompt", prompt);
|
|
}
|
|
for image_url in payload
|
|
.image_urls
|
|
.iter()
|
|
.map(|value| value.trim())
|
|
.filter(|value| !value.is_empty())
|
|
{
|
|
form = form.text("image_urls", image_url.to_string());
|
|
}
|
|
for image in decode_image_data_urls(&payload.image_data_urls)? {
|
|
let part = multipart::Part::bytes(image.bytes)
|
|
.file_name(image.file_name)
|
|
.mime_str(&image.mime_type)
|
|
.map_err(|error| {
|
|
Hyper3dError::invalid_request(
|
|
Some("imageDataUrls"),
|
|
format!("构造图生模型图片字段失败:{error}"),
|
|
)
|
|
})?;
|
|
form = form.part("images", part);
|
|
}
|
|
|
|
if payload.image_data_urls.is_empty() && payload.image_urls.is_empty() {
|
|
return Err(Hyper3dError::invalid_request(
|
|
Some("imageDataUrls"),
|
|
"图生模型至少需要一张参考图",
|
|
));
|
|
}
|
|
if payload.image_data_urls.len() + payload.image_urls.len() > MAX_IMAGE_COUNT {
|
|
return Err(Hyper3dError::invalid_request(
|
|
Some("imageDataUrls"),
|
|
format!("图生模型最多支持 {} 张参考图", MAX_IMAGE_COUNT),
|
|
));
|
|
}
|
|
|
|
let response = post_hyper3d_multipart(
|
|
&http_client,
|
|
state,
|
|
"/rodin",
|
|
form,
|
|
"提交 Hyper3D 图生模型任务失败",
|
|
)
|
|
.await?;
|
|
|
|
build_submit_response(contract::Hyper3dGenerationMode::ImageToModel, response)
|
|
}
|
|
|
|
pub async fn query_task_status(
|
|
state: &Hyper3dSettings,
|
|
payload: contract::Hyper3dTaskStatusRequest,
|
|
) -> Result<contract::Hyper3dTaskStatusResponse, Hyper3dError> {
|
|
let http_client = build_hyper3d_http_client(state)?;
|
|
let subscription_key =
|
|
normalize_required_opaque_text(&payload.subscription_key, "subscriptionKey")?;
|
|
let response = post_hyper3d_json(
|
|
&http_client,
|
|
state,
|
|
"/status",
|
|
json!({ "subscription_key": subscription_key }),
|
|
"查询 Hyper3D 模型任务状态失败",
|
|
)
|
|
.await?;
|
|
|
|
let jobs = extract_job_statuses(&response);
|
|
let status = resolve_hyper3d_overall_status(&response, &jobs);
|
|
|
|
Ok(contract::Hyper3dTaskStatusResponse {
|
|
ok: true,
|
|
provider: HYPER3D_PROVIDER.to_string(),
|
|
status,
|
|
jobs,
|
|
raw: response,
|
|
})
|
|
}
|
|
|
|
pub async fn query_downloads(
|
|
state: &Hyper3dSettings,
|
|
payload: contract::Hyper3dDownloadRequest,
|
|
) -> Result<contract::Hyper3dDownloadResponse, Hyper3dError> {
|
|
let http_client = build_hyper3d_http_client(state)?;
|
|
let task_uuid = normalize_required_text(&payload.task_uuid, "taskUuid", 256)?;
|
|
let response = post_hyper3d_json(
|
|
&http_client,
|
|
state,
|
|
"/download",
|
|
json!({ "task_uuid": task_uuid }),
|
|
"获取 Hyper3D 模型下载列表失败",
|
|
)
|
|
.await?;
|
|
|
|
Ok(contract::Hyper3dDownloadResponse {
|
|
ok: true,
|
|
provider: HYPER3D_PROVIDER.to_string(),
|
|
files: extract_download_files(&response),
|
|
raw: response,
|
|
})
|
|
}
|