This commit is contained in:
2026-05-08 20:48:29 +08:00
parent abf1f1ebea
commit 94975e4735
82 changed files with 7786 additions and 1012 deletions

View File

@@ -145,6 +145,10 @@ use crate::{
begin_story_runtime_session, begin_story_session, continue_story,
get_story_runtime_projection, get_story_session_state, resolve_story_runtime_action,
},
vector_engine_audio_generation::{
create_visual_novel_background_music_task, create_visual_novel_sound_effect_task,
publish_visual_novel_background_music_asset, publish_visual_novel_sound_effect_asset,
},
visual_novel::{
compile_visual_novel_session, create_visual_novel_session, delete_visual_novel_work,
execute_visual_novel_action, get_visual_novel_run, get_visual_novel_session,
@@ -153,6 +157,10 @@ use crate::{
start_visual_novel_run, stream_visual_novel_action, stream_visual_novel_message,
submit_visual_novel_message, update_visual_novel_work,
},
volcengine_speech::{
get_volcengine_speech_config, stream_volcengine_asr, stream_volcengine_tts_bidirection,
stream_volcengine_tts_sse,
},
wechat_auth::{bind_wechat_phone, handle_wechat_callback, start_wechat_login},
};
@@ -312,6 +320,34 @@ pub fn build_router(state: AppState) -> Router {
require_bearer_auth,
)),
)
.route(
"/api/speech/volcengine/config",
get(get_volcengine_speech_config).route_layer(middleware::from_fn_with_state(
state.clone(),
require_bearer_auth,
)),
)
.route(
"/api/speech/volcengine/asr/stream",
get(stream_volcengine_asr).route_layer(middleware::from_fn_with_state(
state.clone(),
require_bearer_auth,
)),
)
.route(
"/api/speech/volcengine/tts/bidirection",
get(stream_volcengine_tts_bidirection).route_layer(middleware::from_fn_with_state(
state.clone(),
require_bearer_auth,
)),
)
.route(
"/api/speech/volcengine/tts/sse",
post(stream_volcengine_tts_sse).route_layer(middleware::from_fn_with_state(
state.clone(),
require_bearer_auth,
)),
)
.route(
"/api/runtime/chat/character/suggestions",
post(generate_runtime_character_chat_suggestions).route_layer(
@@ -1571,6 +1607,30 @@ fn visual_novel_router(state: AppState) -> Router<AppState> {
require_bearer_auth,
)),
)
.route(
"/api/creation/visual-novel/audio/background-music",
post(create_visual_novel_background_music_task).route_layer(
middleware::from_fn_with_state(state.clone(), require_bearer_auth),
),
)
.route(
"/api/creation/visual-novel/audio/background-music/{task_id}/asset",
post(publish_visual_novel_background_music_asset).route_layer(
middleware::from_fn_with_state(state.clone(), require_bearer_auth),
),
)
.route(
"/api/creation/visual-novel/audio/sound-effect",
post(create_visual_novel_sound_effect_task).route_layer(
middleware::from_fn_with_state(state.clone(), require_bearer_auth),
),
)
.route(
"/api/creation/visual-novel/audio/sound-effect/{task_id}/asset",
post(publish_visual_novel_sound_effect_asset).route_layer(
middleware::from_fn_with_state(state.clone(), require_bearer_auth),
),
)
.route(
"/api/runtime/visual-novel/gallery",
get(list_visual_novel_gallery),

View File

@@ -1458,7 +1458,7 @@ mod tests {
endpoint: &str,
signed_at: time::OffsetDateTime,
) -> Result<String, Box<dyn Error>> {
let date = signed_at.date().to_string().replace('-', "");
let date = format_oss_v4_signature_scope_date(signed_at);
let region = endpoint
.trim()
.split('.')
@@ -1470,17 +1470,22 @@ mod tests {
}
fn build_oss_v4_signature_date(signed_at: time::OffsetDateTime) -> String {
let date = signed_at.date().to_string().replace('-', "");
let time = signed_at
.time()
.to_string()
.split('.')
.next()
.unwrap_or("00:00:00")
.replace(':', "");
format!(
"{}T{:02}{:02}{:02}Z",
format_oss_v4_signature_scope_date(signed_at),
signed_at.hour(),
signed_at.minute(),
signed_at.second()
)
}
debug_assert_eq!(time.len(), 6);
format!("{date}T{time}Z")
fn format_oss_v4_signature_scope_date(signed_at: time::OffsetDateTime) -> String {
format!(
"{:04}{:02}{:02}",
signed_at.year(),
signed_at.month() as u8,
signed_at.day()
)
}
fn build_oss_v4_canonical_uri(bucket: &str, object_key: Option<&str>) -> String {

View File

@@ -4,6 +4,11 @@ use platform_llm::{
DEFAULT_ARK_BASE_URL, DEFAULT_MAX_RETRIES, DEFAULT_REQUEST_TIMEOUT_MS,
DEFAULT_RETRY_BACKOFF_MS, LlmProvider,
};
use platform_speech::{
DEFAULT_ASR_RESOURCE_ID, DEFAULT_ASR_WS_URL,
DEFAULT_REQUEST_TIMEOUT_MS as DEFAULT_SPEECH_REQUEST_TIMEOUT_MS,
DEFAULT_TTS_BIDIRECTION_WS_URL, DEFAULT_TTS_RESOURCE_ID, DEFAULT_TTS_SSE_URL,
};
const DEFAULT_INTERNAL_API_SECRET: &str = "genarrative-dev-internal-bridge";
const DEFAULT_AUTH_STORE_PATH: &str = "server-rs/.data/auth-store.json";
@@ -92,6 +97,18 @@ pub struct AppConfig {
pub apimart_base_url: String,
pub apimart_api_key: Option<String>,
pub apimart_image_request_timeout_ms: u64,
pub vector_engine_base_url: String,
pub vector_engine_api_key: Option<String>,
pub vector_engine_audio_request_timeout_ms: u64,
pub volcengine_speech_api_key: Option<String>,
pub volcengine_speech_app_id: Option<String>,
pub volcengine_speech_access_key: Option<String>,
pub volcengine_speech_asr_resource_id: String,
pub volcengine_speech_tts_resource_id: String,
pub volcengine_speech_asr_ws_url: String,
pub volcengine_speech_tts_bidirection_ws_url: String,
pub volcengine_speech_tts_sse_url: String,
pub volcengine_speech_request_timeout_ms: u64,
pub draft_asset_generation_max_concurrent_requests: usize,
pub ark_character_video_base_url: String,
pub ark_character_video_api_key: Option<String>,
@@ -187,6 +204,18 @@ impl Default for AppConfig {
apimart_base_url: String::new(),
apimart_api_key: None,
apimart_image_request_timeout_ms: 180_000,
vector_engine_base_url: String::new(),
vector_engine_api_key: None,
vector_engine_audio_request_timeout_ms: 180_000,
volcengine_speech_api_key: None,
volcengine_speech_app_id: None,
volcengine_speech_access_key: None,
volcengine_speech_asr_resource_id: DEFAULT_ASR_RESOURCE_ID.to_string(),
volcengine_speech_tts_resource_id: DEFAULT_TTS_RESOURCE_ID.to_string(),
volcengine_speech_asr_ws_url: DEFAULT_ASR_WS_URL.to_string(),
volcengine_speech_tts_bidirection_ws_url: DEFAULT_TTS_BIDIRECTION_WS_URL.to_string(),
volcengine_speech_tts_sse_url: DEFAULT_TTS_SSE_URL.to_string(),
volcengine_speech_request_timeout_ms: DEFAULT_SPEECH_REQUEST_TIMEOUT_MS,
draft_asset_generation_max_concurrent_requests: 4,
ark_character_video_base_url: String::new(),
ark_character_video_api_key: None,
@@ -544,6 +573,54 @@ impl AppConfig {
config.apimart_image_request_timeout_ms = apimart_image_request_timeout_ms;
}
if let Some(vector_engine_base_url) = read_first_non_empty_env(&["VECTOR_ENGINE_BASE_URL"])
{
config.vector_engine_base_url = vector_engine_base_url;
}
config.vector_engine_api_key = read_first_non_empty_env(&["VECTOR_ENGINE_API_KEY"]);
if let Some(vector_engine_audio_request_timeout_ms) =
read_first_positive_u64_env(&["VECTOR_ENGINE_AUDIO_REQUEST_TIMEOUT_MS"])
{
config.vector_engine_audio_request_timeout_ms = vector_engine_audio_request_timeout_ms;
}
config.volcengine_speech_api_key =
read_first_non_empty_env(&["VOLCENGINE_SPEECH_API_KEY", "VOLCENGINE_API_KEY"]);
config.volcengine_speech_app_id =
read_first_non_empty_env(&["VOLCENGINE_SPEECH_APP_ID", "VOLCENGINE_ACCESS_KEY_ID"]);
config.volcengine_speech_access_key = read_first_non_empty_env(&[
"VOLCENGINE_SPEECH_ACCESS_KEY",
"VOLCENGINE_SECRET_ACCESS_KEY",
]);
if let Some(asr_resource_id) =
read_first_non_empty_env(&["VOLCENGINE_SPEECH_ASR_RESOURCE_ID"])
{
config.volcengine_speech_asr_resource_id = asr_resource_id;
}
if let Some(tts_resource_id) =
read_first_non_empty_env(&["VOLCENGINE_SPEECH_TTS_RESOURCE_ID"])
{
config.volcengine_speech_tts_resource_id = tts_resource_id;
}
if let Some(asr_ws_url) = read_first_non_empty_env(&["VOLCENGINE_SPEECH_ASR_WS_URL"]) {
config.volcengine_speech_asr_ws_url = asr_ws_url;
}
if let Some(tts_bidirection_ws_url) =
read_first_non_empty_env(&["VOLCENGINE_SPEECH_TTS_BIDIRECTION_WS_URL"])
{
config.volcengine_speech_tts_bidirection_ws_url = tts_bidirection_ws_url;
}
if let Some(tts_sse_url) = read_first_non_empty_env(&["VOLCENGINE_SPEECH_TTS_SSE_URL"]) {
config.volcengine_speech_tts_sse_url = tts_sse_url;
}
if let Some(request_timeout_ms) =
read_first_positive_u64_env(&["VOLCENGINE_SPEECH_REQUEST_TIMEOUT_MS"])
{
config.volcengine_speech_request_timeout_ms = request_timeout_ms;
}
if let Some(max_concurrent_requests) = read_first_usize_env(&[
"GENARRATIVE_DRAFT_ASSET_GENERATION_MAX_CONCURRENT_REQUESTS",
"DRAFT_ASSET_GENERATION_MAX_CONCURRENT_REQUESTS",
@@ -831,6 +908,7 @@ mod tests {
assert!(config.llm_model.is_empty());
assert!(config.llm_base_url.is_empty());
assert!(config.apimart_base_url.is_empty());
assert!(config.vector_engine_base_url.is_empty());
assert!(config.ark_character_video_base_url.is_empty());
assert!(config.ark_character_video_model.is_empty());
assert!(config.dashscope_scene_image_model.is_empty());
@@ -859,6 +937,7 @@ mod tests {
std::env::remove_var("GENARRATIVE_LLM_BASE_URL");
std::env::remove_var("GENARRATIVE_LLM_MODEL");
std::env::remove_var("APIMART_BASE_URL");
std::env::remove_var("VECTOR_ENGINE_BASE_URL");
std::env::remove_var("DASHSCOPE_SCENE_IMAGE_MODEL");
std::env::remove_var("DASHSCOPE_REFERENCE_IMAGE_MODEL");
std::env::remove_var("DASHSCOPE_COVER_IMAGE_MODEL");
@@ -871,6 +950,7 @@ mod tests {
);
std::env::set_var("GENARRATIVE_LLM_MODEL", "internal-text-model");
std::env::set_var("APIMART_BASE_URL", "https://image.internal.example/v1");
std::env::set_var("VECTOR_ENGINE_BASE_URL", "https://audio.internal.example");
std::env::set_var("DASHSCOPE_SCENE_IMAGE_MODEL", "scene-model");
std::env::set_var("DASHSCOPE_REFERENCE_IMAGE_MODEL", "reference-model");
std::env::set_var("DASHSCOPE_COVER_IMAGE_MODEL", "cover-model");
@@ -886,6 +966,10 @@ mod tests {
assert_eq!(config.llm_base_url, "https://llm.internal.example/v1");
assert_eq!(config.llm_model, "internal-text-model");
assert_eq!(config.apimart_base_url, "https://image.internal.example/v1");
assert_eq!(
config.vector_engine_base_url,
"https://audio.internal.example"
);
assert_eq!(config.dashscope_scene_image_model, "scene-model");
assert_eq!(config.dashscope_reference_image_model, "reference-model");
assert_eq!(config.dashscope_cover_image_model, "cover-model");
@@ -900,6 +984,7 @@ mod tests {
std::env::remove_var("GENARRATIVE_LLM_BASE_URL");
std::env::remove_var("GENARRATIVE_LLM_MODEL");
std::env::remove_var("APIMART_BASE_URL");
std::env::remove_var("VECTOR_ENGINE_BASE_URL");
std::env::remove_var("DASHSCOPE_SCENE_IMAGE_MODEL");
std::env::remove_var("DASHSCOPE_REFERENCE_IMAGE_MODEL");
std::env::remove_var("DASHSCOPE_COVER_IMAGE_MODEL");

View File

@@ -1,2 +1,3 @@
pub(crate) const RPG_STORY_LLM_MODEL: &str = "doubao-seed-character-251128";
pub(crate) const CREATION_TEMPLATE_LLM_MODEL: &str = "deepseek-v3-2-251201";
pub(crate) const PUZZLE_LEVEL_NAME_VISION_LLM_MODEL: &str = "gpt-4o-mini";

View File

@@ -68,7 +68,9 @@ mod square_hole_agent_turn;
mod state;
mod story_battles;
mod story_sessions;
mod vector_engine_audio_generation;
mod visual_novel;
mod volcengine_speech;
mod wechat_auth;
mod wechat_provider;
mod work_author;

View File

@@ -177,6 +177,7 @@ pub(crate) fn build_openai_image_request_body(
Value::String(build_prompt_with_negative(prompt, negative_prompt)),
),
("n".to_string(), json!(candidate_count.clamp(1, 4))),
("official_fallback".to_string(), Value::Bool(true)),
(
"size".to_string(),
Value::String(normalize_image_size(size)),
@@ -613,6 +614,7 @@ mod tests {
assert_eq!(body["model"], GPT_IMAGE_2_MODEL);
assert_eq!(body["size"], "16:9");
assert_eq!(body["n"], 2);
assert_eq!(body["official_fallback"], true);
assert_eq!(body["image_urls"][0], "data:image/png;base64,abcd");
assert!(body["prompt"].as_str().unwrap_or_default().contains("避免"));
}

View File

@@ -3,7 +3,7 @@
/// 模型只负责把画面描述压缩成可直接展示的中文关卡名;写回草稿和作品卡由业务路由处理。
pub(crate) const PUZZLE_FIRST_LEVEL_NAME_SYSTEM_PROMPT: &str = r#"你是一个中文拼图关卡命名编辑。
你会收到拼图第一关的画面描述。请生成 1 个适合直接展示在游戏关卡卡片上的中文关卡名。
你会收到拼图第一关的画面描述,部分请求还会附带已经生成完成的正式图片。请综合图片内容和画面描述,生成 1 个适合直接展示在游戏关卡卡片上的中文关卡名。
硬约束:
1. 只输出 JSON不要输出 Markdown、解释或代码块。
@@ -21,6 +21,13 @@ pub(crate) fn build_puzzle_first_level_name_user_prompt(picture_description: &st
)
}
pub(crate) fn build_puzzle_first_level_name_vision_user_text(picture_description: &str) -> String {
format!(
"画面描述:{picture_description}\n\n请观察随消息附带的正式拼图图片,生成第一关关卡名。",
picture_description = picture_description.trim(),
)
}
#[cfg(test)]
mod tests {
use super::*;
@@ -32,4 +39,12 @@ mod tests {
assert!(prompt.contains("画面描述:一只猫在雨夜灯牌下回头。"));
assert!(prompt.contains("第一关关卡名"));
}
#[test]
fn level_name_vision_prompt_mentions_generated_image() {
let prompt = build_puzzle_first_level_name_vision_user_text("一只猫在雨夜灯牌下回头。");
assert!(prompt.contains("画面描述:一只猫在雨夜灯牌下回头。"));
assert!(prompt.contains("正式拼图图片"));
}
}

View File

@@ -13,12 +13,13 @@ use axum::{
},
};
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
use image::ImageFormat;
use module_assets::{
AssetObjectAccessPolicy, AssetObjectFieldError, build_asset_entity_binding_input,
build_asset_object_upsert_input, generate_asset_binding_id, generate_asset_object_id,
};
use module_puzzle::{PuzzleGeneratedImageCandidate, PuzzleRuntimeLevelStatus};
use platform_llm::{LlmMessage, LlmTextRequest};
use platform_llm::{LlmMessage, LlmMessageContentPart, LlmTextRequest};
use platform_oss::{
LegacyAssetPrefix, OssHeadObjectRequest, OssObjectAccess, OssPutObjectRequest,
OssSignedGetObjectUrlRequest,
@@ -78,7 +79,7 @@ use crate::{
},
auth::AuthenticatedAccessToken,
http_error::AppError,
llm_model_routing::CREATION_TEMPLATE_LLM_MODEL,
llm_model_routing::{CREATION_TEMPLATE_LLM_MODEL, PUZZLE_LEVEL_NAME_VISION_LLM_MODEL},
platform_errors::map_oss_error,
prompt::puzzle::{
draft::{
@@ -88,6 +89,7 @@ use crate::{
image::{PUZZLE_DEFAULT_NEGATIVE_PROMPT, build_puzzle_image_prompt},
level_name::{
PUZZLE_FIRST_LEVEL_NAME_SYSTEM_PROMPT, build_puzzle_first_level_name_user_prompt,
build_puzzle_first_level_name_vision_user_text,
},
tags::{PUZZLE_TAG_GENERATION_SYSTEM_PROMPT, build_puzzle_tag_generation_user_prompt},
},
@@ -112,6 +114,7 @@ const PUZZLE_ENTITY_KIND: &str = "puzzle_work";
const PUZZLE_GENERATED_IMAGE_SIZE: &str = "1024*1024";
const PUZZLE_APIMART_GENERATED_IMAGE_SIZE: &str = "1:1";
const PUZZLE_APIMART_GEMINI_RESOLUTION: &str = "1K";
const PUZZLE_LEVEL_NAME_VISION_IMAGE_MAX_SIDE: u32 = 768;
pub async fn create_puzzle_agent_session(
State(state): State<AppState>,
@@ -204,7 +207,8 @@ pub async fn generate_puzzle_onboarding_work(
PUZZLE_AGENT_API_BASE_PROVIDER,
map_puzzle_generation_endpoint_error(error),
)
})?;
})?
.into_records();
let selected = candidates.first().cloned().ok_or_else(|| {
puzzle_error_response(
&request_context,
@@ -864,8 +868,9 @@ pub async fn execute_puzzle_agent_action(
if let Some(levels_json) = levels_json.as_ref() {
draft.levels = parse_puzzle_level_records_from_module_json(levels_json)?;
}
let target_level =
let mut target_level =
select_puzzle_level_for_api(&draft, target_level_id.as_deref())?;
let fallback_level_name = target_level.level_name.clone();
let prompt = resolve_puzzle_level_image_prompt(
payload.prompt_text.as_deref(),
&target_level.picture_description,
@@ -886,10 +891,32 @@ pub async fn execute_puzzle_agent_action(
)
.await
.map_err(map_puzzle_generation_endpoint_error)?;
if candidates.is_empty() {
return Err(AppError::from_status(StatusCode::BAD_GATEWAY).with_details(
json!({
"provider": PUZZLE_AGENT_API_BASE_PROVIDER,
"message": "拼图候选图生成结果为空",
}),
));
}
if let Some(refined_level_name) = generate_puzzle_first_level_name_from_image(
&state,
target_level.picture_description.as_str(),
&candidates[0].downloaded_image,
)
.await
{
target_level.level_name = refined_level_name;
}
let generated_level_name = target_level.level_name.clone();
let levels_json_with_generated_name =
Some(serialize_puzzle_level_records_for_module(
&build_puzzle_levels_with_primary_name(&draft, &target_level),
)?);
let candidates_json = serde_json::to_string(
&candidates
.iter()
.map(to_puzzle_generated_image_candidate)
.map(|candidate| to_puzzle_generated_image_candidate(&candidate.record))
.collect::<Vec<_>>(),
)
.map_err(|error| {
@@ -904,7 +931,7 @@ pub async fn execute_puzzle_agent_action(
session_id: session.session_id.clone(),
owner_user_id: owner_user_id.clone(),
level_id: Some(target_level.level_id.clone()),
levels_json,
levels_json: levels_json_with_generated_name,
candidates_json,
saved_at_micros: now,
})
@@ -925,9 +952,15 @@ pub async fn execute_puzzle_agent_action(
let fallback_session =
replace_puzzle_session_draft_snapshot(session, draft, now);
Ok(apply_generated_puzzle_candidates_to_session_snapshot(
fallback_session,
apply_generated_puzzle_first_level_name_to_session_snapshot(
fallback_session,
target_level.level_id.as_str(),
generated_level_name.as_str(),
fallback_level_name.as_str(),
now,
),
target_level.level_id.as_str(),
candidates,
candidates.into_records(),
now,
))
}
@@ -2830,6 +2863,91 @@ async fn generate_puzzle_first_level_name(state: &AppState, picture_description:
build_fallback_puzzle_first_level_name(picture_description)
}
async fn generate_puzzle_first_level_name_from_image(
state: &AppState,
picture_description: &str,
image: &PuzzleDownloadedImage,
) -> Option<String> {
let Some(llm_client) = state.creative_agent_gpt5_client() else {
return None;
};
let Some(image_data_url) = build_puzzle_level_name_image_data_url(image) else {
tracing::warn!(
provider = PUZZLE_AGENT_API_BASE_PROVIDER,
picture_chars = picture_description.chars().count(),
"拼图首关名图片输入压缩失败,保留文本关卡名"
);
return None;
};
let user_text = build_puzzle_first_level_name_vision_user_text(picture_description);
let response = llm_client
.request_text(
LlmTextRequest::new(vec![
LlmMessage::system(PUZZLE_FIRST_LEVEL_NAME_SYSTEM_PROMPT),
LlmMessage::user_multimodal(vec![
LlmMessageContentPart::InputText { text: user_text },
LlmMessageContentPart::InputImage {
image_url: image_data_url,
},
]),
])
.with_model(PUZZLE_LEVEL_NAME_VISION_LLM_MODEL)
.with_max_tokens(80),
)
.await;
match response {
Ok(response) => {
parse_puzzle_first_level_name_from_text(response.content.as_str()).or_else(|| {
tracing::warn!(
provider = PUZZLE_AGENT_API_BASE_PROVIDER,
model = PUZZLE_LEVEL_NAME_VISION_LLM_MODEL,
picture_chars = picture_description.chars().count(),
"拼图首关名视觉模型返回非法,保留文本关卡名"
);
None
})
}
Err(error) => {
tracing::warn!(
provider = PUZZLE_AGENT_API_BASE_PROVIDER,
model = PUZZLE_LEVEL_NAME_VISION_LLM_MODEL,
picture_chars = picture_description.chars().count(),
error = %error,
"拼图首关名视觉生成失败,保留文本关卡名"
);
None
}
}
}
fn build_puzzle_level_name_image_data_url(image: &PuzzleDownloadedImage) -> Option<String> {
let bytes = resize_puzzle_level_name_image_bytes(image.bytes.as_slice())
.unwrap_or_else(|| image.bytes.clone());
let mime_type = if bytes.starts_with(b"\x89PNG\r\n\x1A\n") {
"image/png"
} else {
image.mime_type.as_str()
};
Some(format!(
"data:{};base64,{}",
normalize_puzzle_downloaded_image_mime_type(mime_type),
BASE64_STANDARD.encode(bytes)
))
}
fn resize_puzzle_level_name_image_bytes(bytes: &[u8]) -> Option<Vec<u8>> {
let image = image::load_from_memory(bytes).ok()?;
let resized = image.resize(
PUZZLE_LEVEL_NAME_VISION_IMAGE_MAX_SIDE,
PUZZLE_LEVEL_NAME_VISION_IMAGE_MAX_SIDE,
image::imageops::FilterType::Triangle,
);
let mut cursor = std::io::Cursor::new(Vec::new());
resized.write_to(&mut cursor, ImageFormat::Png).ok()?;
Some(cursor.into_inner())
}
fn parse_puzzle_first_level_name_from_text(text: &str) -> Option<String> {
let trimmed = text.trim();
let json_text = if let Some(start) = trimmed.find('{')
@@ -2985,9 +3103,6 @@ async fn compile_puzzle_draft_with_initial_cover(
let generated_level_name =
generate_puzzle_first_level_name(state, &target_level.picture_description).await;
target_level.level_name = generated_level_name.clone();
let levels_json_with_generated_name = Some(serialize_puzzle_level_records_for_module(
&build_puzzle_levels_with_primary_name(&draft, &target_level),
)?);
let image_prompt = resolve_puzzle_draft_cover_prompt(
prompt_text,
&target_level.picture_description,
@@ -3008,19 +3123,32 @@ async fn compile_puzzle_draft_with_initial_cover(
.await?;
let selected_candidate_id = candidates
.iter()
.find(|candidate| candidate.selected)
.find(|candidate| candidate.record.selected)
.or_else(|| candidates.first())
.map(|candidate| candidate.candidate_id.clone())
.map(|candidate| candidate.record.candidate_id.clone())
.ok_or_else(|| {
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
"provider": PUZZLE_AGENT_API_BASE_PROVIDER,
"message": "拼图候选图生成结果为空",
}))
})?;
if let Some(refined_level_name) = generate_puzzle_first_level_name_from_image(
state,
target_level.picture_description.as_str(),
&candidates[0].downloaded_image,
)
.await
{
target_level.level_name = refined_level_name;
}
let generated_level_name = target_level.level_name.clone();
let levels_json_with_generated_name = Some(serialize_puzzle_level_records_for_module(
&build_puzzle_levels_with_primary_name(&draft, &target_level),
)?);
let candidates_json = serde_json::to_string(
&candidates
.iter()
.map(to_puzzle_generated_image_candidate)
.map(|candidate| to_puzzle_generated_image_candidate(&candidate.record))
.collect::<Vec<_>>(),
)
.map_err(|error| {
@@ -3061,7 +3189,7 @@ async fn compile_puzzle_draft_with_initial_cover(
now,
),
target_level.level_id.as_str(),
candidates.clone(),
candidates.into_records(),
now,
);
Ok((session, true))
@@ -3138,9 +3266,6 @@ async fn compile_puzzle_draft_with_uploaded_cover(
let generated_level_name =
generate_puzzle_first_level_name(state, &target_level.picture_description).await;
target_level.level_name = generated_level_name.clone();
let levels_json_with_generated_name = Some(serialize_puzzle_level_records_for_module(
&build_puzzle_levels_with_primary_name(&draft, &target_level),
)?);
let image_prompt = resolve_puzzle_draft_cover_prompt(
prompt_text,
&target_level.picture_description,
@@ -3152,6 +3277,24 @@ async fn compile_puzzle_draft_with_uploaded_cover(
compiled_session.session_id,
target_level.candidates.len() + 1
);
let uploaded_downloaded_image = PuzzleDownloadedImage {
extension: puzzle_mime_to_extension(uploaded_image.mime_type.as_str()).to_string(),
mime_type: normalize_puzzle_downloaded_image_mime_type(uploaded_image.mime_type.as_str()),
bytes: uploaded_image.bytes,
};
if let Some(refined_level_name) = generate_puzzle_first_level_name_from_image(
state,
target_level.picture_description.as_str(),
&uploaded_downloaded_image,
)
.await
{
target_level.level_name = refined_level_name;
}
let generated_level_name = target_level.level_name.clone();
let levels_json_with_generated_name = Some(serialize_puzzle_level_records_for_module(
&build_puzzle_levels_with_primary_name(&draft, &target_level),
)?);
let persisted_upload = persist_puzzle_generated_asset(
state,
owner_user_id.as_str(),
@@ -3159,13 +3302,7 @@ async fn compile_puzzle_draft_with_uploaded_cover(
&target_level.level_name,
candidate_id.as_str(),
"uploaded-direct",
PuzzleDownloadedImage {
extension: puzzle_mime_to_extension(uploaded_image.mime_type.as_str()).to_string(),
mime_type: normalize_puzzle_downloaded_image_mime_type(
uploaded_image.mime_type.as_str(),
),
bytes: uploaded_image.bytes,
},
uploaded_downloaded_image,
current_utc_micros(),
)
.await?;
@@ -3865,7 +4002,7 @@ async fn generate_puzzle_image_candidates(
image_model: Option<&str>,
candidate_count: u32,
candidate_start_index: usize,
) -> Result<Vec<PuzzleGeneratedImageCandidateRecord>, AppError> {
) -> Result<Vec<GeneratedPuzzleImageCandidate>, AppError> {
let count = candidate_count.clamp(1, 1);
let resolved_model = resolve_puzzle_image_model(image_model);
let actual_prompt = build_puzzle_image_prompt(level_name, prompt);
@@ -3914,6 +4051,7 @@ async fn generate_puzzle_image_candidates(
"{session_id}-candidate-{}",
candidate_start_index + index + 1
);
let downloaded_image = image.clone();
let asset = persist_puzzle_generated_asset(
state,
owner_user_id,
@@ -3926,30 +4064,22 @@ async fn generate_puzzle_image_candidates(
)
.await
.map_err(map_puzzle_generation_endpoint_error)?;
items.push(PuzzleGeneratedImageCandidateResponse {
candidate_id,
image_src: asset.image_src,
asset_id: asset.asset_id,
prompt: prompt.to_string(),
actual_prompt: Some(actual_prompt.clone()),
source_type: resolved_model.candidate_source_type().to_string(),
// 单图生成结果总是直接成为当前正式图。
selected: index == 0,
items.push(GeneratedPuzzleImageCandidate {
record: PuzzleGeneratedImageCandidateRecord {
candidate_id,
image_src: asset.image_src,
asset_id: asset.asset_id,
prompt: prompt.to_string(),
actual_prompt: Some(actual_prompt.clone()),
source_type: resolved_model.candidate_source_type().to_string(),
// 单图生成结果总是直接成为当前正式图。
selected: index == 0,
},
downloaded_image,
});
}
Ok(items
.into_iter()
.map(|candidate| PuzzleGeneratedImageCandidateRecord {
candidate_id: candidate.candidate_id,
image_src: candidate.image_src,
asset_id: candidate.asset_id,
prompt: candidate.prompt,
actual_prompt: candidate.actual_prompt,
source_type: candidate.source_type,
selected: candidate.selected,
})
.collect())
Ok(items)
}
#[cfg(test)]
@@ -3977,6 +4107,7 @@ mod tests {
assert_eq!(body["size"], PUZZLE_APIMART_GENERATED_IMAGE_SIZE);
assert_eq!(body["resolution"], PUZZLE_APIMART_GEMINI_RESOLUTION);
assert_eq!(body["n"], 1);
assert_eq!(body["official_fallback"], true);
assert_eq!(body["image_urls"][0], "data:image/png;base64,abcd");
assert!(
body["prompt"]
@@ -4014,6 +4145,7 @@ mod tests {
prompt_text: None,
reference_image_src: None,
image_model: Some(PUZZLE_IMAGE_MODEL_GPT_IMAGE_2.to_string()),
ai_redraw: None,
candidate_count: Some(1),
candidate_id: None,
level_id: Some("puzzle-level-1".to_string()),
@@ -4073,6 +4205,26 @@ mod tests {
);
}
#[test]
fn puzzle_level_name_image_data_url_downsizes_generated_image() {
let image = image::DynamicImage::ImageRgb8(image::RgbImage::new(4, 4));
let mut cursor = std::io::Cursor::new(Vec::new());
image
.write_to(&mut cursor, ImageFormat::Png)
.expect("test image should encode");
let downloaded = PuzzleDownloadedImage {
extension: "png".to_string(),
mime_type: "image/png".to_string(),
bytes: cursor.into_inner(),
};
let data_url = build_puzzle_level_name_image_data_url(&downloaded)
.expect("data url should be generated");
assert!(data_url.starts_with("data:image/png;base64,"));
assert!(data_url.len() > "data:image/png;base64,".len());
}
#[test]
fn puzzle_first_level_name_snapshot_defaults_work_title() {
let levels_json = serde_json::to_string(&vec![json!({
@@ -4091,6 +4243,7 @@ mod tests {
prompt_text: None,
reference_image_src: None,
image_model: Some(PUZZLE_IMAGE_MODEL_GPT_IMAGE_2.to_string()),
ai_redraw: None,
candidate_count: Some(1),
candidate_id: None,
level_id: Some("puzzle-level-1".to_string()),
@@ -4181,6 +4334,30 @@ struct PuzzleGeneratedImages {
images: Vec<PuzzleDownloadedImage>,
}
struct GeneratedPuzzleImageCandidate {
record: PuzzleGeneratedImageCandidateRecord,
downloaded_image: PuzzleDownloadedImage,
}
impl GeneratedPuzzleImageCandidate {
fn into_record(self) -> PuzzleGeneratedImageCandidateRecord {
self.record
}
}
trait GeneratedPuzzleImageCandidatesExt {
fn into_records(self) -> Vec<PuzzleGeneratedImageCandidateRecord>;
}
impl GeneratedPuzzleImageCandidatesExt for Vec<GeneratedPuzzleImageCandidate> {
fn into_records(self) -> Vec<PuzzleGeneratedImageCandidateRecord> {
self.into_iter()
.map(GeneratedPuzzleImageCandidate::into_record)
.collect()
}
}
#[derive(Clone)]
struct PuzzleDownloadedImage {
extension: String,
mime_type: String,
@@ -4361,6 +4538,7 @@ fn build_puzzle_apimart_image_request_body(
Value::String(build_puzzle_apimart_prompt(prompt, negative_prompt)),
),
("n".to_string(), json!(candidate_count.clamp(1, 1))),
("official_fallback".to_string(), Value::Bool(true)),
("size".to_string(), Value::String(size.to_string())),
]);
body.insert(

View File

@@ -787,7 +787,8 @@ fn build_creative_agent_gpt5_client(
config.apimart_image_request_timeout_ms,
0,
config.llm_retry_backoff_ms,
)?;
)?
.with_official_fallback(true);
Ok(Some(LlmClient::new(llm_config)?))
}
@@ -888,5 +889,6 @@ mod tests {
client.config().responses_url(),
"https://api.apimart.test/v1/responses"
);
assert!(client.config().official_fallback());
}
}

View File

@@ -0,0 +1,973 @@
use std::{collections::BTreeMap, time::Duration};
use axum::{
Json,
extract::{Path, State, rejection::JsonRejection},
http::StatusCode,
response::Response,
};
use module_assets::{
AssetObjectAccessPolicy, build_asset_entity_binding_input, build_asset_object_upsert_input,
generate_asset_binding_id, generate_asset_object_id,
};
use platform_oss::{LegacyAssetPrefix, OssObjectAccess, OssPutObjectRequest};
use reqwest::header;
use serde_json::{Map, Value, json};
use shared_contracts::visual_novel as contract;
use crate::{
api_response::json_success_body, auth::AuthenticatedAccessToken, http_error::AppError,
platform_errors::map_oss_error, request_context::RequestContext, state::AppState,
};
const VECTOR_ENGINE_PROVIDER: &str = "vector-engine";
const VECTOR_ENGINE_SUNO_PROVIDER: &str = "vector-engine-suno";
const VECTOR_ENGINE_VIDU_PROVIDER: &str = "vector-engine-vidu";
const SUNO_DEFAULT_MODEL: &str = "chirp-v4";
const VIDU_AUDIO_MODEL: &str = "audio1.0";
const AUDIO_ENTITY_KIND: &str = "visual_novel_scene";
const MUSIC_ASSET_KIND: &str = "visual_novel_music";
const AMBIENT_SOUND_ASSET_KIND: &str = "visual_novel_ambient_sound";
const MUSIC_SLOT: &str = "music";
const AMBIENT_SOUND_SLOT: &str = "ambient_sound";
const SUNO_PROMPT_MAX_CHARS: usize = 5_000;
const SUNO_TITLE_MAX_CHARS: usize = 80;
const SUNO_TAGS_MAX_CHARS: usize = 160;
const VIDU_PROMPT_MAX_CHARS: usize = 1_500;
const DEFAULT_SOUND_EFFECT_DURATION_SECONDS: u8 = 5;
const MAX_GENERATED_AUDIO_BYTES: usize = 40 * 1024 * 1024;
#[derive(Clone, Debug)]
struct VectorEngineAudioSettings {
base_url: String,
api_key: String,
request_timeout_ms: u64,
}
#[derive(Clone, Debug)]
struct DownloadedAudio {
bytes: Vec<u8>,
mime_type: String,
extension: String,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum AudioAssetSlot {
BackgroundMusic,
SoundEffect,
}
impl AudioAssetSlot {
fn contract_kind(self) -> contract::VisualNovelAudioGenerationKind {
match self {
Self::BackgroundMusic => contract::VisualNovelAudioGenerationKind::BackgroundMusic,
Self::SoundEffect => contract::VisualNovelAudioGenerationKind::SoundEffect,
}
}
fn provider(self) -> &'static str {
match self {
Self::BackgroundMusic => VECTOR_ENGINE_SUNO_PROVIDER,
Self::SoundEffect => VECTOR_ENGINE_VIDU_PROVIDER,
}
}
fn asset_kind(self) -> &'static str {
match self {
Self::BackgroundMusic => MUSIC_ASSET_KIND,
Self::SoundEffect => AMBIENT_SOUND_ASSET_KIND,
}
}
fn slot(self) -> &'static str {
match self {
Self::BackgroundMusic => MUSIC_SLOT,
Self::SoundEffect => AMBIENT_SOUND_SLOT,
}
}
fn file_stem(self) -> &'static str {
match self {
Self::BackgroundMusic => "background-music",
Self::SoundEffect => "sound-effect",
}
}
}
pub async fn create_visual_novel_background_music_task(
State(state): State<AppState>,
axum::extract::Extension(request_context): axum::extract::Extension<RequestContext>,
payload: Result<Json<contract::CreateVisualNovelBackgroundMusicRequest>, JsonRejection>,
) -> Result<Json<Value>, Response> {
let Json(payload) = parse_json_payload(&request_context, payload)?;
let settings = require_vector_engine_audio_settings(&state)?;
let http_client = build_vector_engine_audio_http_client(&settings)?;
let prompt = normalize_limited_text(&payload.prompt, "prompt", SUNO_PROMPT_MAX_CHARS)?;
let title = normalize_limited_text(&payload.title, "title", SUNO_TITLE_MAX_CHARS)?;
let tags = payload
.tags
.as_deref()
.map(|value| normalize_limited_text(value, "tags", SUNO_TAGS_MAX_CHARS))
.transpose()?;
let model = normalize_optional_text(payload.model.as_deref())
.unwrap_or_else(|| SUNO_DEFAULT_MODEL.to_string());
let mut body = Map::from_iter([
("prompt".to_string(), Value::String(prompt)),
("mv".to_string(), Value::String(model)),
("title".to_string(), Value::String(title)),
("task".to_string(), Value::String("generate".to_string())),
]);
if let Some(tags) = tags {
body.insert("tags".to_string(), Value::String(tags));
}
let response = post_vector_engine_json(
&http_client,
&settings,
"/suno/submit/music",
Value::Object(body),
"提交 Suno 背景音乐任务失败",
)
.await?;
let task_id = extract_string_by_path(&response, &["data"])
.or_else(|| find_first_string_by_key(&response, "task_id"))
.or_else(|| find_first_string_by_key(&response, "taskId"))
.ok_or_else(|| {
vector_engine_bad_gateway("提交 Suno 背景音乐任务失败:上游未返回任务 ID")
})?;
Ok(json_success_body(
Some(&request_context),
contract::VisualNovelAudioGenerationTaskResponse {
kind: contract::VisualNovelAudioGenerationKind::BackgroundMusic,
task_id,
provider: VECTOR_ENGINE_SUNO_PROVIDER.to_string(),
status: "submitted".to_string(),
},
))
}
pub async fn create_visual_novel_sound_effect_task(
State(state): State<AppState>,
axum::extract::Extension(request_context): axum::extract::Extension<RequestContext>,
payload: Result<Json<contract::CreateVisualNovelSoundEffectRequest>, JsonRejection>,
) -> Result<Json<Value>, Response> {
let Json(payload) = parse_json_payload(&request_context, payload)?;
let settings = require_vector_engine_audio_settings(&state)?;
let http_client = build_vector_engine_audio_http_client(&settings)?;
let prompt = normalize_limited_text(&payload.prompt, "prompt", VIDU_PROMPT_MAX_CHARS)?;
let duration = payload
.duration
.unwrap_or(DEFAULT_SOUND_EFFECT_DURATION_SECONDS)
.clamp(2, 10);
let mut body = Map::from_iter([
(
"model".to_string(),
Value::String(VIDU_AUDIO_MODEL.to_string()),
),
("prompt".to_string(), Value::String(prompt)),
("duration".to_string(), json!(duration)),
]);
if let Some(seed) = payload.seed {
body.insert("seed".to_string(), json!(seed));
}
let response = post_vector_engine_json(
&http_client,
&settings,
"/ent/v2/text2audio",
Value::Object(body),
"提交 Vidu 音效任务失败",
)
.await?;
let task_id = find_first_string_by_key(&response, "task_id")
.or_else(|| find_first_string_by_key(&response, "taskId"))
.ok_or_else(|| vector_engine_bad_gateway("提交 Vidu 音效任务失败:上游未返回任务 ID"))?;
let status = find_first_string_by_key(&response, "state").unwrap_or_else(|| "created".into());
Ok(json_success_body(
Some(&request_context),
contract::VisualNovelAudioGenerationTaskResponse {
kind: contract::VisualNovelAudioGenerationKind::SoundEffect,
task_id,
provider: VECTOR_ENGINE_VIDU_PROVIDER.to_string(),
status,
},
))
}
pub async fn publish_visual_novel_background_music_asset(
State(state): State<AppState>,
Path(task_id): Path<String>,
axum::extract::Extension(request_context): axum::extract::Extension<RequestContext>,
axum::extract::Extension(authenticated): axum::extract::Extension<AuthenticatedAccessToken>,
payload: Result<Json<contract::PublishVisualNovelGeneratedAudioAssetRequest>, JsonRejection>,
) -> Result<Json<Value>, Response> {
publish_generated_audio_asset(
&state,
&request_context,
authenticated.claims().user_id(),
task_id,
parse_json_payload(&request_context, payload)?.0,
AudioAssetSlot::BackgroundMusic,
)
.await
.map(|payload| json_success_body(Some(&request_context), payload))
.map_err(|error| error.into_response_with_context(Some(&request_context)))
}
pub async fn publish_visual_novel_sound_effect_asset(
State(state): State<AppState>,
Path(task_id): Path<String>,
axum::extract::Extension(request_context): axum::extract::Extension<RequestContext>,
axum::extract::Extension(authenticated): axum::extract::Extension<AuthenticatedAccessToken>,
payload: Result<Json<contract::PublishVisualNovelGeneratedAudioAssetRequest>, JsonRejection>,
) -> Result<Json<Value>, Response> {
publish_generated_audio_asset(
&state,
&request_context,
authenticated.claims().user_id(),
task_id,
parse_json_payload(&request_context, payload)?.0,
AudioAssetSlot::SoundEffect,
)
.await
.map(|payload| json_success_body(Some(&request_context), payload))
.map_err(|error| error.into_response_with_context(Some(&request_context)))
}
async fn publish_generated_audio_asset(
state: &AppState,
_request_context: &RequestContext,
owner_user_id: &str,
task_id: String,
payload: contract::PublishVisualNovelGeneratedAudioAssetRequest,
slot: AudioAssetSlot,
) -> Result<contract::VisualNovelGeneratedAudioAssetResponse, AppError> {
let task_id = normalize_limited_text(&task_id, "taskId", 160)?;
let scene_id = normalize_limited_text(&payload.scene_id, "sceneId", 160)?;
let profile_id = normalize_optional_text(payload.profile_id.as_deref());
let settings = require_vector_engine_audio_settings(state)?;
let http_client = build_vector_engine_audio_http_client(&settings)?;
let task_payload = fetch_audio_task_payload(&http_client, &settings, slot, &task_id).await?;
let status = normalize_task_status(
find_first_string_by_key(&task_payload, "status")
.or_else(|| find_first_string_by_key(&task_payload, "state"))
.or_else(|| find_first_string_by_key(&task_payload, "Status"))
.as_deref()
.unwrap_or(""),
);
let mut audio_urls = extract_audio_urls(&task_payload);
if slot == AudioAssetSlot::BackgroundMusic && audio_urls.is_empty() {
if let Some(clip_id) = extract_string_by_path(&task_payload, &["data"])
.filter(|value| !value.trim().is_empty())
{
let wav_payload = get_vector_engine_json(
&http_client,
&settings,
&format!("/suno/act/wav/{}", encode_path_segment(clip_id.as_str())),
"获取 Suno wav 音频失败",
)
.await?;
audio_urls = extract_audio_urls(&wav_payload);
}
}
if is_pending_task_status(&status) && audio_urls.is_empty() {
return Ok(contract::VisualNovelGeneratedAudioAssetResponse {
kind: slot.contract_kind(),
task_id,
provider: slot.provider().to_string(),
status,
asset_object_id: None,
asset_kind: None,
audio_src: None,
});
}
if is_failed_task_status(&status) {
return Err(vector_engine_bad_gateway(
"音频生成任务失败,请调整提示词后重试",
));
}
let audio_url = audio_urls
.into_iter()
.next()
.ok_or_else(|| vector_engine_bad_gateway("音频生成尚未返回可下载地址"))?;
let audio = download_generated_audio(&http_client, &audio_url, slot.provider()).await?;
let persisted = persist_generated_audio_asset(
state,
&http_client,
owner_user_id,
profile_id,
scene_id,
&task_id,
slot,
audio,
)
.await?;
Ok(contract::VisualNovelGeneratedAudioAssetResponse {
kind: slot.contract_kind(),
task_id,
provider: slot.provider().to_string(),
status: "completed".to_string(),
asset_object_id: Some(persisted.asset_object_id),
asset_kind: Some(slot.asset_kind().to_string()),
audio_src: Some(persisted.audio_src),
})
}
async fn fetch_audio_task_payload(
http_client: &reqwest::Client,
settings: &VectorEngineAudioSettings,
slot: AudioAssetSlot,
task_id: &str,
) -> Result<Value, AppError> {
match slot {
AudioAssetSlot::BackgroundMusic => {
get_vector_engine_json(
http_client,
settings,
&format!("/suno/fetch/{}", encode_path_segment(task_id)),
"查询 Suno 背景音乐任务失败",
)
.await
}
AudioAssetSlot::SoundEffect => {
get_vector_engine_json(
http_client,
settings,
&format!("/ent/v2/tasks/{}/creations", encode_path_segment(task_id)),
"查询 Vidu 音效任务失败",
)
.await
}
}
}
#[derive(Clone, Debug)]
struct PersistedAudioAsset {
asset_object_id: String,
audio_src: String,
}
async fn persist_generated_audio_asset(
state: &AppState,
http_client: &reqwest::Client,
owner_user_id: &str,
profile_id: Option<String>,
scene_id: String,
task_id: &str,
slot: AudioAssetSlot,
audio: DownloadedAudio,
) -> Result<PersistedAudioAsset, AppError> {
let oss_client = state.oss_client().ok_or_else(|| {
AppError::from_status(StatusCode::SERVICE_UNAVAILABLE).with_details(json!({
"provider": "aliyun-oss",
"reason": "OSS 未完成环境变量配置",
}))
})?;
let file_name = format!("{}-{}.{}", slot.file_stem(), task_id, audio.extension);
let put_result = oss_client
.put_object(
http_client,
OssPutObjectRequest {
prefix: LegacyAssetPrefix::CustomWorldScenes,
path_segments: vec![
"visual-novel".to_string(),
profile_id.clone().unwrap_or_else(|| "draft".to_string()),
scene_id.clone(),
slot.slot().to_string(),
],
file_name,
content_type: Some(audio.mime_type.clone()),
access: OssObjectAccess::Private,
metadata: build_audio_asset_metadata(
owner_user_id,
profile_id.as_deref(),
&scene_id,
slot,
),
body: audio.bytes,
},
)
.await
.map_err(|error| map_oss_error(error, "aliyun-oss"))?;
let head = oss_client
.head_object(
http_client,
platform_oss::OssHeadObjectRequest {
object_key: put_result.object_key.clone(),
},
)
.await
.map_err(|error| map_oss_error(error, "aliyun-oss"))?;
let now_micros = current_utc_micros();
let asset_object = state
.spacetime_client()
.confirm_asset_object(
build_asset_object_upsert_input(
generate_asset_object_id(now_micros),
head.bucket,
head.object_key,
AssetObjectAccessPolicy::Private,
head.content_type.or(Some(audio.mime_type)),
head.content_length,
head.etag,
slot.asset_kind().to_string(),
Some(task_id.to_string()),
Some(owner_user_id.to_string()),
profile_id.clone(),
Some(scene_id.clone()),
now_micros,
)
.map_err(map_asset_field_error)?,
)
.await
.map_err(map_spacetime_error)?;
state
.spacetime_client()
.bind_asset_object_to_entity(
build_asset_entity_binding_input(
generate_asset_binding_id(now_micros),
asset_object.asset_object_id.clone(),
AUDIO_ENTITY_KIND.to_string(),
scene_id,
slot.slot().to_string(),
slot.asset_kind().to_string(),
Some(owner_user_id.to_string()),
profile_id,
now_micros,
)
.map_err(map_asset_field_error)?,
)
.await
.map_err(map_spacetime_error)?;
Ok(PersistedAudioAsset {
asset_object_id: asset_object.asset_object_id,
audio_src: put_result.legacy_public_path,
})
}
fn build_audio_asset_metadata(
owner_user_id: &str,
profile_id: Option<&str>,
scene_id: &str,
slot: AudioAssetSlot,
) -> BTreeMap<String, String> {
let mut metadata = BTreeMap::from([
("asset-kind".to_string(), slot.asset_kind().to_string()),
("owner-user-id".to_string(), owner_user_id.to_string()),
("entity-kind".to_string(), AUDIO_ENTITY_KIND.to_string()),
("entity-id".to_string(), scene_id.to_string()),
("slot".to_string(), slot.slot().to_string()),
("provider".to_string(), slot.provider().to_string()),
]);
if let Some(profile_id) = profile_id {
metadata.insert("profile-id".to_string(), profile_id.to_string());
}
metadata
}
fn require_vector_engine_audio_settings(
state: &AppState,
) -> Result<VectorEngineAudioSettings, AppError> {
let base_url = state
.config
.vector_engine_base_url
.trim()
.trim_end_matches('/');
if base_url.is_empty() {
return Err(
AppError::from_status(StatusCode::SERVICE_UNAVAILABLE).with_details(json!({
"provider": VECTOR_ENGINE_PROVIDER,
"reason": "VECTOR_ENGINE_BASE_URL 未配置",
})),
);
}
let api_key = state
.config
.vector_engine_api_key
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty())
.ok_or_else(|| {
AppError::from_status(StatusCode::SERVICE_UNAVAILABLE).with_details(json!({
"provider": VECTOR_ENGINE_PROVIDER,
"reason": "VECTOR_ENGINE_API_KEY 未配置",
}))
})?;
Ok(VectorEngineAudioSettings {
base_url: base_url.to_string(),
api_key: api_key.to_string(),
request_timeout_ms: state.config.vector_engine_audio_request_timeout_ms.max(1),
})
}
fn build_vector_engine_audio_http_client(
settings: &VectorEngineAudioSettings,
) -> Result<reqwest::Client, AppError> {
reqwest::Client::builder()
.timeout(Duration::from_millis(settings.request_timeout_ms))
.build()
.map_err(|error| {
AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR).with_details(json!({
"provider": VECTOR_ENGINE_PROVIDER,
"message": format!("构造 VectorEngine 音频生成 HTTP 客户端失败:{error}"),
}))
})
}
async fn post_vector_engine_json(
http_client: &reqwest::Client,
settings: &VectorEngineAudioSettings,
path: &str,
body: Value,
failure_context: &str,
) -> Result<Value, AppError> {
let response = http_client
.post(format!("{}{}", settings.base_url, path))
.header(
header::AUTHORIZATION,
format!("Bearer {}", settings.api_key),
)
.header(header::ACCEPT, "application/json")
.header(header::CONTENT_TYPE, "application/json")
.json(&body)
.send()
.await
.map_err(|error| vector_engine_bad_gateway(format!("{failure_context}{error}")))?;
parse_vector_engine_response(response, failure_context).await
}
async fn get_vector_engine_json(
http_client: &reqwest::Client,
settings: &VectorEngineAudioSettings,
path: &str,
failure_context: &str,
) -> Result<Value, AppError> {
let response = http_client
.get(format!("{}{}", settings.base_url, path))
.header(
header::AUTHORIZATION,
format!("Bearer {}", settings.api_key),
)
.header(header::ACCEPT, "application/json")
.send()
.await
.map_err(|error| vector_engine_bad_gateway(format!("{failure_context}{error}")))?;
parse_vector_engine_response(response, failure_context).await
}
async fn parse_vector_engine_response(
response: reqwest::Response,
failure_context: &str,
) -> Result<Value, AppError> {
let status = response.status();
let raw_text = response.text().await.map_err(|error| {
vector_engine_bad_gateway(format!("{failure_context}:读取响应失败:{error}"))
})?;
if !status.is_success() {
return Err(
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
"provider": VECTOR_ENGINE_PROVIDER,
"message": failure_context,
"status": status.as_u16(),
"rawExcerpt": truncate_raw(raw_text.as_str()),
})),
);
}
let payload = serde_json::from_str::<Value>(&raw_text).map_err(|error| {
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
"provider": VECTOR_ENGINE_PROVIDER,
"message": format!("{failure_context}:解析响应失败:{error}"),
"rawExcerpt": truncate_raw(raw_text.as_str()),
}))
})?;
if let Some(code) = payload.get("code").and_then(Value::as_str)
&& !matches!(
code.trim().to_ascii_lowercase().as_str(),
"success" | "succeeded" | "ok"
)
{
return Err(
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
"provider": VECTOR_ENGINE_PROVIDER,
"message": payload
.get("message")
.and_then(Value::as_str)
.unwrap_or(failure_context),
"code": code,
})),
);
}
Ok(payload)
}
async fn download_generated_audio(
http_client: &reqwest::Client,
audio_url: &str,
provider: &str,
) -> Result<DownloadedAudio, AppError> {
let response = http_client
.get(audio_url)
.send()
.await
.map_err(|error| vector_engine_bad_gateway(format!("下载生成音频失败:{error}")))?;
let status = response.status();
let content_type = response
.headers()
.get(header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.unwrap_or("audio/mpeg")
.to_string();
let body = response
.bytes()
.await
.map_err(|error| vector_engine_bad_gateway(format!("读取生成音频内容失败:{error}")))?;
if !status.is_success() {
return Err(
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
"provider": provider,
"message": "下载生成音频失败",
"status": status.as_u16(),
})),
);
}
if body.is_empty() || body.len() > MAX_GENERATED_AUDIO_BYTES {
return Err(
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
"provider": provider,
"message": "生成音频内容为空或超过大小上限",
})),
);
}
let mime_type = normalize_audio_mime_type(&content_type, audio_url);
Ok(DownloadedAudio {
extension: audio_mime_to_extension(&mime_type).to_string(),
mime_type,
bytes: body.to_vec(),
})
}
fn extract_audio_urls(payload: &Value) -> Vec<String> {
let mut urls = Vec::new();
collect_audio_url_strings(payload, &mut urls);
let mut deduped = Vec::new();
for url in urls {
if !deduped.contains(&url) {
deduped.push(url);
}
}
deduped
}
fn collect_audio_url_strings(value: &Value, output: &mut Vec<String>) {
match value {
Value::Object(object) => {
for (key, value) in object {
if let Some(raw) = value.as_str()
&& looks_like_audio_url_key(key)
&& looks_like_http_url(raw)
{
output.push(raw.trim().to_string());
}
collect_audio_url_strings(value, output);
}
}
Value::Array(items) => {
for item in items {
collect_audio_url_strings(item, output);
}
}
Value::String(raw) if looks_like_http_url(raw) && looks_like_audio_url(raw) => {
output.push(raw.trim().to_string());
}
_ => {}
}
}
fn looks_like_audio_url_key(key: &str) -> bool {
let normalized = key.trim().to_ascii_lowercase();
normalized.contains("audio")
|| normalized.contains("wav")
|| normalized.contains("mp3")
|| normalized.contains("fileurl")
|| normalized == "url"
|| normalized.ends_with("_url")
|| normalized.ends_with("url")
}
fn looks_like_http_url(value: &str) -> bool {
let value = value.trim().to_ascii_lowercase();
value.starts_with("http://") || value.starts_with("https://")
}
fn looks_like_audio_url(value: &str) -> bool {
let value = value
.trim()
.split('?')
.next()
.unwrap_or_default()
.to_ascii_lowercase();
value.ends_with(".mp3")
|| value.ends_with(".wav")
|| value.ends_with(".m4a")
|| value.ends_with(".aac")
|| value.ends_with(".ogg")
|| value.ends_with(".webm")
|| value.ends_with(".flac")
}
fn normalize_audio_mime_type(content_type: &str, audio_url: &str) -> String {
let mime_type = content_type
.split(';')
.next()
.map(str::trim)
.filter(|value| value.starts_with("audio/"))
.unwrap_or("");
match mime_type {
"audio/mpeg" | "audio/mp3" => "audio/mpeg".to_string(),
"audio/wav" | "audio/wave" | "audio/x-wav" => "audio/wav".to_string(),
"audio/ogg" => "audio/ogg".to_string(),
"audio/webm" => "audio/webm".to_string(),
"audio/aac" => "audio/aac".to_string(),
"audio/flac" => "audio/flac".to_string(),
"audio/mp4" | "audio/x-m4a" => "audio/mp4".to_string(),
_ => mime_type_from_audio_url(audio_url),
}
}
fn mime_type_from_audio_url(audio_url: &str) -> String {
let path = audio_url
.split('?')
.next()
.unwrap_or_default()
.to_ascii_lowercase();
if path.ends_with(".wav") {
"audio/wav".to_string()
} else if path.ends_with(".ogg") {
"audio/ogg".to_string()
} else if path.ends_with(".webm") {
"audio/webm".to_string()
} else if path.ends_with(".aac") {
"audio/aac".to_string()
} else if path.ends_with(".flac") {
"audio/flac".to_string()
} else if path.ends_with(".m4a") {
"audio/mp4".to_string()
} else {
"audio/mpeg".to_string()
}
}
fn audio_mime_to_extension(mime_type: &str) -> &'static str {
match mime_type {
"audio/wav" => "wav",
"audio/ogg" => "ogg",
"audio/webm" => "webm",
"audio/aac" => "aac",
"audio/flac" => "flac",
"audio/mp4" => "m4a",
_ => "mp3",
}
}
fn normalize_limited_text(
value: &str,
field: &'static str,
max_chars: usize,
) -> Result<String, AppError> {
let normalized = value.trim().to_string();
if normalized.is_empty() {
return Err(
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
"provider": VECTOR_ENGINE_PROVIDER,
"field": field,
"message": format!("{field} 不能为空"),
})),
);
}
if normalized.chars().count() > max_chars {
return Err(
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
"provider": VECTOR_ENGINE_PROVIDER,
"field": field,
"message": format!("{field} 超过 {} 字符", max_chars),
})),
);
}
Ok(normalized)
}
fn normalize_optional_text(value: Option<&str>) -> Option<String> {
value
.map(str::trim)
.filter(|value| !value.is_empty())
.map(ToOwned::to_owned)
}
fn normalize_task_status(status: &str) -> String {
let normalized = status.trim().to_ascii_lowercase().replace(' ', "_");
match normalized.as_str() {
"finish" | "finished" | "complete" | "completed" | "success" | "succeeded" => {
"completed".to_string()
}
"" => "processing".to_string(),
value => value.to_string(),
}
}
fn is_pending_task_status(status: &str) -> bool {
matches!(
status,
"created" | "pending" | "queued" | "processing" | "running" | "submitted" | "started"
)
}
fn is_failed_task_status(status: &str) -> bool {
matches!(
status,
"failed" | "error" | "canceled" | "cancelled" | "rejected" | "expired"
)
}
fn find_first_string_by_key(value: &Value, target_key: &str) -> Option<String> {
match value {
Value::Object(object) => {
for (key, value) in object {
if key.eq_ignore_ascii_case(target_key)
&& let Some(text) = value.as_str()
{
return Some(text.trim().to_string());
}
if let Some(found) = find_first_string_by_key(value, target_key) {
return Some(found);
}
}
None
}
Value::Array(items) => items
.iter()
.find_map(|item| find_first_string_by_key(item, target_key)),
_ => None,
}
}
fn extract_string_by_path(value: &Value, path: &[&str]) -> Option<String> {
let mut current = value;
for key in path {
current = current.get(*key)?;
}
current.as_str().map(str::trim).map(ToOwned::to_owned)
}
fn encode_path_segment(value: &str) -> String {
urlencoding::encode(value).into_owned()
}
fn truncate_raw(raw_text: &str) -> String {
raw_text.chars().take(800).collect()
}
fn current_utc_micros() -> i64 {
shared_kernel::offset_datetime_to_unix_micros(time::OffsetDateTime::now_utc())
}
fn map_asset_field_error(error: module_assets::AssetObjectFieldError) -> AppError {
AppError::from_status(StatusCode::BAD_REQUEST).with_details(json!({
"provider": "asset-object",
"message": error.to_string(),
}))
}
fn map_spacetime_error(error: spacetime_client::SpacetimeClientError) -> AppError {
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
"provider": "spacetimedb",
"message": error.to_string(),
}))
}
fn vector_engine_bad_gateway(message: impl Into<String>) -> AppError {
AppError::from_status(StatusCode::BAD_GATEWAY).with_details(json!({
"provider": VECTOR_ENGINE_PROVIDER,
"message": message.into(),
}))
}
fn parse_json_payload<T>(
request_context: &RequestContext,
payload: Result<Json<T>, JsonRejection>,
) -> Result<Json<T>, Response> {
payload.map_err(|rejection| {
AppError::from_status(StatusCode::BAD_REQUEST)
.with_message(format!("请求体 JSON 不合法:{rejection}"))
.into_response_with_context(Some(request_context))
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn normalizes_audio_mime_type_from_content_type_and_url() {
assert_eq!(
normalize_audio_mime_type("audio/x-wav; charset=utf-8", "https://x/a.bin"),
"audio/wav"
);
assert_eq!(
normalize_audio_mime_type("application/octet-stream", "https://x/a.m4a?token=1"),
"audio/mp4"
);
assert_eq!(audio_mime_to_extension("audio/mp4"), "m4a");
}
#[test]
fn extracts_nested_audio_urls() {
let payload = json!({
"Response": {
"Status": "FINISH",
"Task": {
"Output": {
"FileInfos": [
{ "FileUrl": "https://cdn.example.test/audio.wav" }
]
}
}
}
});
assert_eq!(
extract_audio_urls(&payload),
vec!["https://cdn.example.test/audio.wav".to_string()]
);
}
#[test]
fn vector_engine_task_status_is_stable() {
assert_eq!(normalize_task_status("FINISH"), "completed");
assert!(is_pending_task_status("processing"));
assert!(is_failed_task_status("failed"));
}
#[test]
fn validates_prompt_length() {
let prompt = "".repeat(VIDU_PROMPT_MAX_CHARS + 1);
let error = normalize_limited_text(&prompt, "prompt", VIDU_PROMPT_MAX_CHARS)
.expect_err("long prompt should fail");
assert_eq!(error.status_code(), StatusCode::BAD_REQUEST);
}
}

View File

@@ -1532,7 +1532,10 @@ mod tests {
let summary = resolve_document_summary_for_prompt(&record, None)
.expect("document session should build summary");
assert_eq!(summary.chars().count(), VISUAL_NOVEL_DOCUMENT_SUMMARY_MAX_CHARS);
assert_eq!(
summary.chars().count(),
VISUAL_NOVEL_DOCUMENT_SUMMARY_MAX_CHARS
);
assert!(summary.contains("旧书店"));
}
@@ -1598,7 +1601,8 @@ async fn create_or_update_creation_draft(
latest_user_text: Option<String>,
) -> Result<contract::VisualNovelResultDraft, Response> {
let now_iso = current_utc_iso();
let document_summary = resolve_document_summary_for_prompt(session, latest_user_text.as_deref());
let document_summary =
resolve_document_summary_for_prompt(session, latest_user_text.as_deref());
if let Some(llm_client) = state.llm_client() {
let current_draft = session.draft.as_ref();
let recent_messages = session
@@ -1682,7 +1686,12 @@ fn resolve_document_summary_for_prompt(
(!seed_text.is_empty()).then_some(seed_text)
})?;
Some(source.chars().take(VISUAL_NOVEL_DOCUMENT_SUMMARY_MAX_CHARS).collect())
Some(
source
.chars()
.take(VISUAL_NOVEL_DOCUMENT_SUMMARY_MAX_CHARS)
.collect(),
)
}
async fn compile_visual_novel_session_inner(

View File

@@ -0,0 +1,552 @@
use axum::{
Json,
body::Body,
extract::{
State,
ws::{Message as ClientWsMessage, WebSocket, WebSocketUpgrade},
},
http::{HeaderValue, StatusCode, header},
response::{IntoResponse, Response},
};
use futures_util::{SinkExt, StreamExt, TryStreamExt};
use platform_speech::{
AsrAudioConfig, AsrFrameKind, PublicSpeechConfig, PublicSpeechEndpoints, SpeechError,
TtsAudioParams, TtsBidirectionClientEvent, TtsSseRequest, VolcengineSpeechClient,
VolcengineSpeechConfig, build_asr_frame, build_asr_full_client_request,
build_tts_bidirection_frame_from_client_event, default_asr_request_payload,
parse_asr_response_frame, parse_tts_response_frame, tts_response_to_client_value,
};
use serde_json::{Value, json};
use tokio_tungstenite::tungstenite::Message as UpstreamWsMessage;
use tracing::{info, warn};
use crate::{
api_response::json_success_body, auth::AuthenticatedAccessToken, http_error::AppError,
request_context::RequestContext, state::AppState,
};
const PROVIDER: &str = "volcengine-speech";
pub async fn get_volcengine_speech_config(
State(state): State<AppState>,
axum::extract::Extension(request_context): axum::extract::Extension<RequestContext>,
) -> Json<Value> {
json_success_body(Some(&request_context), public_speech_config(&state))
}
pub async fn stream_volcengine_asr(
State(state): State<AppState>,
axum::extract::Extension(authenticated): axum::extract::Extension<AuthenticatedAccessToken>,
ws: WebSocketUpgrade,
) -> Result<Response, Response> {
let client = build_speech_client(&state)
.map_err(|error| map_speech_error(error).into_response_with_context(None))?;
let user_id = authenticated.claims().user_id().to_string();
Ok(ws.on_upgrade(move |socket| proxy_asr_websocket(socket, client, user_id)))
}
pub async fn stream_volcengine_tts_bidirection(
State(state): State<AppState>,
ws: WebSocketUpgrade,
) -> Result<Response, Response> {
let client = build_speech_client(&state)
.map_err(|error| map_speech_error(error).into_response_with_context(None))?;
Ok(ws.on_upgrade(move |socket| proxy_tts_bidirection_websocket(socket, client)))
}
pub async fn stream_volcengine_tts_sse(
State(state): State<AppState>,
axum::extract::Extension(request_context): axum::extract::Extension<RequestContext>,
axum::extract::Extension(authenticated): axum::extract::Extension<AuthenticatedAccessToken>,
payload: Result<Json<TtsSseRequest>, axum::extract::rejection::JsonRejection>,
) -> Result<Response, Response> {
let Json(payload) = payload.map_err(|rejection| {
AppError::from_status(StatusCode::BAD_REQUEST)
.with_message(format!("请求体 JSON 不合法:{rejection}"))
.into_response_with_context(Some(&request_context))
})?;
let client = build_speech_client(&state).map_err(|error| {
map_speech_error(error).into_response_with_context(Some(&request_context))
})?;
let upstream_request = client
.build_tts_sse_upstream_request(payload, authenticated.claims().user_id())
.map_err(|error| {
map_speech_error(error).into_response_with_context(Some(&request_context))
})?;
let http_client = reqwest::Client::builder()
.timeout(upstream_request.timeout)
.build()
.map_err(|error| {
AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR)
.with_details(json!({
"provider": PROVIDER,
"message": format!("构造火山语音 HTTP 客户端失败:{error}"),
}))
.into_response_with_context(Some(&request_context))
})?;
let upstream_response = http_client
.post(upstream_request.url)
.headers(upstream_request.headers)
.json(&upstream_request.body)
.send()
.await
.map_err(|error| {
AppError::from_status(StatusCode::BAD_GATEWAY)
.with_details(json!({
"provider": PROVIDER,
"message": format!("请求火山 TTS SSE 失败:{error}"),
}))
.into_response_with_context(Some(&request_context))
})?;
let status = upstream_response.status();
let log_id = upstream_response
.headers()
.get("X-Tt-Logid")
.and_then(|value| value.to_str().ok())
.map(ToOwned::to_owned);
if !status.is_success() {
let raw_text = upstream_response.text().await.unwrap_or_default();
return Err(AppError::from_status(StatusCode::BAD_GATEWAY)
.with_details(json!({
"provider": PROVIDER,
"status": status.as_u16(),
"logId": log_id,
"rawExcerpt": raw_text.chars().take(800).collect::<String>(),
}))
.into_response_with_context(Some(&request_context)));
}
let byte_stream = upstream_response
.bytes_stream()
.map_err(std::io::Error::other);
let mut response = Response::new(Body::from_stream(byte_stream));
*response.status_mut() = StatusCode::OK;
response.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("text/event-stream; charset=utf-8"),
);
response
.headers_mut()
.insert(header::CACHE_CONTROL, HeaderValue::from_static("no-cache"));
if let Some(log_id) = log_id.and_then(|value| HeaderValue::from_str(&value).ok()) {
response.headers_mut().insert("x-volcengine-logid", log_id);
}
Ok(response)
}
async fn proxy_asr_websocket(socket: WebSocket, client: VolcengineSpeechClient, user_id: String) {
let (mut browser_sender, mut browser_receiver) = socket.split();
let Ok((upstream, response_headers)) = client.connect_asr().await else {
let _ = browser_sender
.send(ClientWsMessage::Text(
json!({
"type": "error",
"provider": PROVIDER,
"message": "连接火山 ASR WebSocket 失败",
})
.to_string()
.into(),
))
.await;
return;
};
if let Some(log_id) = response_headers.get("x-tt-logid") {
info!(%log_id, "火山 ASR WebSocket 已连接");
}
let (mut upstream_sender, mut upstream_receiver) = upstream.split();
let mut has_sent_start = false;
let mut last_audio_sent = false;
let browser_to_upstream = async {
while let Some(message) = browser_receiver.next().await {
match message {
Ok(ClientWsMessage::Text(text)) => {
let value = serde_json::from_str::<Value>(text.as_str()).unwrap_or_else(|_| {
json!({
"request": {
"context": text.as_str(),
}
})
});
if value
.get("type")
.and_then(Value::as_str)
.is_some_and(|kind| kind.eq_ignore_ascii_case("finish"))
{
let frame = build_asr_frame(AsrFrameKind::LastAudio, &[])?;
upstream_sender
.send(UpstreamWsMessage::Binary(frame.into()))
.await
.map_err(map_ws_send_error)?;
last_audio_sent = true;
continue;
}
if !has_sent_start {
let payload = default_asr_request_payload(&user_id, Some(value));
let frame = build_asr_full_client_request(&payload)?;
upstream_sender
.send(UpstreamWsMessage::Binary(frame.into()))
.await
.map_err(map_ws_send_error)?;
has_sent_start = true;
}
}
Ok(ClientWsMessage::Binary(bytes)) => {
if !has_sent_start {
let payload = default_asr_request_payload(&user_id, None);
let frame = build_asr_full_client_request(&payload)?;
upstream_sender
.send(UpstreamWsMessage::Binary(frame.into()))
.await
.map_err(map_ws_send_error)?;
has_sent_start = true;
}
let frame = build_asr_frame(AsrFrameKind::Audio, &bytes)?;
upstream_sender
.send(UpstreamWsMessage::Binary(frame.into()))
.await
.map_err(map_ws_send_error)?;
}
Ok(ClientWsMessage::Close(_)) => break,
Ok(ClientWsMessage::Ping(bytes)) => {
upstream_sender
.send(UpstreamWsMessage::Ping(bytes))
.await
.map_err(map_ws_send_error)?;
}
Ok(ClientWsMessage::Pong(_)) => {}
Err(error) => {
return Err(SpeechError::Upstream(format!(
"读取浏览器 ASR WebSocket 失败:{error}"
)));
}
}
}
if has_sent_start && !last_audio_sent {
let frame = build_asr_frame(AsrFrameKind::LastAudio, &[])?;
let _ = upstream_sender
.send(UpstreamWsMessage::Binary(frame.into()))
.await;
}
Ok::<(), SpeechError>(())
};
let upstream_to_browser = async {
while let Some(message) = upstream_receiver.next().await {
match message {
Ok(UpstreamWsMessage::Binary(bytes)) => {
let parsed = parse_asr_response_frame(&bytes)?;
let value = json!({
"type": "asr_response",
"sequence": parsed.sequence,
"payload": parsed.payload,
"errorCode": parsed.error_code,
});
browser_sender
.send(ClientWsMessage::Text(value.to_string().into()))
.await
.map_err(map_client_ws_send_error)?;
}
Ok(UpstreamWsMessage::Text(text)) => {
browser_sender
.send(ClientWsMessage::Text(text))
.await
.map_err(map_client_ws_send_error)?;
}
Ok(UpstreamWsMessage::Close(close)) => {
let _ = browser_sender.send(ClientWsMessage::Close(close)).await;
break;
}
Ok(UpstreamWsMessage::Ping(bytes)) => {
browser_sender
.send(ClientWsMessage::Ping(bytes))
.await
.map_err(map_client_ws_send_error)?;
}
Ok(UpstreamWsMessage::Pong(_)) => {}
Ok(UpstreamWsMessage::Frame(_)) => {}
Err(error) => {
return Err(SpeechError::Upstream(format!(
"读取火山 ASR WebSocket 失败:{error}"
)));
}
}
}
Ok::<(), SpeechError>(())
};
let mut browser_to_upstream = Box::pin(browser_to_upstream);
let mut upstream_to_browser = Box::pin(upstream_to_browser);
let result = tokio::select! {
result = &mut browser_to_upstream => result,
result = &mut upstream_to_browser => result,
};
if let Err(error) = result {
warn!(error = %error, "火山 ASR WebSocket 代理中断");
}
}
async fn proxy_tts_bidirection_websocket(socket: WebSocket, client: VolcengineSpeechClient) {
let (mut browser_sender, mut browser_receiver) = socket.split();
let Ok((upstream, response_headers)) = client.connect_tts_bidirection().await else {
let _ = browser_sender
.send(ClientWsMessage::Text(
json!({
"type": "error",
"provider": PROVIDER,
"message": "连接火山 TTS WebSocket 失败",
})
.to_string()
.into(),
))
.await;
return;
};
if let Some(log_id) = response_headers.get("x-tt-logid") {
info!(%log_id, "火山 TTS WebSocket 已连接");
}
let (mut upstream_sender, mut upstream_receiver) = upstream.split();
let browser_to_upstream = async {
while let Some(message) = browser_receiver.next().await {
match message {
Ok(ClientWsMessage::Text(text)) => {
let event = serde_json::from_str::<TtsBidirectionClientEvent>(text.as_str())
.map_err(|error| {
SpeechError::InvalidFrame(format!(
"TTS 浏览器事件 JSON 不合法:{error}"
))
})?;
let frame = build_tts_bidirection_frame_from_client_event(event)?;
upstream_sender
.send(UpstreamWsMessage::Binary(frame.into()))
.await
.map_err(map_ws_send_error)?;
}
Ok(ClientWsMessage::Close(_)) => break,
Ok(ClientWsMessage::Ping(bytes)) => {
upstream_sender
.send(UpstreamWsMessage::Ping(bytes))
.await
.map_err(map_ws_send_error)?;
}
Ok(ClientWsMessage::Binary(_)) | Ok(ClientWsMessage::Pong(_)) => {}
Err(error) => {
return Err(SpeechError::Upstream(format!(
"读取浏览器 TTS WebSocket 失败:{error}"
)));
}
}
}
Ok::<(), SpeechError>(())
};
let upstream_to_browser = async {
while let Some(message) = upstream_receiver.next().await {
match message {
Ok(UpstreamWsMessage::Binary(bytes)) => {
let parsed = parse_tts_response_frame(&bytes)?;
if let Some(audio) = parsed.audio.clone() {
browser_sender
.send(ClientWsMessage::Binary(audio.into()))
.await
.map_err(map_client_ws_send_error)?;
}
if parsed.payload.is_some() || parsed.error_code.is_some() {
browser_sender
.send(ClientWsMessage::Text(
tts_response_to_client_value(&parsed).to_string().into(),
))
.await
.map_err(map_client_ws_send_error)?;
}
}
Ok(UpstreamWsMessage::Text(text)) => {
browser_sender
.send(ClientWsMessage::Text(text))
.await
.map_err(map_client_ws_send_error)?;
}
Ok(UpstreamWsMessage::Close(close)) => {
let _ = browser_sender.send(ClientWsMessage::Close(close)).await;
break;
}
Ok(UpstreamWsMessage::Ping(bytes)) => {
browser_sender
.send(ClientWsMessage::Ping(bytes))
.await
.map_err(map_client_ws_send_error)?;
}
Ok(UpstreamWsMessage::Pong(_)) => {}
Ok(UpstreamWsMessage::Frame(_)) => {}
Err(error) => {
return Err(SpeechError::Upstream(format!(
"读取火山 TTS WebSocket 失败:{error}"
)));
}
}
}
Ok::<(), SpeechError>(())
};
let mut browser_to_upstream = Box::pin(browser_to_upstream);
let mut upstream_to_browser = Box::pin(upstream_to_browser);
let result = tokio::select! {
result = &mut browser_to_upstream => result,
result = &mut upstream_to_browser => result,
};
if let Err(error) = result {
warn!(error = %error, "火山 TTS WebSocket 代理中断");
}
}
fn build_speech_client(state: &AppState) -> Result<VolcengineSpeechClient, SpeechError> {
Ok(VolcengineSpeechClient::new(VolcengineSpeechConfig::new(
state.config.volcengine_speech_api_key.clone(),
state.config.volcengine_speech_app_id.clone(),
state.config.volcengine_speech_access_key.clone(),
state.config.volcengine_speech_asr_resource_id.clone(),
state.config.volcengine_speech_tts_resource_id.clone(),
state.config.volcengine_speech_asr_ws_url.clone(),
state
.config
.volcengine_speech_tts_bidirection_ws_url
.clone(),
state.config.volcengine_speech_tts_sse_url.clone(),
state.config.volcengine_speech_request_timeout_ms,
)?))
}
fn public_speech_config(state: &AppState) -> PublicSpeechConfig {
PublicSpeechConfig {
asr_resource_id: state.config.volcengine_speech_asr_resource_id.clone(),
tts_resource_id: state.config.volcengine_speech_tts_resource_id.clone(),
asr_audio: AsrAudioConfig::default(),
tts_audio: TtsAudioParams::default(),
endpoints: PublicSpeechEndpoints {
asr_stream: "/api/speech/volcengine/asr/stream",
tts_bidirection: "/api/speech/volcengine/tts/bidirection",
tts_sse: "/api/speech/volcengine/tts/sse",
},
}
}
fn map_speech_error(error: SpeechError) -> AppError {
match error {
SpeechError::InvalidConfig(message) => {
AppError::from_status(StatusCode::SERVICE_UNAVAILABLE).with_details(json!({
"provider": PROVIDER,
"message": message,
}))
}
SpeechError::InvalidHeader(message)
| SpeechError::InvalidFrame(message)
| SpeechError::Serialize(message)
| SpeechError::Io(message)
| SpeechError::Upstream(message) => AppError::from_status(StatusCode::BAD_GATEWAY)
.with_details(json!({
"provider": PROVIDER,
"message": message,
})),
}
}
fn map_ws_send_error(error: tokio_tungstenite::tungstenite::Error) -> SpeechError {
SpeechError::Upstream(format!("发送火山语音 WebSocket 帧失败:{error}"))
}
fn map_client_ws_send_error(error: axum::Error) -> SpeechError {
SpeechError::Upstream(format!("发送浏览器语音 WebSocket 帧失败:{error}"))
}
#[cfg(test)]
mod tests {
use axum::{
body::Body,
http::{Request, StatusCode},
};
use http_body_util::BodyExt;
use serde_json::Value;
use tower::ServiceExt;
use super::*;
use crate::{app::build_router, config::AppConfig, state::AppState};
#[tokio::test]
async fn speech_config_route_requires_authentication() {
let app = build_router(AppState::new(AppConfig::default()).expect("state should build"));
let response = app
.oneshot(
Request::builder()
.uri("/api/speech/volcengine/config")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("request should complete");
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn speech_config_route_returns_no_secret_fields() {
let mut config = AppConfig::default();
config.volcengine_speech_api_key = Some("secret-key".to_string());
let state = AppState::new(config).expect("state should build");
state
.seed_test_phone_user_with_password("13800138088", "Password123")
.await;
let app = build_router(state);
let login_response = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/api/auth/entry")
.header("content-type", "application/json")
.body(Body::from(
json!({
"phone": "13800138088",
"password": "Password123"
})
.to_string(),
))
.expect("login request should build"),
)
.await
.expect("login should complete");
let login_body = login_response
.into_body()
.collect()
.await
.expect("login body should collect")
.to_bytes();
let login_payload: Value =
serde_json::from_slice(&login_body).expect("login body should be json");
let token = login_payload["token"].as_str().expect("token should exist");
let response = app
.oneshot(
Request::builder()
.uri("/api/speech/volcengine/config")
.header("authorization", format!("Bearer {token}"))
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("request should complete");
assert_eq!(response.status(), StatusCode::OK);
let body = response
.into_body()
.collect()
.await
.expect("body should collect")
.to_bytes();
let payload_text = String::from_utf8_lossy(&body);
assert!(!payload_text.contains("secret-key"));
assert!(!payload_text.contains("apiKey"));
assert!(payload_text.contains("asrResourceId"));
}
}