1
This commit is contained in:
552
server-rs/crates/api-server/src/volcengine_speech.rs
Normal file
552
server-rs/crates/api-server/src/volcengine_speech.rs
Normal file
@@ -0,0 +1,552 @@
|
||||
use axum::{
|
||||
Json,
|
||||
body::Body,
|
||||
extract::{
|
||||
State,
|
||||
ws::{Message as ClientWsMessage, WebSocket, WebSocketUpgrade},
|
||||
},
|
||||
http::{HeaderValue, StatusCode, header},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use futures_util::{SinkExt, StreamExt, TryStreamExt};
|
||||
use platform_speech::{
|
||||
AsrAudioConfig, AsrFrameKind, PublicSpeechConfig, PublicSpeechEndpoints, SpeechError,
|
||||
TtsAudioParams, TtsBidirectionClientEvent, TtsSseRequest, VolcengineSpeechClient,
|
||||
VolcengineSpeechConfig, build_asr_frame, build_asr_full_client_request,
|
||||
build_tts_bidirection_frame_from_client_event, default_asr_request_payload,
|
||||
parse_asr_response_frame, parse_tts_response_frame, tts_response_to_client_value,
|
||||
};
|
||||
use serde_json::{Value, json};
|
||||
use tokio_tungstenite::tungstenite::Message as UpstreamWsMessage;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::{
|
||||
api_response::json_success_body, auth::AuthenticatedAccessToken, http_error::AppError,
|
||||
request_context::RequestContext, state::AppState,
|
||||
};
|
||||
|
||||
const PROVIDER: &str = "volcengine-speech";
|
||||
|
||||
pub async fn get_volcengine_speech_config(
|
||||
State(state): State<AppState>,
|
||||
axum::extract::Extension(request_context): axum::extract::Extension<RequestContext>,
|
||||
) -> Json<Value> {
|
||||
json_success_body(Some(&request_context), public_speech_config(&state))
|
||||
}
|
||||
|
||||
pub async fn stream_volcengine_asr(
|
||||
State(state): State<AppState>,
|
||||
axum::extract::Extension(authenticated): axum::extract::Extension<AuthenticatedAccessToken>,
|
||||
ws: WebSocketUpgrade,
|
||||
) -> Result<Response, Response> {
|
||||
let client = build_speech_client(&state)
|
||||
.map_err(|error| map_speech_error(error).into_response_with_context(None))?;
|
||||
let user_id = authenticated.claims().user_id().to_string();
|
||||
Ok(ws.on_upgrade(move |socket| proxy_asr_websocket(socket, client, user_id)))
|
||||
}
|
||||
|
||||
pub async fn stream_volcengine_tts_bidirection(
|
||||
State(state): State<AppState>,
|
||||
ws: WebSocketUpgrade,
|
||||
) -> Result<Response, Response> {
|
||||
let client = build_speech_client(&state)
|
||||
.map_err(|error| map_speech_error(error).into_response_with_context(None))?;
|
||||
Ok(ws.on_upgrade(move |socket| proxy_tts_bidirection_websocket(socket, client)))
|
||||
}
|
||||
|
||||
pub async fn stream_volcengine_tts_sse(
|
||||
State(state): State<AppState>,
|
||||
axum::extract::Extension(request_context): axum::extract::Extension<RequestContext>,
|
||||
axum::extract::Extension(authenticated): axum::extract::Extension<AuthenticatedAccessToken>,
|
||||
payload: Result<Json<TtsSseRequest>, axum::extract::rejection::JsonRejection>,
|
||||
) -> Result<Response, Response> {
|
||||
let Json(payload) = payload.map_err(|rejection| {
|
||||
AppError::from_status(StatusCode::BAD_REQUEST)
|
||||
.with_message(format!("请求体 JSON 不合法:{rejection}"))
|
||||
.into_response_with_context(Some(&request_context))
|
||||
})?;
|
||||
let client = build_speech_client(&state).map_err(|error| {
|
||||
map_speech_error(error).into_response_with_context(Some(&request_context))
|
||||
})?;
|
||||
let upstream_request = client
|
||||
.build_tts_sse_upstream_request(payload, authenticated.claims().user_id())
|
||||
.map_err(|error| {
|
||||
map_speech_error(error).into_response_with_context(Some(&request_context))
|
||||
})?;
|
||||
let http_client = reqwest::Client::builder()
|
||||
.timeout(upstream_request.timeout)
|
||||
.build()
|
||||
.map_err(|error| {
|
||||
AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.with_details(json!({
|
||||
"provider": PROVIDER,
|
||||
"message": format!("构造火山语音 HTTP 客户端失败:{error}"),
|
||||
}))
|
||||
.into_response_with_context(Some(&request_context))
|
||||
})?;
|
||||
let upstream_response = http_client
|
||||
.post(upstream_request.url)
|
||||
.headers(upstream_request.headers)
|
||||
.json(&upstream_request.body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|error| {
|
||||
AppError::from_status(StatusCode::BAD_GATEWAY)
|
||||
.with_details(json!({
|
||||
"provider": PROVIDER,
|
||||
"message": format!("请求火山 TTS SSE 失败:{error}"),
|
||||
}))
|
||||
.into_response_with_context(Some(&request_context))
|
||||
})?;
|
||||
let status = upstream_response.status();
|
||||
let log_id = upstream_response
|
||||
.headers()
|
||||
.get("X-Tt-Logid")
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.map(ToOwned::to_owned);
|
||||
if !status.is_success() {
|
||||
let raw_text = upstream_response.text().await.unwrap_or_default();
|
||||
return Err(AppError::from_status(StatusCode::BAD_GATEWAY)
|
||||
.with_details(json!({
|
||||
"provider": PROVIDER,
|
||||
"status": status.as_u16(),
|
||||
"logId": log_id,
|
||||
"rawExcerpt": raw_text.chars().take(800).collect::<String>(),
|
||||
}))
|
||||
.into_response_with_context(Some(&request_context)));
|
||||
}
|
||||
|
||||
let byte_stream = upstream_response
|
||||
.bytes_stream()
|
||||
.map_err(std::io::Error::other);
|
||||
let mut response = Response::new(Body::from_stream(byte_stream));
|
||||
*response.status_mut() = StatusCode::OK;
|
||||
response.headers_mut().insert(
|
||||
header::CONTENT_TYPE,
|
||||
HeaderValue::from_static("text/event-stream; charset=utf-8"),
|
||||
);
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(header::CACHE_CONTROL, HeaderValue::from_static("no-cache"));
|
||||
if let Some(log_id) = log_id.and_then(|value| HeaderValue::from_str(&value).ok()) {
|
||||
response.headers_mut().insert("x-volcengine-logid", log_id);
|
||||
}
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn proxy_asr_websocket(socket: WebSocket, client: VolcengineSpeechClient, user_id: String) {
|
||||
let (mut browser_sender, mut browser_receiver) = socket.split();
|
||||
let Ok((upstream, response_headers)) = client.connect_asr().await else {
|
||||
let _ = browser_sender
|
||||
.send(ClientWsMessage::Text(
|
||||
json!({
|
||||
"type": "error",
|
||||
"provider": PROVIDER,
|
||||
"message": "连接火山 ASR WebSocket 失败",
|
||||
})
|
||||
.to_string()
|
||||
.into(),
|
||||
))
|
||||
.await;
|
||||
return;
|
||||
};
|
||||
if let Some(log_id) = response_headers.get("x-tt-logid") {
|
||||
info!(%log_id, "火山 ASR WebSocket 已连接");
|
||||
}
|
||||
let (mut upstream_sender, mut upstream_receiver) = upstream.split();
|
||||
let mut has_sent_start = false;
|
||||
let mut last_audio_sent = false;
|
||||
|
||||
let browser_to_upstream = async {
|
||||
while let Some(message) = browser_receiver.next().await {
|
||||
match message {
|
||||
Ok(ClientWsMessage::Text(text)) => {
|
||||
let value = serde_json::from_str::<Value>(text.as_str()).unwrap_or_else(|_| {
|
||||
json!({
|
||||
"request": {
|
||||
"context": text.as_str(),
|
||||
}
|
||||
})
|
||||
});
|
||||
if value
|
||||
.get("type")
|
||||
.and_then(Value::as_str)
|
||||
.is_some_and(|kind| kind.eq_ignore_ascii_case("finish"))
|
||||
{
|
||||
let frame = build_asr_frame(AsrFrameKind::LastAudio, &[])?;
|
||||
upstream_sender
|
||||
.send(UpstreamWsMessage::Binary(frame.into()))
|
||||
.await
|
||||
.map_err(map_ws_send_error)?;
|
||||
last_audio_sent = true;
|
||||
continue;
|
||||
}
|
||||
if !has_sent_start {
|
||||
let payload = default_asr_request_payload(&user_id, Some(value));
|
||||
let frame = build_asr_full_client_request(&payload)?;
|
||||
upstream_sender
|
||||
.send(UpstreamWsMessage::Binary(frame.into()))
|
||||
.await
|
||||
.map_err(map_ws_send_error)?;
|
||||
has_sent_start = true;
|
||||
}
|
||||
}
|
||||
Ok(ClientWsMessage::Binary(bytes)) => {
|
||||
if !has_sent_start {
|
||||
let payload = default_asr_request_payload(&user_id, None);
|
||||
let frame = build_asr_full_client_request(&payload)?;
|
||||
upstream_sender
|
||||
.send(UpstreamWsMessage::Binary(frame.into()))
|
||||
.await
|
||||
.map_err(map_ws_send_error)?;
|
||||
has_sent_start = true;
|
||||
}
|
||||
let frame = build_asr_frame(AsrFrameKind::Audio, &bytes)?;
|
||||
upstream_sender
|
||||
.send(UpstreamWsMessage::Binary(frame.into()))
|
||||
.await
|
||||
.map_err(map_ws_send_error)?;
|
||||
}
|
||||
Ok(ClientWsMessage::Close(_)) => break,
|
||||
Ok(ClientWsMessage::Ping(bytes)) => {
|
||||
upstream_sender
|
||||
.send(UpstreamWsMessage::Ping(bytes))
|
||||
.await
|
||||
.map_err(map_ws_send_error)?;
|
||||
}
|
||||
Ok(ClientWsMessage::Pong(_)) => {}
|
||||
Err(error) => {
|
||||
return Err(SpeechError::Upstream(format!(
|
||||
"读取浏览器 ASR WebSocket 失败:{error}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
if has_sent_start && !last_audio_sent {
|
||||
let frame = build_asr_frame(AsrFrameKind::LastAudio, &[])?;
|
||||
let _ = upstream_sender
|
||||
.send(UpstreamWsMessage::Binary(frame.into()))
|
||||
.await;
|
||||
}
|
||||
Ok::<(), SpeechError>(())
|
||||
};
|
||||
|
||||
let upstream_to_browser = async {
|
||||
while let Some(message) = upstream_receiver.next().await {
|
||||
match message {
|
||||
Ok(UpstreamWsMessage::Binary(bytes)) => {
|
||||
let parsed = parse_asr_response_frame(&bytes)?;
|
||||
let value = json!({
|
||||
"type": "asr_response",
|
||||
"sequence": parsed.sequence,
|
||||
"payload": parsed.payload,
|
||||
"errorCode": parsed.error_code,
|
||||
});
|
||||
browser_sender
|
||||
.send(ClientWsMessage::Text(value.to_string().into()))
|
||||
.await
|
||||
.map_err(map_client_ws_send_error)?;
|
||||
}
|
||||
Ok(UpstreamWsMessage::Text(text)) => {
|
||||
browser_sender
|
||||
.send(ClientWsMessage::Text(text))
|
||||
.await
|
||||
.map_err(map_client_ws_send_error)?;
|
||||
}
|
||||
Ok(UpstreamWsMessage::Close(close)) => {
|
||||
let _ = browser_sender.send(ClientWsMessage::Close(close)).await;
|
||||
break;
|
||||
}
|
||||
Ok(UpstreamWsMessage::Ping(bytes)) => {
|
||||
browser_sender
|
||||
.send(ClientWsMessage::Ping(bytes))
|
||||
.await
|
||||
.map_err(map_client_ws_send_error)?;
|
||||
}
|
||||
Ok(UpstreamWsMessage::Pong(_)) => {}
|
||||
Ok(UpstreamWsMessage::Frame(_)) => {}
|
||||
Err(error) => {
|
||||
return Err(SpeechError::Upstream(format!(
|
||||
"读取火山 ASR WebSocket 失败:{error}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok::<(), SpeechError>(())
|
||||
};
|
||||
|
||||
let mut browser_to_upstream = Box::pin(browser_to_upstream);
|
||||
let mut upstream_to_browser = Box::pin(upstream_to_browser);
|
||||
let result = tokio::select! {
|
||||
result = &mut browser_to_upstream => result,
|
||||
result = &mut upstream_to_browser => result,
|
||||
};
|
||||
if let Err(error) = result {
|
||||
warn!(error = %error, "火山 ASR WebSocket 代理中断");
|
||||
}
|
||||
}
|
||||
|
||||
async fn proxy_tts_bidirection_websocket(socket: WebSocket, client: VolcengineSpeechClient) {
|
||||
let (mut browser_sender, mut browser_receiver) = socket.split();
|
||||
let Ok((upstream, response_headers)) = client.connect_tts_bidirection().await else {
|
||||
let _ = browser_sender
|
||||
.send(ClientWsMessage::Text(
|
||||
json!({
|
||||
"type": "error",
|
||||
"provider": PROVIDER,
|
||||
"message": "连接火山 TTS WebSocket 失败",
|
||||
})
|
||||
.to_string()
|
||||
.into(),
|
||||
))
|
||||
.await;
|
||||
return;
|
||||
};
|
||||
if let Some(log_id) = response_headers.get("x-tt-logid") {
|
||||
info!(%log_id, "火山 TTS WebSocket 已连接");
|
||||
}
|
||||
let (mut upstream_sender, mut upstream_receiver) = upstream.split();
|
||||
|
||||
let browser_to_upstream = async {
|
||||
while let Some(message) = browser_receiver.next().await {
|
||||
match message {
|
||||
Ok(ClientWsMessage::Text(text)) => {
|
||||
let event = serde_json::from_str::<TtsBidirectionClientEvent>(text.as_str())
|
||||
.map_err(|error| {
|
||||
SpeechError::InvalidFrame(format!(
|
||||
"TTS 浏览器事件 JSON 不合法:{error}"
|
||||
))
|
||||
})?;
|
||||
let frame = build_tts_bidirection_frame_from_client_event(event)?;
|
||||
upstream_sender
|
||||
.send(UpstreamWsMessage::Binary(frame.into()))
|
||||
.await
|
||||
.map_err(map_ws_send_error)?;
|
||||
}
|
||||
Ok(ClientWsMessage::Close(_)) => break,
|
||||
Ok(ClientWsMessage::Ping(bytes)) => {
|
||||
upstream_sender
|
||||
.send(UpstreamWsMessage::Ping(bytes))
|
||||
.await
|
||||
.map_err(map_ws_send_error)?;
|
||||
}
|
||||
Ok(ClientWsMessage::Binary(_)) | Ok(ClientWsMessage::Pong(_)) => {}
|
||||
Err(error) => {
|
||||
return Err(SpeechError::Upstream(format!(
|
||||
"读取浏览器 TTS WebSocket 失败:{error}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok::<(), SpeechError>(())
|
||||
};
|
||||
|
||||
let upstream_to_browser = async {
|
||||
while let Some(message) = upstream_receiver.next().await {
|
||||
match message {
|
||||
Ok(UpstreamWsMessage::Binary(bytes)) => {
|
||||
let parsed = parse_tts_response_frame(&bytes)?;
|
||||
if let Some(audio) = parsed.audio.clone() {
|
||||
browser_sender
|
||||
.send(ClientWsMessage::Binary(audio.into()))
|
||||
.await
|
||||
.map_err(map_client_ws_send_error)?;
|
||||
}
|
||||
if parsed.payload.is_some() || parsed.error_code.is_some() {
|
||||
browser_sender
|
||||
.send(ClientWsMessage::Text(
|
||||
tts_response_to_client_value(&parsed).to_string().into(),
|
||||
))
|
||||
.await
|
||||
.map_err(map_client_ws_send_error)?;
|
||||
}
|
||||
}
|
||||
Ok(UpstreamWsMessage::Text(text)) => {
|
||||
browser_sender
|
||||
.send(ClientWsMessage::Text(text))
|
||||
.await
|
||||
.map_err(map_client_ws_send_error)?;
|
||||
}
|
||||
Ok(UpstreamWsMessage::Close(close)) => {
|
||||
let _ = browser_sender.send(ClientWsMessage::Close(close)).await;
|
||||
break;
|
||||
}
|
||||
Ok(UpstreamWsMessage::Ping(bytes)) => {
|
||||
browser_sender
|
||||
.send(ClientWsMessage::Ping(bytes))
|
||||
.await
|
||||
.map_err(map_client_ws_send_error)?;
|
||||
}
|
||||
Ok(UpstreamWsMessage::Pong(_)) => {}
|
||||
Ok(UpstreamWsMessage::Frame(_)) => {}
|
||||
Err(error) => {
|
||||
return Err(SpeechError::Upstream(format!(
|
||||
"读取火山 TTS WebSocket 失败:{error}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok::<(), SpeechError>(())
|
||||
};
|
||||
|
||||
let mut browser_to_upstream = Box::pin(browser_to_upstream);
|
||||
let mut upstream_to_browser = Box::pin(upstream_to_browser);
|
||||
let result = tokio::select! {
|
||||
result = &mut browser_to_upstream => result,
|
||||
result = &mut upstream_to_browser => result,
|
||||
};
|
||||
if let Err(error) = result {
|
||||
warn!(error = %error, "火山 TTS WebSocket 代理中断");
|
||||
}
|
||||
}
|
||||
|
||||
fn build_speech_client(state: &AppState) -> Result<VolcengineSpeechClient, SpeechError> {
|
||||
Ok(VolcengineSpeechClient::new(VolcengineSpeechConfig::new(
|
||||
state.config.volcengine_speech_api_key.clone(),
|
||||
state.config.volcengine_speech_app_id.clone(),
|
||||
state.config.volcengine_speech_access_key.clone(),
|
||||
state.config.volcengine_speech_asr_resource_id.clone(),
|
||||
state.config.volcengine_speech_tts_resource_id.clone(),
|
||||
state.config.volcengine_speech_asr_ws_url.clone(),
|
||||
state
|
||||
.config
|
||||
.volcengine_speech_tts_bidirection_ws_url
|
||||
.clone(),
|
||||
state.config.volcengine_speech_tts_sse_url.clone(),
|
||||
state.config.volcengine_speech_request_timeout_ms,
|
||||
)?))
|
||||
}
|
||||
|
||||
fn public_speech_config(state: &AppState) -> PublicSpeechConfig {
|
||||
PublicSpeechConfig {
|
||||
asr_resource_id: state.config.volcengine_speech_asr_resource_id.clone(),
|
||||
tts_resource_id: state.config.volcengine_speech_tts_resource_id.clone(),
|
||||
asr_audio: AsrAudioConfig::default(),
|
||||
tts_audio: TtsAudioParams::default(),
|
||||
endpoints: PublicSpeechEndpoints {
|
||||
asr_stream: "/api/speech/volcengine/asr/stream",
|
||||
tts_bidirection: "/api/speech/volcengine/tts/bidirection",
|
||||
tts_sse: "/api/speech/volcengine/tts/sse",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn map_speech_error(error: SpeechError) -> AppError {
|
||||
match error {
|
||||
SpeechError::InvalidConfig(message) => {
|
||||
AppError::from_status(StatusCode::SERVICE_UNAVAILABLE).with_details(json!({
|
||||
"provider": PROVIDER,
|
||||
"message": message,
|
||||
}))
|
||||
}
|
||||
SpeechError::InvalidHeader(message)
|
||||
| SpeechError::InvalidFrame(message)
|
||||
| SpeechError::Serialize(message)
|
||||
| SpeechError::Io(message)
|
||||
| SpeechError::Upstream(message) => AppError::from_status(StatusCode::BAD_GATEWAY)
|
||||
.with_details(json!({
|
||||
"provider": PROVIDER,
|
||||
"message": message,
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
fn map_ws_send_error(error: tokio_tungstenite::tungstenite::Error) -> SpeechError {
|
||||
SpeechError::Upstream(format!("发送火山语音 WebSocket 帧失败:{error}"))
|
||||
}
|
||||
|
||||
fn map_client_ws_send_error(error: axum::Error) -> SpeechError {
|
||||
SpeechError::Upstream(format!("发送浏览器语音 WebSocket 帧失败:{error}"))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{Request, StatusCode},
|
||||
};
|
||||
use http_body_util::BodyExt;
|
||||
use serde_json::Value;
|
||||
use tower::ServiceExt;
|
||||
|
||||
use super::*;
|
||||
use crate::{app::build_router, config::AppConfig, state::AppState};
|
||||
|
||||
#[tokio::test]
|
||||
async fn speech_config_route_requires_authentication() {
|
||||
let app = build_router(AppState::new(AppConfig::default()).expect("state should build"));
|
||||
|
||||
let response = app
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri("/api/speech/volcengine/config")
|
||||
.body(Body::empty())
|
||||
.expect("request should build"),
|
||||
)
|
||||
.await
|
||||
.expect("request should complete");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn speech_config_route_returns_no_secret_fields() {
|
||||
let mut config = AppConfig::default();
|
||||
config.volcengine_speech_api_key = Some("secret-key".to_string());
|
||||
let state = AppState::new(config).expect("state should build");
|
||||
state
|
||||
.seed_test_phone_user_with_password("13800138088", "Password123")
|
||||
.await;
|
||||
let app = build_router(state);
|
||||
let login_response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri("/api/auth/entry")
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from(
|
||||
json!({
|
||||
"phone": "13800138088",
|
||||
"password": "Password123"
|
||||
})
|
||||
.to_string(),
|
||||
))
|
||||
.expect("login request should build"),
|
||||
)
|
||||
.await
|
||||
.expect("login should complete");
|
||||
let login_body = login_response
|
||||
.into_body()
|
||||
.collect()
|
||||
.await
|
||||
.expect("login body should collect")
|
||||
.to_bytes();
|
||||
let login_payload: Value =
|
||||
serde_json::from_slice(&login_body).expect("login body should be json");
|
||||
let token = login_payload["token"].as_str().expect("token should exist");
|
||||
|
||||
let response = app
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri("/api/speech/volcengine/config")
|
||||
.header("authorization", format!("Bearer {token}"))
|
||||
.body(Body::empty())
|
||||
.expect("request should build"),
|
||||
)
|
||||
.await
|
||||
.expect("request should complete");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
let body = response
|
||||
.into_body()
|
||||
.collect()
|
||||
.await
|
||||
.expect("body should collect")
|
||||
.to_bytes();
|
||||
let payload_text = String::from_utf8_lossy(&body);
|
||||
assert!(!payload_text.contains("secret-key"));
|
||||
assert!(!payload_text.contains("apiKey"));
|
||||
assert!(payload_text.contains("asrResourceId"));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user