Files
Genarrative/server-rs/crates/platform-audio/src/client.rs
2026-05-26 13:18:13 +08:00

256 lines
8.1 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use std::error::Error;
use reqwest::header;
use serde_json::Value;
use crate::response::{
extract_audio_urls, extract_string_by_path, find_first_string_by_key, normalize_task_status,
};
use crate::{
AudioError, AudioTaskKind, AudioTaskResponse, BackgroundMusicTaskRequest,
SoundEffectTaskRequest, VectorEngineAudioSettings, build_background_music_task_body,
build_sound_effect_task_body,
};
pub fn build_vector_engine_audio_http_client(
settings: &VectorEngineAudioSettings,
) -> Result<reqwest::Client, AudioError> {
reqwest::Client::builder()
.timeout(std::time::Duration::from_millis(
settings.request_timeout_ms.max(1),
))
.build()
.map_err(|error| {
AudioError::invalid_config(format!(
"构造 VectorEngine 音频生成 HTTP 客户端失败:{error}"
))
})
}
pub async fn submit_background_music_task(
http_client: &reqwest::Client,
settings: &VectorEngineAudioSettings,
request: BackgroundMusicTaskRequest,
) -> Result<AudioTaskResponse, AudioError> {
let body = build_background_music_task_body(request)?;
let response = post_vector_engine_json(
http_client,
settings,
AudioTaskKind::BackgroundMusic.submit_path(),
body,
"提交 Suno 背景音乐任务失败",
)
.await?;
let task_id = extract_string_by_path(&response, &["data"])
.or_else(|| find_first_string_by_key(&response, "task_id"))
.or_else(|| find_first_string_by_key(&response, "taskId"))
.ok_or_else(|| {
AudioError::missing_audio("提交 Suno 背景音乐任务失败:上游未返回任务 ID")
})?;
Ok(AudioTaskResponse {
kind: AudioTaskKind::BackgroundMusic,
task_id,
provider: AudioTaskKind::BackgroundMusic.provider().to_string(),
status: "submitted".to_string(),
})
}
pub async fn submit_sound_effect_task(
http_client: &reqwest::Client,
settings: &VectorEngineAudioSettings,
request: SoundEffectTaskRequest,
) -> Result<AudioTaskResponse, AudioError> {
let body = build_sound_effect_task_body(request)?;
let response = post_vector_engine_json(
http_client,
settings,
AudioTaskKind::SoundEffect.submit_path(),
body,
"提交 Vidu 音效任务失败",
)
.await?;
let task_id = find_first_string_by_key(&response, "task_id")
.or_else(|| find_first_string_by_key(&response, "taskId"))
.ok_or_else(|| AudioError::missing_audio("提交 Vidu 音效任务失败:上游未返回任务 ID"))?;
let status = find_first_string_by_key(&response, "state").unwrap_or_else(|| "created".into());
Ok(AudioTaskResponse {
kind: AudioTaskKind::SoundEffect,
task_id,
provider: AudioTaskKind::SoundEffect.provider().to_string(),
status,
})
}
async fn fetch_audio_task_payload(
http_client: &reqwest::Client,
settings: &VectorEngineAudioSettings,
kind: AudioTaskKind,
task_id: &str,
) -> Result<Value, AudioError> {
get_vector_engine_json(
http_client,
settings,
&kind.fetch_path(task_id),
match kind {
AudioTaskKind::BackgroundMusic => "查询 Suno 背景音乐任务失败",
AudioTaskKind::SoundEffect => "查询 Vidu 音效任务失败",
},
)
.await
}
pub async fn resolve_audio_task_download_urls(
http_client: &reqwest::Client,
settings: &VectorEngineAudioSettings,
kind: AudioTaskKind,
task_id: &str,
) -> Result<(String, Vec<String>), AudioError> {
let task_payload = fetch_audio_task_payload(http_client, settings, kind, task_id).await?;
let status = normalize_task_status(
find_first_string_by_key(&task_payload, "status")
.or_else(|| find_first_string_by_key(&task_payload, "state"))
.or_else(|| find_first_string_by_key(&task_payload, "Status"))
.as_deref()
.unwrap_or(""),
);
let mut audio_urls = extract_audio_urls(&task_payload);
if kind == AudioTaskKind::BackgroundMusic && audio_urls.is_empty() {
if let Some(clip_id) = extract_string_by_path(&task_payload, &["data"]).and_then(|value| {
if value.trim().is_empty() {
None
} else {
Some(value)
}
}) {
let wav_payload = get_vector_engine_json(
http_client,
settings,
&format!("/suno/act/wav/{}", urlencoding::encode(clip_id.as_str())),
"获取 Suno wav 音频失败",
)
.await?;
audio_urls = extract_audio_urls(&wav_payload);
}
}
Ok((status, audio_urls))
}
async fn get_vector_engine_json(
http_client: &reqwest::Client,
settings: &VectorEngineAudioSettings,
path: &str,
failure_context: &str,
) -> Result<Value, AudioError> {
let response = http_client
.get(format!(
"{}{}",
settings.base_url.trim_end_matches('/'),
path
))
.header(
header::AUTHORIZATION,
format!("Bearer {}", settings.api_key),
)
.header(header::ACCEPT, "application/json")
.send()
.await
.map_err(|error| map_reqwest_error(failure_context, path, error))?;
parse_vector_engine_response(response, failure_context).await
}
async fn post_vector_engine_json(
http_client: &reqwest::Client,
settings: &VectorEngineAudioSettings,
path: &str,
body: Value,
failure_context: &str,
) -> Result<Value, AudioError> {
let response = http_client
.post(format!(
"{}{}",
settings.base_url.trim_end_matches('/'),
path
))
.header(
header::AUTHORIZATION,
format!("Bearer {}", settings.api_key),
)
.header(header::ACCEPT, "application/json")
.header(header::CONTENT_TYPE, "application/json")
.json(&body)
.send()
.await
.map_err(|error| map_reqwest_error(failure_context, path, error))?;
parse_vector_engine_response(response, failure_context).await
}
async fn parse_vector_engine_response(
response: reqwest::Response,
failure_context: &str,
) -> Result<Value, AudioError> {
let status = response.status();
let raw_text = response.text().await.map_err(|error| {
AudioError::request(
format!("{failure_context}:读取响应失败:{error}"),
None,
false,
false,
false,
true,
Some(status.as_u16()),
None,
)
})?;
if !status.is_success() {
return Err(AudioError::upstream(
failure_context.to_string(),
status.as_u16(),
truncate_raw(raw_text.as_str()),
));
}
let payload = serde_json::from_str::<Value>(&raw_text).map_err(|error| {
AudioError::response_parse(
format!("{failure_context}:解析响应失败:{error}"),
truncate_raw(raw_text.as_str()),
)
})?;
if let Some(code) = payload.get("code").and_then(Value::as_str)
&& !matches!(
code.trim().to_ascii_lowercase().as_str(),
"success" | "succeeded" | "ok"
)
{
return Err(AudioError::upstream(
payload
.get("message")
.and_then(Value::as_str)
.unwrap_or(failure_context)
.to_string(),
status.as_u16(),
truncate_raw(raw_text.as_str()),
));
}
Ok(payload)
}
fn map_reqwest_error(failure_context: &str, endpoint: &str, error: reqwest::Error) -> AudioError {
AudioError::request(
format!("{failure_context}{error}"),
Some(endpoint.to_string()),
error.is_timeout(),
error.is_connect(),
error.is_request(),
error.is_body(),
error.status().map(|status| status.as_u16()),
Error::source(&error).map(ToString::to_string),
)
}
fn truncate_raw(raw_text: &str) -> String {
raw_text.chars().take(800).collect()
}