552 lines
22 KiB
Rust
552 lines
22 KiB
Rust
use axum::{
|
|
Json,
|
|
body::Body,
|
|
extract::{
|
|
State,
|
|
ws::{Message as ClientWsMessage, WebSocket, WebSocketUpgrade},
|
|
},
|
|
http::{HeaderValue, StatusCode, header},
|
|
response::Response,
|
|
};
|
|
use futures_util::{SinkExt, StreamExt, TryStreamExt};
|
|
use platform_speech::{
|
|
AsrAudioConfig, AsrFrameKind, PublicSpeechConfig, PublicSpeechEndpoints, SpeechError,
|
|
TtsAudioParams, TtsBidirectionClientEvent, TtsSseRequest, UpstreamWsError, UpstreamWsMessage,
|
|
VolcengineSpeechClient, VolcengineSpeechConfig, build_asr_frame, build_asr_full_client_request,
|
|
build_tts_bidirection_frame_from_client_event, default_asr_request_payload,
|
|
parse_asr_response_frame, parse_tts_response_frame, tts_response_to_client_value,
|
|
};
|
|
use serde_json::{Value, json};
|
|
use tracing::{info, warn};
|
|
|
|
use crate::{
|
|
api_response::json_success_body, auth::AuthenticatedAccessToken, http_error::AppError,
|
|
request_context::RequestContext, state::AppState,
|
|
};
|
|
|
|
const PROVIDER: &str = "volcengine-speech";
|
|
|
|
pub async fn get_volcengine_speech_config(
|
|
State(state): State<AppState>,
|
|
axum::extract::Extension(request_context): axum::extract::Extension<RequestContext>,
|
|
) -> Json<Value> {
|
|
json_success_body(Some(&request_context), public_speech_config(&state))
|
|
}
|
|
|
|
pub async fn stream_volcengine_asr(
|
|
State(state): State<AppState>,
|
|
axum::extract::Extension(authenticated): axum::extract::Extension<AuthenticatedAccessToken>,
|
|
ws: WebSocketUpgrade,
|
|
) -> Result<Response, Response> {
|
|
let client = build_speech_client(&state)
|
|
.map_err(|error| map_speech_error(error).into_response_with_context(None))?;
|
|
let user_id = authenticated.claims().user_id().to_string();
|
|
Ok(ws.on_upgrade(move |socket| proxy_asr_websocket(socket, client, user_id)))
|
|
}
|
|
|
|
pub async fn stream_volcengine_tts_bidirection(
|
|
State(state): State<AppState>,
|
|
ws: WebSocketUpgrade,
|
|
) -> Result<Response, Response> {
|
|
let client = build_speech_client(&state)
|
|
.map_err(|error| map_speech_error(error).into_response_with_context(None))?;
|
|
Ok(ws.on_upgrade(move |socket| proxy_tts_bidirection_websocket(socket, client)))
|
|
}
|
|
|
|
pub async fn stream_volcengine_tts_sse(
|
|
State(state): State<AppState>,
|
|
axum::extract::Extension(request_context): axum::extract::Extension<RequestContext>,
|
|
axum::extract::Extension(authenticated): axum::extract::Extension<AuthenticatedAccessToken>,
|
|
payload: Result<Json<TtsSseRequest>, axum::extract::rejection::JsonRejection>,
|
|
) -> Result<Response, Response> {
|
|
let Json(payload) = payload.map_err(|rejection| {
|
|
AppError::from_status(StatusCode::BAD_REQUEST)
|
|
.with_message(format!("请求体 JSON 不合法:{rejection}"))
|
|
.into_response_with_context(Some(&request_context))
|
|
})?;
|
|
let client = build_speech_client(&state).map_err(|error| {
|
|
map_speech_error(error).into_response_with_context(Some(&request_context))
|
|
})?;
|
|
let upstream_request = client
|
|
.build_tts_sse_upstream_request(payload, authenticated.claims().user_id())
|
|
.map_err(|error| {
|
|
map_speech_error(error).into_response_with_context(Some(&request_context))
|
|
})?;
|
|
let http_client = reqwest::Client::builder()
|
|
.timeout(upstream_request.timeout)
|
|
.build()
|
|
.map_err(|error| {
|
|
AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR)
|
|
.with_details(json!({
|
|
"provider": PROVIDER,
|
|
"message": format!("构造火山语音 HTTP 客户端失败:{error}"),
|
|
}))
|
|
.into_response_with_context(Some(&request_context))
|
|
})?;
|
|
let upstream_response = http_client
|
|
.post(upstream_request.url)
|
|
.headers(upstream_request.headers)
|
|
.json(&upstream_request.body)
|
|
.send()
|
|
.await
|
|
.map_err(|error| {
|
|
AppError::from_status(StatusCode::BAD_GATEWAY)
|
|
.with_details(json!({
|
|
"provider": PROVIDER,
|
|
"message": format!("请求火山 TTS SSE 失败:{error}"),
|
|
}))
|
|
.into_response_with_context(Some(&request_context))
|
|
})?;
|
|
let status = upstream_response.status();
|
|
let log_id = upstream_response
|
|
.headers()
|
|
.get("X-Tt-Logid")
|
|
.and_then(|value| value.to_str().ok())
|
|
.map(ToOwned::to_owned);
|
|
if !status.is_success() {
|
|
let raw_text = upstream_response.text().await.unwrap_or_default();
|
|
return Err(AppError::from_status(StatusCode::BAD_GATEWAY)
|
|
.with_details(json!({
|
|
"provider": PROVIDER,
|
|
"status": status.as_u16(),
|
|
"logId": log_id,
|
|
"rawExcerpt": raw_text.chars().take(800).collect::<String>(),
|
|
}))
|
|
.into_response_with_context(Some(&request_context)));
|
|
}
|
|
|
|
let byte_stream = upstream_response
|
|
.bytes_stream()
|
|
.map_err(std::io::Error::other);
|
|
let mut response = Response::new(Body::from_stream(byte_stream));
|
|
*response.status_mut() = StatusCode::OK;
|
|
response.headers_mut().insert(
|
|
header::CONTENT_TYPE,
|
|
HeaderValue::from_static("text/event-stream; charset=utf-8"),
|
|
);
|
|
response
|
|
.headers_mut()
|
|
.insert(header::CACHE_CONTROL, HeaderValue::from_static("no-cache"));
|
|
if let Some(log_id) = log_id.and_then(|value| HeaderValue::from_str(&value).ok()) {
|
|
response.headers_mut().insert("x-volcengine-logid", log_id);
|
|
}
|
|
Ok(response)
|
|
}
|
|
|
|
async fn proxy_asr_websocket(socket: WebSocket, client: VolcengineSpeechClient, user_id: String) {
|
|
let (mut browser_sender, mut browser_receiver) = socket.split();
|
|
let Ok((upstream, response_headers)) = client.connect_asr().await else {
|
|
let _ = browser_sender
|
|
.send(ClientWsMessage::Text(
|
|
json!({
|
|
"type": "error",
|
|
"provider": PROVIDER,
|
|
"message": "连接火山 ASR WebSocket 失败",
|
|
})
|
|
.to_string()
|
|
.into(),
|
|
))
|
|
.await;
|
|
return;
|
|
};
|
|
if let Some(log_id) = response_headers.get("x-tt-logid") {
|
|
info!(%log_id, "火山 ASR WebSocket 已连接");
|
|
}
|
|
let (mut upstream_sender, mut upstream_receiver) = upstream.split();
|
|
let mut has_sent_start = false;
|
|
let mut last_audio_sent = false;
|
|
|
|
let browser_to_upstream = async {
|
|
while let Some(message) = browser_receiver.next().await {
|
|
match message {
|
|
Ok(ClientWsMessage::Text(text)) => {
|
|
let value = serde_json::from_str::<Value>(text.as_str()).unwrap_or_else(|_| {
|
|
json!({
|
|
"request": {
|
|
"context": text.as_str(),
|
|
}
|
|
})
|
|
});
|
|
if value
|
|
.get("type")
|
|
.and_then(Value::as_str)
|
|
.is_some_and(|kind| kind.eq_ignore_ascii_case("finish"))
|
|
{
|
|
let frame = build_asr_frame(AsrFrameKind::LastAudio, &[])?;
|
|
upstream_sender
|
|
.send(UpstreamWsMessage::Binary(frame.into()))
|
|
.await
|
|
.map_err(map_ws_send_error)?;
|
|
last_audio_sent = true;
|
|
continue;
|
|
}
|
|
if !has_sent_start {
|
|
let payload = default_asr_request_payload(&user_id, Some(value));
|
|
let frame = build_asr_full_client_request(&payload)?;
|
|
upstream_sender
|
|
.send(UpstreamWsMessage::Binary(frame.into()))
|
|
.await
|
|
.map_err(map_ws_send_error)?;
|
|
has_sent_start = true;
|
|
}
|
|
}
|
|
Ok(ClientWsMessage::Binary(bytes)) => {
|
|
if !has_sent_start {
|
|
let payload = default_asr_request_payload(&user_id, None);
|
|
let frame = build_asr_full_client_request(&payload)?;
|
|
upstream_sender
|
|
.send(UpstreamWsMessage::Binary(frame.into()))
|
|
.await
|
|
.map_err(map_ws_send_error)?;
|
|
has_sent_start = true;
|
|
}
|
|
let frame = build_asr_frame(AsrFrameKind::Audio, &bytes)?;
|
|
upstream_sender
|
|
.send(UpstreamWsMessage::Binary(frame.into()))
|
|
.await
|
|
.map_err(map_ws_send_error)?;
|
|
}
|
|
Ok(ClientWsMessage::Close(_)) => break,
|
|
Ok(ClientWsMessage::Ping(bytes)) => {
|
|
upstream_sender
|
|
.send(UpstreamWsMessage::Ping(bytes))
|
|
.await
|
|
.map_err(map_ws_send_error)?;
|
|
}
|
|
Ok(ClientWsMessage::Pong(_)) => {}
|
|
Err(error) => {
|
|
return Err(SpeechError::Upstream(format!(
|
|
"读取浏览器 ASR WebSocket 失败:{error}"
|
|
)));
|
|
}
|
|
}
|
|
}
|
|
if has_sent_start && !last_audio_sent {
|
|
let frame = build_asr_frame(AsrFrameKind::LastAudio, &[])?;
|
|
let _ = upstream_sender
|
|
.send(UpstreamWsMessage::Binary(frame.into()))
|
|
.await;
|
|
}
|
|
Ok::<(), SpeechError>(())
|
|
};
|
|
|
|
let upstream_to_browser = async {
|
|
while let Some(message) = upstream_receiver.next().await {
|
|
match message {
|
|
Ok(UpstreamWsMessage::Binary(bytes)) => {
|
|
let parsed = parse_asr_response_frame(&bytes)?;
|
|
let value = json!({
|
|
"type": "asr_response",
|
|
"sequence": parsed.sequence,
|
|
"payload": parsed.payload,
|
|
"errorCode": parsed.error_code,
|
|
});
|
|
browser_sender
|
|
.send(ClientWsMessage::Text(value.to_string().into()))
|
|
.await
|
|
.map_err(map_client_ws_send_error)?;
|
|
}
|
|
Ok(UpstreamWsMessage::Text(text)) => {
|
|
browser_sender
|
|
.send(ClientWsMessage::Text(text.to_string().into()))
|
|
.await
|
|
.map_err(map_client_ws_send_error)?;
|
|
}
|
|
Ok(UpstreamWsMessage::Close(_)) => {
|
|
let _ = browser_sender.send(ClientWsMessage::Close(None)).await;
|
|
break;
|
|
}
|
|
Ok(UpstreamWsMessage::Ping(bytes)) => {
|
|
browser_sender
|
|
.send(ClientWsMessage::Ping(bytes))
|
|
.await
|
|
.map_err(map_client_ws_send_error)?;
|
|
}
|
|
Ok(UpstreamWsMessage::Pong(_)) => {}
|
|
Ok(UpstreamWsMessage::Frame(_)) => {}
|
|
Err(error) => {
|
|
return Err(SpeechError::Upstream(format!(
|
|
"读取火山 ASR WebSocket 失败:{error}"
|
|
)));
|
|
}
|
|
}
|
|
}
|
|
Ok::<(), SpeechError>(())
|
|
};
|
|
|
|
let mut browser_to_upstream = Box::pin(browser_to_upstream);
|
|
let mut upstream_to_browser = Box::pin(upstream_to_browser);
|
|
let result = tokio::select! {
|
|
result = &mut browser_to_upstream => result,
|
|
result = &mut upstream_to_browser => result,
|
|
};
|
|
if let Err(error) = result {
|
|
warn!(error = %error, "火山 ASR WebSocket 代理中断");
|
|
}
|
|
}
|
|
|
|
async fn proxy_tts_bidirection_websocket(socket: WebSocket, client: VolcengineSpeechClient) {
|
|
let (mut browser_sender, mut browser_receiver) = socket.split();
|
|
let Ok((upstream, response_headers)) = client.connect_tts_bidirection().await else {
|
|
let _ = browser_sender
|
|
.send(ClientWsMessage::Text(
|
|
json!({
|
|
"type": "error",
|
|
"provider": PROVIDER,
|
|
"message": "连接火山 TTS WebSocket 失败",
|
|
})
|
|
.to_string()
|
|
.into(),
|
|
))
|
|
.await;
|
|
return;
|
|
};
|
|
if let Some(log_id) = response_headers.get("x-tt-logid") {
|
|
info!(%log_id, "火山 TTS WebSocket 已连接");
|
|
}
|
|
let (mut upstream_sender, mut upstream_receiver) = upstream.split();
|
|
|
|
let browser_to_upstream = async {
|
|
while let Some(message) = browser_receiver.next().await {
|
|
match message {
|
|
Ok(ClientWsMessage::Text(text)) => {
|
|
let event = serde_json::from_str::<TtsBidirectionClientEvent>(text.as_str())
|
|
.map_err(|error| {
|
|
SpeechError::InvalidFrame(format!(
|
|
"TTS 浏览器事件 JSON 不合法:{error}"
|
|
))
|
|
})?;
|
|
let frame = build_tts_bidirection_frame_from_client_event(event)?;
|
|
upstream_sender
|
|
.send(UpstreamWsMessage::Binary(frame.into()))
|
|
.await
|
|
.map_err(map_ws_send_error)?;
|
|
}
|
|
Ok(ClientWsMessage::Close(_)) => break,
|
|
Ok(ClientWsMessage::Ping(bytes)) => {
|
|
upstream_sender
|
|
.send(UpstreamWsMessage::Ping(bytes))
|
|
.await
|
|
.map_err(map_ws_send_error)?;
|
|
}
|
|
Ok(ClientWsMessage::Binary(_)) | Ok(ClientWsMessage::Pong(_)) => {}
|
|
Err(error) => {
|
|
return Err(SpeechError::Upstream(format!(
|
|
"读取浏览器 TTS WebSocket 失败:{error}"
|
|
)));
|
|
}
|
|
}
|
|
}
|
|
Ok::<(), SpeechError>(())
|
|
};
|
|
|
|
let upstream_to_browser = async {
|
|
while let Some(message) = upstream_receiver.next().await {
|
|
match message {
|
|
Ok(UpstreamWsMessage::Binary(bytes)) => {
|
|
let parsed = parse_tts_response_frame(&bytes)?;
|
|
if let Some(audio) = parsed.audio.clone() {
|
|
browser_sender
|
|
.send(ClientWsMessage::Binary(audio.into()))
|
|
.await
|
|
.map_err(map_client_ws_send_error)?;
|
|
}
|
|
if parsed.payload.is_some() || parsed.error_code.is_some() {
|
|
browser_sender
|
|
.send(ClientWsMessage::Text(
|
|
tts_response_to_client_value(&parsed).to_string().into(),
|
|
))
|
|
.await
|
|
.map_err(map_client_ws_send_error)?;
|
|
}
|
|
}
|
|
Ok(UpstreamWsMessage::Text(text)) => {
|
|
browser_sender
|
|
.send(ClientWsMessage::Text(text.to_string().into()))
|
|
.await
|
|
.map_err(map_client_ws_send_error)?;
|
|
}
|
|
Ok(UpstreamWsMessage::Close(_)) => {
|
|
let _ = browser_sender.send(ClientWsMessage::Close(None)).await;
|
|
break;
|
|
}
|
|
Ok(UpstreamWsMessage::Ping(bytes)) => {
|
|
browser_sender
|
|
.send(ClientWsMessage::Ping(bytes))
|
|
.await
|
|
.map_err(map_client_ws_send_error)?;
|
|
}
|
|
Ok(UpstreamWsMessage::Pong(_)) => {}
|
|
Ok(UpstreamWsMessage::Frame(_)) => {}
|
|
Err(error) => {
|
|
return Err(SpeechError::Upstream(format!(
|
|
"读取火山 TTS WebSocket 失败:{error}"
|
|
)));
|
|
}
|
|
}
|
|
}
|
|
Ok::<(), SpeechError>(())
|
|
};
|
|
|
|
let mut browser_to_upstream = Box::pin(browser_to_upstream);
|
|
let mut upstream_to_browser = Box::pin(upstream_to_browser);
|
|
let result = tokio::select! {
|
|
result = &mut browser_to_upstream => result,
|
|
result = &mut upstream_to_browser => result,
|
|
};
|
|
if let Err(error) = result {
|
|
warn!(error = %error, "火山 TTS WebSocket 代理中断");
|
|
}
|
|
}
|
|
|
|
fn build_speech_client(state: &AppState) -> Result<VolcengineSpeechClient, SpeechError> {
|
|
Ok(VolcengineSpeechClient::new(VolcengineSpeechConfig::new(
|
|
state.config.volcengine_speech_api_key.clone(),
|
|
state.config.volcengine_speech_app_id.clone(),
|
|
state.config.volcengine_speech_access_key.clone(),
|
|
state.config.volcengine_speech_asr_resource_id.clone(),
|
|
state.config.volcengine_speech_tts_resource_id.clone(),
|
|
state.config.volcengine_speech_asr_ws_url.clone(),
|
|
state
|
|
.config
|
|
.volcengine_speech_tts_bidirection_ws_url
|
|
.clone(),
|
|
state.config.volcengine_speech_tts_sse_url.clone(),
|
|
state.config.volcengine_speech_request_timeout_ms,
|
|
)?))
|
|
}
|
|
|
|
fn public_speech_config(state: &AppState) -> PublicSpeechConfig {
|
|
PublicSpeechConfig {
|
|
asr_resource_id: state.config.volcengine_speech_asr_resource_id.clone(),
|
|
tts_resource_id: state.config.volcengine_speech_tts_resource_id.clone(),
|
|
asr_audio: AsrAudioConfig::default(),
|
|
tts_audio: TtsAudioParams::default(),
|
|
endpoints: PublicSpeechEndpoints {
|
|
asr_stream: "/api/speech/volcengine/asr/stream",
|
|
tts_bidirection: "/api/speech/volcengine/tts/bidirection",
|
|
tts_sse: "/api/speech/volcengine/tts/sse",
|
|
},
|
|
}
|
|
}
|
|
|
|
fn map_speech_error(error: SpeechError) -> AppError {
|
|
match error {
|
|
SpeechError::InvalidConfig(message) => {
|
|
AppError::from_status(StatusCode::SERVICE_UNAVAILABLE).with_details(json!({
|
|
"provider": PROVIDER,
|
|
"message": message,
|
|
}))
|
|
}
|
|
SpeechError::InvalidHeader(message)
|
|
| SpeechError::InvalidFrame(message)
|
|
| SpeechError::Serialize(message)
|
|
| SpeechError::Io(message)
|
|
| SpeechError::Upstream(message) => AppError::from_status(StatusCode::BAD_GATEWAY)
|
|
.with_details(json!({
|
|
"provider": PROVIDER,
|
|
"message": message,
|
|
})),
|
|
}
|
|
}
|
|
|
|
fn map_ws_send_error(error: UpstreamWsError) -> SpeechError {
|
|
SpeechError::Upstream(format!("发送火山语音 WebSocket 帧失败:{error}"))
|
|
}
|
|
|
|
fn map_client_ws_send_error(error: axum::Error) -> SpeechError {
|
|
SpeechError::Upstream(format!("发送浏览器语音 WebSocket 帧失败:{error}"))
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use axum::{
|
|
body::Body,
|
|
http::{Request, StatusCode},
|
|
};
|
|
use http_body_util::BodyExt;
|
|
use serde_json::Value;
|
|
use tower::ServiceExt;
|
|
|
|
use super::*;
|
|
use crate::{app::build_router, config::AppConfig, state::AppState};
|
|
|
|
#[tokio::test]
|
|
async fn speech_config_route_requires_authentication() {
|
|
let app = build_router(AppState::new(AppConfig::default()).expect("state should build"));
|
|
|
|
let response = app
|
|
.oneshot(
|
|
Request::builder()
|
|
.uri("/api/speech/volcengine/config")
|
|
.body(Body::empty())
|
|
.expect("request should build"),
|
|
)
|
|
.await
|
|
.expect("request should complete");
|
|
|
|
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn speech_config_route_returns_no_secret_fields() {
|
|
let mut config = AppConfig::default();
|
|
config.volcengine_speech_api_key = Some("secret-key".to_string());
|
|
let state = AppState::new(config).expect("state should build");
|
|
state
|
|
.seed_test_phone_user_with_password("13800138088", "Password123")
|
|
.await;
|
|
let app = build_router(state);
|
|
let login_response = app
|
|
.clone()
|
|
.oneshot(
|
|
Request::builder()
|
|
.method("POST")
|
|
.uri("/api/auth/entry")
|
|
.header("content-type", "application/json")
|
|
.body(Body::from(
|
|
json!({
|
|
"phone": "13800138088",
|
|
"password": "Password123"
|
|
})
|
|
.to_string(),
|
|
))
|
|
.expect("login request should build"),
|
|
)
|
|
.await
|
|
.expect("login should complete");
|
|
let login_body = login_response
|
|
.into_body()
|
|
.collect()
|
|
.await
|
|
.expect("login body should collect")
|
|
.to_bytes();
|
|
let login_payload: Value =
|
|
serde_json::from_slice(&login_body).expect("login body should be json");
|
|
let token = login_payload["token"].as_str().expect("token should exist");
|
|
|
|
let response = app
|
|
.oneshot(
|
|
Request::builder()
|
|
.uri("/api/speech/volcengine/config")
|
|
.header("authorization", format!("Bearer {token}"))
|
|
.body(Body::empty())
|
|
.expect("request should build"),
|
|
)
|
|
.await
|
|
.expect("request should complete");
|
|
|
|
assert_eq!(response.status(), StatusCode::OK);
|
|
let body = response
|
|
.into_body()
|
|
.collect()
|
|
.await
|
|
.expect("body should collect")
|
|
.to_bytes();
|
|
let payload_text = String::from_utf8_lossy(&body);
|
|
assert!(!payload_text.contains("secret-key"));
|
|
assert!(!payload_text.contains("apiKey"));
|
|
assert!(payload_text.contains("asrResourceId"));
|
|
}
|
|
}
|