use axum::{ Json, body::Body, extract::{ State, ws::{Message as ClientWsMessage, WebSocket, WebSocketUpgrade}, }, http::{HeaderValue, StatusCode, header}, response::Response, }; use futures_util::{SinkExt, StreamExt, TryStreamExt}; use platform_speech::{ AsrAudioConfig, AsrFrameKind, PublicSpeechConfig, PublicSpeechEndpoints, SpeechError, TtsAudioParams, TtsBidirectionClientEvent, TtsSseRequest, UpstreamWsError, UpstreamWsMessage, VolcengineSpeechClient, VolcengineSpeechConfig, build_asr_frame, build_asr_full_client_request, build_tts_bidirection_frame_from_client_event, default_asr_request_payload, parse_asr_response_frame, parse_tts_response_frame, tts_response_to_client_value, }; use serde_json::{Value, json}; use tracing::{info, warn}; use crate::{ api_response::json_success_body, auth::AuthenticatedAccessToken, http_error::AppError, request_context::RequestContext, state::AppState, }; const PROVIDER: &str = "volcengine-speech"; pub async fn get_volcengine_speech_config( State(state): State, axum::extract::Extension(request_context): axum::extract::Extension, ) -> Json { json_success_body(Some(&request_context), public_speech_config(&state)) } pub async fn stream_volcengine_asr( State(state): State, axum::extract::Extension(authenticated): axum::extract::Extension, ws: WebSocketUpgrade, ) -> Result { let client = build_speech_client(&state) .map_err(|error| map_speech_error(error).into_response_with_context(None))?; let user_id = authenticated.claims().user_id().to_string(); Ok(ws.on_upgrade(move |socket| proxy_asr_websocket(socket, client, user_id))) } pub async fn stream_volcengine_tts_bidirection( State(state): State, ws: WebSocketUpgrade, ) -> Result { let client = build_speech_client(&state) .map_err(|error| map_speech_error(error).into_response_with_context(None))?; Ok(ws.on_upgrade(move |socket| proxy_tts_bidirection_websocket(socket, client))) } pub async fn stream_volcengine_tts_sse( State(state): State, axum::extract::Extension(request_context): axum::extract::Extension, axum::extract::Extension(authenticated): axum::extract::Extension, payload: Result, axum::extract::rejection::JsonRejection>, ) -> Result { let Json(payload) = payload.map_err(|rejection| { AppError::from_status(StatusCode::BAD_REQUEST) .with_message(format!("请求体 JSON 不合法:{rejection}")) .into_response_with_context(Some(&request_context)) })?; let client = build_speech_client(&state).map_err(|error| { map_speech_error(error).into_response_with_context(Some(&request_context)) })?; let upstream_request = client .build_tts_sse_upstream_request(payload, authenticated.claims().user_id()) .map_err(|error| { map_speech_error(error).into_response_with_context(Some(&request_context)) })?; let http_client = reqwest::Client::builder() .timeout(upstream_request.timeout) .build() .map_err(|error| { AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR) .with_details(json!({ "provider": PROVIDER, "message": format!("构造火山语音 HTTP 客户端失败:{error}"), })) .into_response_with_context(Some(&request_context)) })?; let upstream_response = http_client .post(upstream_request.url) .headers(upstream_request.headers) .json(&upstream_request.body) .send() .await .map_err(|error| { AppError::from_status(StatusCode::BAD_GATEWAY) .with_details(json!({ "provider": PROVIDER, "message": format!("请求火山 TTS SSE 失败:{error}"), })) .into_response_with_context(Some(&request_context)) })?; let status = upstream_response.status(); let log_id = upstream_response .headers() .get("X-Tt-Logid") .and_then(|value| value.to_str().ok()) .map(ToOwned::to_owned); if !status.is_success() { let raw_text = upstream_response.text().await.unwrap_or_default(); return Err(AppError::from_status(StatusCode::BAD_GATEWAY) .with_details(json!({ "provider": PROVIDER, "status": status.as_u16(), "logId": log_id, "rawExcerpt": raw_text.chars().take(800).collect::(), })) .into_response_with_context(Some(&request_context))); } let byte_stream = upstream_response .bytes_stream() .map_err(std::io::Error::other); let mut response = Response::new(Body::from_stream(byte_stream)); *response.status_mut() = StatusCode::OK; response.headers_mut().insert( header::CONTENT_TYPE, HeaderValue::from_static("text/event-stream; charset=utf-8"), ); response .headers_mut() .insert(header::CACHE_CONTROL, HeaderValue::from_static("no-cache")); if let Some(log_id) = log_id.and_then(|value| HeaderValue::from_str(&value).ok()) { response.headers_mut().insert("x-volcengine-logid", log_id); } Ok(response) } async fn proxy_asr_websocket(socket: WebSocket, client: VolcengineSpeechClient, user_id: String) { let (mut browser_sender, mut browser_receiver) = socket.split(); let Ok((upstream, response_headers)) = client.connect_asr().await else { let _ = browser_sender .send(ClientWsMessage::Text( json!({ "type": "error", "provider": PROVIDER, "message": "连接火山 ASR WebSocket 失败", }) .to_string() .into(), )) .await; return; }; if let Some(log_id) = response_headers.get("x-tt-logid") { info!(%log_id, "火山 ASR WebSocket 已连接"); } let (mut upstream_sender, mut upstream_receiver) = upstream.split(); let mut has_sent_start = false; let mut last_audio_sent = false; let browser_to_upstream = async { while let Some(message) = browser_receiver.next().await { match message { Ok(ClientWsMessage::Text(text)) => { let value = serde_json::from_str::(text.as_str()).unwrap_or_else(|_| { json!({ "request": { "context": text.as_str(), } }) }); if value .get("type") .and_then(Value::as_str) .is_some_and(|kind| kind.eq_ignore_ascii_case("finish")) { let frame = build_asr_frame(AsrFrameKind::LastAudio, &[])?; upstream_sender .send(UpstreamWsMessage::Binary(frame.into())) .await .map_err(map_ws_send_error)?; last_audio_sent = true; continue; } if !has_sent_start { let payload = default_asr_request_payload(&user_id, Some(value)); let frame = build_asr_full_client_request(&payload)?; upstream_sender .send(UpstreamWsMessage::Binary(frame.into())) .await .map_err(map_ws_send_error)?; has_sent_start = true; } } Ok(ClientWsMessage::Binary(bytes)) => { if !has_sent_start { let payload = default_asr_request_payload(&user_id, None); let frame = build_asr_full_client_request(&payload)?; upstream_sender .send(UpstreamWsMessage::Binary(frame.into())) .await .map_err(map_ws_send_error)?; has_sent_start = true; } let frame = build_asr_frame(AsrFrameKind::Audio, &bytes)?; upstream_sender .send(UpstreamWsMessage::Binary(frame.into())) .await .map_err(map_ws_send_error)?; } Ok(ClientWsMessage::Close(_)) => break, Ok(ClientWsMessage::Ping(bytes)) => { upstream_sender .send(UpstreamWsMessage::Ping(bytes)) .await .map_err(map_ws_send_error)?; } Ok(ClientWsMessage::Pong(_)) => {} Err(error) => { return Err(SpeechError::Upstream(format!( "读取浏览器 ASR WebSocket 失败:{error}" ))); } } } if has_sent_start && !last_audio_sent { let frame = build_asr_frame(AsrFrameKind::LastAudio, &[])?; let _ = upstream_sender .send(UpstreamWsMessage::Binary(frame.into())) .await; } Ok::<(), SpeechError>(()) }; let upstream_to_browser = async { while let Some(message) = upstream_receiver.next().await { match message { Ok(UpstreamWsMessage::Binary(bytes)) => { let parsed = parse_asr_response_frame(&bytes)?; let value = json!({ "type": "asr_response", "sequence": parsed.sequence, "payload": parsed.payload, "errorCode": parsed.error_code, }); browser_sender .send(ClientWsMessage::Text(value.to_string().into())) .await .map_err(map_client_ws_send_error)?; } Ok(UpstreamWsMessage::Text(text)) => { browser_sender .send(ClientWsMessage::Text(text.to_string().into())) .await .map_err(map_client_ws_send_error)?; } Ok(UpstreamWsMessage::Close(_)) => { let _ = browser_sender.send(ClientWsMessage::Close(None)).await; break; } Ok(UpstreamWsMessage::Ping(bytes)) => { browser_sender .send(ClientWsMessage::Ping(bytes)) .await .map_err(map_client_ws_send_error)?; } Ok(UpstreamWsMessage::Pong(_)) => {} Ok(UpstreamWsMessage::Frame(_)) => {} Err(error) => { return Err(SpeechError::Upstream(format!( "读取火山 ASR WebSocket 失败:{error}" ))); } } } Ok::<(), SpeechError>(()) }; let mut browser_to_upstream = Box::pin(browser_to_upstream); let mut upstream_to_browser = Box::pin(upstream_to_browser); let result = tokio::select! { result = &mut browser_to_upstream => result, result = &mut upstream_to_browser => result, }; if let Err(error) = result { warn!(error = %error, "火山 ASR WebSocket 代理中断"); } } async fn proxy_tts_bidirection_websocket(socket: WebSocket, client: VolcengineSpeechClient) { let (mut browser_sender, mut browser_receiver) = socket.split(); let Ok((upstream, response_headers)) = client.connect_tts_bidirection().await else { let _ = browser_sender .send(ClientWsMessage::Text( json!({ "type": "error", "provider": PROVIDER, "message": "连接火山 TTS WebSocket 失败", }) .to_string() .into(), )) .await; return; }; if let Some(log_id) = response_headers.get("x-tt-logid") { info!(%log_id, "火山 TTS WebSocket 已连接"); } let (mut upstream_sender, mut upstream_receiver) = upstream.split(); let browser_to_upstream = async { while let Some(message) = browser_receiver.next().await { match message { Ok(ClientWsMessage::Text(text)) => { let event = serde_json::from_str::(text.as_str()) .map_err(|error| { SpeechError::InvalidFrame(format!( "TTS 浏览器事件 JSON 不合法:{error}" )) })?; let frame = build_tts_bidirection_frame_from_client_event(event)?; upstream_sender .send(UpstreamWsMessage::Binary(frame.into())) .await .map_err(map_ws_send_error)?; } Ok(ClientWsMessage::Close(_)) => break, Ok(ClientWsMessage::Ping(bytes)) => { upstream_sender .send(UpstreamWsMessage::Ping(bytes)) .await .map_err(map_ws_send_error)?; } Ok(ClientWsMessage::Binary(_)) | Ok(ClientWsMessage::Pong(_)) => {} Err(error) => { return Err(SpeechError::Upstream(format!( "读取浏览器 TTS WebSocket 失败:{error}" ))); } } } Ok::<(), SpeechError>(()) }; let upstream_to_browser = async { while let Some(message) = upstream_receiver.next().await { match message { Ok(UpstreamWsMessage::Binary(bytes)) => { let parsed = parse_tts_response_frame(&bytes)?; if let Some(audio) = parsed.audio.clone() { browser_sender .send(ClientWsMessage::Binary(audio.into())) .await .map_err(map_client_ws_send_error)?; } if parsed.payload.is_some() || parsed.error_code.is_some() { browser_sender .send(ClientWsMessage::Text( tts_response_to_client_value(&parsed).to_string().into(), )) .await .map_err(map_client_ws_send_error)?; } } Ok(UpstreamWsMessage::Text(text)) => { browser_sender .send(ClientWsMessage::Text(text.to_string().into())) .await .map_err(map_client_ws_send_error)?; } Ok(UpstreamWsMessage::Close(_)) => { let _ = browser_sender.send(ClientWsMessage::Close(None)).await; break; } Ok(UpstreamWsMessage::Ping(bytes)) => { browser_sender .send(ClientWsMessage::Ping(bytes)) .await .map_err(map_client_ws_send_error)?; } Ok(UpstreamWsMessage::Pong(_)) => {} Ok(UpstreamWsMessage::Frame(_)) => {} Err(error) => { return Err(SpeechError::Upstream(format!( "读取火山 TTS WebSocket 失败:{error}" ))); } } } Ok::<(), SpeechError>(()) }; let mut browser_to_upstream = Box::pin(browser_to_upstream); let mut upstream_to_browser = Box::pin(upstream_to_browser); let result = tokio::select! { result = &mut browser_to_upstream => result, result = &mut upstream_to_browser => result, }; if let Err(error) = result { warn!(error = %error, "火山 TTS WebSocket 代理中断"); } } fn build_speech_client(state: &AppState) -> Result { Ok(VolcengineSpeechClient::new(VolcengineSpeechConfig::new( state.config.volcengine_speech_api_key.clone(), state.config.volcengine_speech_app_id.clone(), state.config.volcengine_speech_access_key.clone(), state.config.volcengine_speech_asr_resource_id.clone(), state.config.volcengine_speech_tts_resource_id.clone(), state.config.volcengine_speech_asr_ws_url.clone(), state .config .volcengine_speech_tts_bidirection_ws_url .clone(), state.config.volcengine_speech_tts_sse_url.clone(), state.config.volcengine_speech_request_timeout_ms, )?)) } fn public_speech_config(state: &AppState) -> PublicSpeechConfig { PublicSpeechConfig { asr_resource_id: state.config.volcengine_speech_asr_resource_id.clone(), tts_resource_id: state.config.volcengine_speech_tts_resource_id.clone(), asr_audio: AsrAudioConfig::default(), tts_audio: TtsAudioParams::default(), endpoints: PublicSpeechEndpoints { asr_stream: "/api/speech/volcengine/asr/stream", tts_bidirection: "/api/speech/volcengine/tts/bidirection", tts_sse: "/api/speech/volcengine/tts/sse", }, } } fn map_speech_error(error: SpeechError) -> AppError { match error { SpeechError::InvalidConfig(message) => { AppError::from_status(StatusCode::SERVICE_UNAVAILABLE).with_details(json!({ "provider": PROVIDER, "message": message, })) } SpeechError::InvalidHeader(message) | SpeechError::InvalidFrame(message) | SpeechError::Serialize(message) | SpeechError::Io(message) | SpeechError::Upstream(message) => AppError::from_status(StatusCode::BAD_GATEWAY) .with_details(json!({ "provider": PROVIDER, "message": message, })), } } fn map_ws_send_error(error: UpstreamWsError) -> SpeechError { SpeechError::Upstream(format!("发送火山语音 WebSocket 帧失败:{error}")) } fn map_client_ws_send_error(error: axum::Error) -> SpeechError { SpeechError::Upstream(format!("发送浏览器语音 WebSocket 帧失败:{error}")) } #[cfg(test)] mod tests { use axum::{ body::Body, http::{Request, StatusCode}, }; use http_body_util::BodyExt; use serde_json::Value; use tower::ServiceExt; use super::*; use crate::{app::build_router, config::AppConfig, state::AppState}; #[tokio::test] async fn speech_config_route_requires_authentication() { let app = build_router(AppState::new(AppConfig::default()).expect("state should build")); let response = app .oneshot( Request::builder() .uri("/api/speech/volcengine/config") .body(Body::empty()) .expect("request should build"), ) .await .expect("request should complete"); assert_eq!(response.status(), StatusCode::UNAUTHORIZED); } #[tokio::test] async fn speech_config_route_returns_no_secret_fields() { let mut config = AppConfig::default(); config.volcengine_speech_api_key = Some("secret-key".to_string()); let state = AppState::new(config).expect("state should build"); state .seed_test_phone_user_with_password("13800138088", "Password123") .await; let app = build_router(state); let login_response = app .clone() .oneshot( Request::builder() .method("POST") .uri("/api/auth/entry") .header("content-type", "application/json") .body(Body::from( json!({ "phone": "13800138088", "password": "Password123" }) .to_string(), )) .expect("login request should build"), ) .await .expect("login should complete"); let login_body = login_response .into_body() .collect() .await .expect("login body should collect") .to_bytes(); let login_payload: Value = serde_json::from_slice(&login_body).expect("login body should be json"); let token = login_payload["token"].as_str().expect("token should exist"); let response = app .oneshot( Request::builder() .uri("/api/speech/volcengine/config") .header("authorization", format!("Bearer {token}")) .body(Body::empty()) .expect("request should build"), ) .await .expect("request should complete"); assert_eq!(response.status(), StatusCode::OK); let body = response .into_body() .collect() .await .expect("body should collect") .to_bytes(); let payload_text = String::from_utf8_lossy(&body); assert!(!payload_text.contains("secret-key")); assert!(!payload_text.contains("apiKey")); assert!(payload_text.contains("asrResourceId")); } }