Files
Genarrative/server-rs/crates/api-server/src/volcengine_speech.rs
2026-05-08 21:46:11 +08:00

552 lines
22 KiB
Rust

use axum::{
Json,
body::Body,
extract::{
State,
ws::{Message as ClientWsMessage, WebSocket, WebSocketUpgrade},
},
http::{HeaderValue, StatusCode, header},
response::Response,
};
use futures_util::{SinkExt, StreamExt, TryStreamExt};
use platform_speech::{
AsrAudioConfig, AsrFrameKind, PublicSpeechConfig, PublicSpeechEndpoints, SpeechError,
TtsAudioParams, TtsBidirectionClientEvent, TtsSseRequest, UpstreamWsError, UpstreamWsMessage,
VolcengineSpeechClient, VolcengineSpeechConfig, build_asr_frame, build_asr_full_client_request,
build_tts_bidirection_frame_from_client_event, default_asr_request_payload,
parse_asr_response_frame, parse_tts_response_frame, tts_response_to_client_value,
};
use serde_json::{Value, json};
use tracing::{info, warn};
use crate::{
api_response::json_success_body, auth::AuthenticatedAccessToken, http_error::AppError,
request_context::RequestContext, state::AppState,
};
const PROVIDER: &str = "volcengine-speech";
pub async fn get_volcengine_speech_config(
State(state): State<AppState>,
axum::extract::Extension(request_context): axum::extract::Extension<RequestContext>,
) -> Json<Value> {
json_success_body(Some(&request_context), public_speech_config(&state))
}
pub async fn stream_volcengine_asr(
State(state): State<AppState>,
axum::extract::Extension(authenticated): axum::extract::Extension<AuthenticatedAccessToken>,
ws: WebSocketUpgrade,
) -> Result<Response, Response> {
let client = build_speech_client(&state)
.map_err(|error| map_speech_error(error).into_response_with_context(None))?;
let user_id = authenticated.claims().user_id().to_string();
Ok(ws.on_upgrade(move |socket| proxy_asr_websocket(socket, client, user_id)))
}
pub async fn stream_volcengine_tts_bidirection(
State(state): State<AppState>,
ws: WebSocketUpgrade,
) -> Result<Response, Response> {
let client = build_speech_client(&state)
.map_err(|error| map_speech_error(error).into_response_with_context(None))?;
Ok(ws.on_upgrade(move |socket| proxy_tts_bidirection_websocket(socket, client)))
}
pub async fn stream_volcengine_tts_sse(
State(state): State<AppState>,
axum::extract::Extension(request_context): axum::extract::Extension<RequestContext>,
axum::extract::Extension(authenticated): axum::extract::Extension<AuthenticatedAccessToken>,
payload: Result<Json<TtsSseRequest>, axum::extract::rejection::JsonRejection>,
) -> Result<Response, Response> {
let Json(payload) = payload.map_err(|rejection| {
AppError::from_status(StatusCode::BAD_REQUEST)
.with_message(format!("请求体 JSON 不合法:{rejection}"))
.into_response_with_context(Some(&request_context))
})?;
let client = build_speech_client(&state).map_err(|error| {
map_speech_error(error).into_response_with_context(Some(&request_context))
})?;
let upstream_request = client
.build_tts_sse_upstream_request(payload, authenticated.claims().user_id())
.map_err(|error| {
map_speech_error(error).into_response_with_context(Some(&request_context))
})?;
let http_client = reqwest::Client::builder()
.timeout(upstream_request.timeout)
.build()
.map_err(|error| {
AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR)
.with_details(json!({
"provider": PROVIDER,
"message": format!("构造火山语音 HTTP 客户端失败:{error}"),
}))
.into_response_with_context(Some(&request_context))
})?;
let upstream_response = http_client
.post(upstream_request.url)
.headers(upstream_request.headers)
.json(&upstream_request.body)
.send()
.await
.map_err(|error| {
AppError::from_status(StatusCode::BAD_GATEWAY)
.with_details(json!({
"provider": PROVIDER,
"message": format!("请求火山 TTS SSE 失败:{error}"),
}))
.into_response_with_context(Some(&request_context))
})?;
let status = upstream_response.status();
let log_id = upstream_response
.headers()
.get("X-Tt-Logid")
.and_then(|value| value.to_str().ok())
.map(ToOwned::to_owned);
if !status.is_success() {
let raw_text = upstream_response.text().await.unwrap_or_default();
return Err(AppError::from_status(StatusCode::BAD_GATEWAY)
.with_details(json!({
"provider": PROVIDER,
"status": status.as_u16(),
"logId": log_id,
"rawExcerpt": raw_text.chars().take(800).collect::<String>(),
}))
.into_response_with_context(Some(&request_context)));
}
let byte_stream = upstream_response
.bytes_stream()
.map_err(std::io::Error::other);
let mut response = Response::new(Body::from_stream(byte_stream));
*response.status_mut() = StatusCode::OK;
response.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("text/event-stream; charset=utf-8"),
);
response
.headers_mut()
.insert(header::CACHE_CONTROL, HeaderValue::from_static("no-cache"));
if let Some(log_id) = log_id.and_then(|value| HeaderValue::from_str(&value).ok()) {
response.headers_mut().insert("x-volcengine-logid", log_id);
}
Ok(response)
}
async fn proxy_asr_websocket(socket: WebSocket, client: VolcengineSpeechClient, user_id: String) {
let (mut browser_sender, mut browser_receiver) = socket.split();
let Ok((upstream, response_headers)) = client.connect_asr().await else {
let _ = browser_sender
.send(ClientWsMessage::Text(
json!({
"type": "error",
"provider": PROVIDER,
"message": "连接火山 ASR WebSocket 失败",
})
.to_string()
.into(),
))
.await;
return;
};
if let Some(log_id) = response_headers.get("x-tt-logid") {
info!(%log_id, "火山 ASR WebSocket 已连接");
}
let (mut upstream_sender, mut upstream_receiver) = upstream.split();
let mut has_sent_start = false;
let mut last_audio_sent = false;
let browser_to_upstream = async {
while let Some(message) = browser_receiver.next().await {
match message {
Ok(ClientWsMessage::Text(text)) => {
let value = serde_json::from_str::<Value>(text.as_str()).unwrap_or_else(|_| {
json!({
"request": {
"context": text.as_str(),
}
})
});
if value
.get("type")
.and_then(Value::as_str)
.is_some_and(|kind| kind.eq_ignore_ascii_case("finish"))
{
let frame = build_asr_frame(AsrFrameKind::LastAudio, &[])?;
upstream_sender
.send(UpstreamWsMessage::Binary(frame.into()))
.await
.map_err(map_ws_send_error)?;
last_audio_sent = true;
continue;
}
if !has_sent_start {
let payload = default_asr_request_payload(&user_id, Some(value));
let frame = build_asr_full_client_request(&payload)?;
upstream_sender
.send(UpstreamWsMessage::Binary(frame.into()))
.await
.map_err(map_ws_send_error)?;
has_sent_start = true;
}
}
Ok(ClientWsMessage::Binary(bytes)) => {
if !has_sent_start {
let payload = default_asr_request_payload(&user_id, None);
let frame = build_asr_full_client_request(&payload)?;
upstream_sender
.send(UpstreamWsMessage::Binary(frame.into()))
.await
.map_err(map_ws_send_error)?;
has_sent_start = true;
}
let frame = build_asr_frame(AsrFrameKind::Audio, &bytes)?;
upstream_sender
.send(UpstreamWsMessage::Binary(frame.into()))
.await
.map_err(map_ws_send_error)?;
}
Ok(ClientWsMessage::Close(_)) => break,
Ok(ClientWsMessage::Ping(bytes)) => {
upstream_sender
.send(UpstreamWsMessage::Ping(bytes))
.await
.map_err(map_ws_send_error)?;
}
Ok(ClientWsMessage::Pong(_)) => {}
Err(error) => {
return Err(SpeechError::Upstream(format!(
"读取浏览器 ASR WebSocket 失败:{error}"
)));
}
}
}
if has_sent_start && !last_audio_sent {
let frame = build_asr_frame(AsrFrameKind::LastAudio, &[])?;
let _ = upstream_sender
.send(UpstreamWsMessage::Binary(frame.into()))
.await;
}
Ok::<(), SpeechError>(())
};
let upstream_to_browser = async {
while let Some(message) = upstream_receiver.next().await {
match message {
Ok(UpstreamWsMessage::Binary(bytes)) => {
let parsed = parse_asr_response_frame(&bytes)?;
let value = json!({
"type": "asr_response",
"sequence": parsed.sequence,
"payload": parsed.payload,
"errorCode": parsed.error_code,
});
browser_sender
.send(ClientWsMessage::Text(value.to_string().into()))
.await
.map_err(map_client_ws_send_error)?;
}
Ok(UpstreamWsMessage::Text(text)) => {
browser_sender
.send(ClientWsMessage::Text(text.to_string().into()))
.await
.map_err(map_client_ws_send_error)?;
}
Ok(UpstreamWsMessage::Close(_)) => {
let _ = browser_sender.send(ClientWsMessage::Close(None)).await;
break;
}
Ok(UpstreamWsMessage::Ping(bytes)) => {
browser_sender
.send(ClientWsMessage::Ping(bytes))
.await
.map_err(map_client_ws_send_error)?;
}
Ok(UpstreamWsMessage::Pong(_)) => {}
Ok(UpstreamWsMessage::Frame(_)) => {}
Err(error) => {
return Err(SpeechError::Upstream(format!(
"读取火山 ASR WebSocket 失败:{error}"
)));
}
}
}
Ok::<(), SpeechError>(())
};
let mut browser_to_upstream = Box::pin(browser_to_upstream);
let mut upstream_to_browser = Box::pin(upstream_to_browser);
let result = tokio::select! {
result = &mut browser_to_upstream => result,
result = &mut upstream_to_browser => result,
};
if let Err(error) = result {
warn!(error = %error, "火山 ASR WebSocket 代理中断");
}
}
async fn proxy_tts_bidirection_websocket(socket: WebSocket, client: VolcengineSpeechClient) {
let (mut browser_sender, mut browser_receiver) = socket.split();
let Ok((upstream, response_headers)) = client.connect_tts_bidirection().await else {
let _ = browser_sender
.send(ClientWsMessage::Text(
json!({
"type": "error",
"provider": PROVIDER,
"message": "连接火山 TTS WebSocket 失败",
})
.to_string()
.into(),
))
.await;
return;
};
if let Some(log_id) = response_headers.get("x-tt-logid") {
info!(%log_id, "火山 TTS WebSocket 已连接");
}
let (mut upstream_sender, mut upstream_receiver) = upstream.split();
let browser_to_upstream = async {
while let Some(message) = browser_receiver.next().await {
match message {
Ok(ClientWsMessage::Text(text)) => {
let event = serde_json::from_str::<TtsBidirectionClientEvent>(text.as_str())
.map_err(|error| {
SpeechError::InvalidFrame(format!(
"TTS 浏览器事件 JSON 不合法:{error}"
))
})?;
let frame = build_tts_bidirection_frame_from_client_event(event)?;
upstream_sender
.send(UpstreamWsMessage::Binary(frame.into()))
.await
.map_err(map_ws_send_error)?;
}
Ok(ClientWsMessage::Close(_)) => break,
Ok(ClientWsMessage::Ping(bytes)) => {
upstream_sender
.send(UpstreamWsMessage::Ping(bytes))
.await
.map_err(map_ws_send_error)?;
}
Ok(ClientWsMessage::Binary(_)) | Ok(ClientWsMessage::Pong(_)) => {}
Err(error) => {
return Err(SpeechError::Upstream(format!(
"读取浏览器 TTS WebSocket 失败:{error}"
)));
}
}
}
Ok::<(), SpeechError>(())
};
let upstream_to_browser = async {
while let Some(message) = upstream_receiver.next().await {
match message {
Ok(UpstreamWsMessage::Binary(bytes)) => {
let parsed = parse_tts_response_frame(&bytes)?;
if let Some(audio) = parsed.audio.clone() {
browser_sender
.send(ClientWsMessage::Binary(audio.into()))
.await
.map_err(map_client_ws_send_error)?;
}
if parsed.payload.is_some() || parsed.error_code.is_some() {
browser_sender
.send(ClientWsMessage::Text(
tts_response_to_client_value(&parsed).to_string().into(),
))
.await
.map_err(map_client_ws_send_error)?;
}
}
Ok(UpstreamWsMessage::Text(text)) => {
browser_sender
.send(ClientWsMessage::Text(text.to_string().into()))
.await
.map_err(map_client_ws_send_error)?;
}
Ok(UpstreamWsMessage::Close(_)) => {
let _ = browser_sender.send(ClientWsMessage::Close(None)).await;
break;
}
Ok(UpstreamWsMessage::Ping(bytes)) => {
browser_sender
.send(ClientWsMessage::Ping(bytes))
.await
.map_err(map_client_ws_send_error)?;
}
Ok(UpstreamWsMessage::Pong(_)) => {}
Ok(UpstreamWsMessage::Frame(_)) => {}
Err(error) => {
return Err(SpeechError::Upstream(format!(
"读取火山 TTS WebSocket 失败:{error}"
)));
}
}
}
Ok::<(), SpeechError>(())
};
let mut browser_to_upstream = Box::pin(browser_to_upstream);
let mut upstream_to_browser = Box::pin(upstream_to_browser);
let result = tokio::select! {
result = &mut browser_to_upstream => result,
result = &mut upstream_to_browser => result,
};
if let Err(error) = result {
warn!(error = %error, "火山 TTS WebSocket 代理中断");
}
}
fn build_speech_client(state: &AppState) -> Result<VolcengineSpeechClient, SpeechError> {
Ok(VolcengineSpeechClient::new(VolcengineSpeechConfig::new(
state.config.volcengine_speech_api_key.clone(),
state.config.volcengine_speech_app_id.clone(),
state.config.volcengine_speech_access_key.clone(),
state.config.volcengine_speech_asr_resource_id.clone(),
state.config.volcengine_speech_tts_resource_id.clone(),
state.config.volcengine_speech_asr_ws_url.clone(),
state
.config
.volcengine_speech_tts_bidirection_ws_url
.clone(),
state.config.volcengine_speech_tts_sse_url.clone(),
state.config.volcengine_speech_request_timeout_ms,
)?))
}
fn public_speech_config(state: &AppState) -> PublicSpeechConfig {
PublicSpeechConfig {
asr_resource_id: state.config.volcengine_speech_asr_resource_id.clone(),
tts_resource_id: state.config.volcengine_speech_tts_resource_id.clone(),
asr_audio: AsrAudioConfig::default(),
tts_audio: TtsAudioParams::default(),
endpoints: PublicSpeechEndpoints {
asr_stream: "/api/speech/volcengine/asr/stream",
tts_bidirection: "/api/speech/volcengine/tts/bidirection",
tts_sse: "/api/speech/volcengine/tts/sse",
},
}
}
fn map_speech_error(error: SpeechError) -> AppError {
match error {
SpeechError::InvalidConfig(message) => {
AppError::from_status(StatusCode::SERVICE_UNAVAILABLE).with_details(json!({
"provider": PROVIDER,
"message": message,
}))
}
SpeechError::InvalidHeader(message)
| SpeechError::InvalidFrame(message)
| SpeechError::Serialize(message)
| SpeechError::Io(message)
| SpeechError::Upstream(message) => AppError::from_status(StatusCode::BAD_GATEWAY)
.with_details(json!({
"provider": PROVIDER,
"message": message,
})),
}
}
fn map_ws_send_error(error: UpstreamWsError) -> SpeechError {
SpeechError::Upstream(format!("发送火山语音 WebSocket 帧失败:{error}"))
}
fn map_client_ws_send_error(error: axum::Error) -> SpeechError {
SpeechError::Upstream(format!("发送浏览器语音 WebSocket 帧失败:{error}"))
}
#[cfg(test)]
mod tests {
use axum::{
body::Body,
http::{Request, StatusCode},
};
use http_body_util::BodyExt;
use serde_json::Value;
use tower::ServiceExt;
use super::*;
use crate::{app::build_router, config::AppConfig, state::AppState};
#[tokio::test]
async fn speech_config_route_requires_authentication() {
let app = build_router(AppState::new(AppConfig::default()).expect("state should build"));
let response = app
.oneshot(
Request::builder()
.uri("/api/speech/volcengine/config")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("request should complete");
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn speech_config_route_returns_no_secret_fields() {
let mut config = AppConfig::default();
config.volcengine_speech_api_key = Some("secret-key".to_string());
let state = AppState::new(config).expect("state should build");
state
.seed_test_phone_user_with_password("13800138088", "Password123")
.await;
let app = build_router(state);
let login_response = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/api/auth/entry")
.header("content-type", "application/json")
.body(Body::from(
json!({
"phone": "13800138088",
"password": "Password123"
})
.to_string(),
))
.expect("login request should build"),
)
.await
.expect("login should complete");
let login_body = login_response
.into_body()
.collect()
.await
.expect("login body should collect")
.to_bytes();
let login_payload: Value =
serde_json::from_slice(&login_body).expect("login body should be json");
let token = login_payload["token"].as_str().expect("token should exist");
let response = app
.oneshot(
Request::builder()
.uri("/api/speech/volcengine/config")
.header("authorization", format!("Bearer {token}"))
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("request should complete");
assert_eq!(response.status(), StatusCode::OK);
let body = response
.into_body()
.collect()
.await
.expect("body should collect")
.to_bytes();
let payload_text = String::from_utf8_lossy(&body);
assert!(!payload_text.contains("secret-key"));
assert!(!payload_text.contains("apiKey"));
assert!(payload_text.contains("asrResourceId"));
}
}