use std::{ collections::BTreeMap, error::Error, fmt, io::{Read, Write}, time::Duration, }; use flate2::{Compression, read::GzDecoder, write::GzEncoder}; use futures_util::{SinkExt, StreamExt}; use reqwest::header::{HeaderMap, HeaderValue}; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; use tokio_tungstenite::{ MaybeTlsStream, WebSocketStream, connect_async, tungstenite::{Message, client::IntoClientRequest, http::Uri}, }; use uuid::Uuid; pub use tokio_tungstenite::tungstenite::{Error as UpstreamWsError, Message as UpstreamWsMessage}; pub const DEFAULT_ASR_WS_URL: &str = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async"; pub const DEFAULT_TTS_BIDIRECTION_WS_URL: &str = "wss://openspeech.bytedance.com/api/v3/tts/bidirection"; pub const DEFAULT_TTS_SSE_URL: &str = "https://openspeech.bytedance.com/api/v3/tts/unidirectional/sse"; pub const DEFAULT_ASR_RESOURCE_ID: &str = "volc.seedasr.sauc.concurrent"; pub const DEFAULT_TTS_RESOURCE_ID: &str = "seed-tts-2.0"; pub const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 180_000; const PROTOCOL_VERSION: u8 = 0b0001; const HEADER_SIZE_FOUR_BYTES: u8 = 0b0001; const SERIALIZATION_NONE: u8 = 0b0000; const SERIALIZATION_JSON: u8 = 0b0001; const COMPRESSION_NONE: u8 = 0b0000; const COMPRESSION_GZIP: u8 = 0b0001; const MESSAGE_FULL_CLIENT_REQUEST: u8 = 0b0001; const MESSAGE_AUDIO_ONLY_REQUEST: u8 = 0b0010; pub const MESSAGE_FULL_SERVER_RESPONSE: u8 = 0b1001; const MESSAGE_AUDIO_ONLY_RESPONSE: u8 = 0b1011; const MESSAGE_ERROR: u8 = 0b1111; const FLAG_NONE: u8 = 0b0000; const FLAG_WITH_SEQUENCE: u8 = 0b0001; const FLAG_LAST_PACKET: u8 = 0b0010; const FLAG_WITH_NEGATIVE_SEQUENCE: u8 = 0b0011; const FLAG_WITH_EVENT: u8 = 0b0100; const EVENT_START_CONNECTION: i32 = 1; const EVENT_FINISH_CONNECTION: i32 = 2; const EVENT_CONNECTION_STARTED: i32 = 50; const EVENT_CONNECTION_FAILED: i32 = 51; const EVENT_CONNECTION_FINISHED: i32 = 52; const EVENT_START_SESSION: i32 = 100; const EVENT_CANCEL_SESSION: i32 = 101; const EVENT_FINISH_SESSION: i32 = 102; const EVENT_SESSION_STARTED: i32 = 150; const EVENT_SESSION_CANCELED: i32 = 151; const EVENT_SESSION_FINISHED: i32 = 152; const EVENT_SESSION_FAILED: i32 = 153; const EVENT_TASK_REQUEST: i32 = 200; const EVENT_TTS_SENTENCE_END: i32 = 351; const EVENT_TTS_RESPONSE: i32 = 352; const EVENT_TTS_SUBTITLE: i32 = 353; pub type SpeechWsStream = WebSocketStream>; #[derive(Clone, Debug, PartialEq, Eq)] pub struct VolcengineSpeechConfig { pub api_key: Option, pub app_id: Option, pub access_key: Option, pub asr_resource_id: String, pub tts_resource_id: String, pub asr_ws_url: String, pub tts_bidirection_ws_url: String, pub tts_sse_url: String, pub request_timeout_ms: u64, } #[derive(Clone, Debug, PartialEq, Eq, Serialize)] #[serde(rename_all = "camelCase")] pub struct PublicSpeechConfig { pub asr_resource_id: String, pub tts_resource_id: String, pub asr_audio: AsrAudioConfig, pub tts_audio: TtsAudioParams, pub endpoints: PublicSpeechEndpoints, } #[derive(Clone, Debug, PartialEq, Eq, Serialize)] #[serde(rename_all = "camelCase")] pub struct PublicSpeechEndpoints { pub asr_stream: &'static str, pub tts_bidirection: &'static str, pub tts_sse: &'static str, } #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct AsrAudioConfig { pub format: String, pub codec: String, pub rate: u32, pub bits: u8, pub channel: u8, } #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct TtsAudioParams { pub format: String, pub sample_rate: u32, #[serde(skip_serializing_if = "Option::is_none")] pub bit_rate: Option, } #[derive(Clone, Debug, PartialEq, Eq)] pub struct VolcengineSpeechClient { config: VolcengineSpeechConfig, } #[derive(Debug)] pub enum SpeechError { InvalidConfig(String), InvalidHeader(String), InvalidFrame(String), Serialize(String), Io(String), Upstream(String), } #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum VolcengineSpeechAuthMode { ApiKey, LegacyApp, } #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum AsrFrameKind { FullClientRequest, Audio, LastAudio, } #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct SpeechFrameHeader { pub version: u8, pub header_size_words: u8, pub message_type: u8, pub flags: u8, pub serialization: u8, pub compression: u8, } #[derive(Clone, Debug, PartialEq)] pub struct ParsedAsrResponse { pub header: SpeechFrameHeader, pub sequence: Option, pub event: Option, pub payload: Value, pub error_code: Option, } #[derive(Clone, Debug, PartialEq)] pub struct ParsedTtsResponse { pub header: SpeechFrameHeader, pub event: Option, pub session_id: Option, pub connection_id: Option, pub payload: Option, pub audio: Option>, pub error_code: Option, } #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum TtsEvent { StartConnection, FinishConnection, ConnectionStarted, ConnectionFailed, ConnectionFinished, StartSession, CancelSession, FinishSession, SessionStarted, SessionCanceled, SessionFinished, SessionFailed, TaskRequest, TtsSentenceEnd, TtsResponse, TtsSubtitle, Unknown(i32), } #[derive(Clone, Debug, PartialEq, Eq, Deserialize)] #[serde(rename_all = "snake_case", tag = "type")] pub enum TtsBidirectionClientEvent { StartConnection { #[serde(default)] payload: Option, }, FinishConnection { #[serde(default)] payload: Option, }, StartSession { #[serde(rename = "sessionId")] session_id: Option, #[serde(default)] payload: Value, }, FinishSession { #[serde(rename = "sessionId")] session_id: String, #[serde(default)] payload: Option, }, CancelSession { #[serde(rename = "sessionId")] session_id: String, #[serde(default)] payload: Option, }, TaskRequest { #[serde(rename = "sessionId")] session_id: String, #[serde(default)] payload: Value, }, } #[derive(Clone, Debug, PartialEq, Eq, Deserialize)] #[serde(rename_all = "camelCase")] pub struct TtsSseRequest { pub text: String, pub speaker: String, #[serde(default)] pub model: Option, #[serde(default)] pub audio_params: Option, #[serde(default)] pub additions: Option, #[serde(default)] pub ssml: Option, } #[derive(Clone, Debug, PartialEq, Eq)] pub struct TtsSseUpstreamRequest { pub url: String, pub headers: HeaderMap, pub body: Value, pub timeout: Duration, } #[derive(Clone, Copy, Debug, PartialEq, Eq)] enum SpeechCompression { None, Gzip, } impl Default for AsrAudioConfig { fn default() -> Self { Self { format: "pcm".to_string(), codec: "raw".to_string(), rate: 16_000, bits: 16, channel: 1, } } } impl Default for TtsAudioParams { fn default() -> Self { Self { format: "mp3".to_string(), sample_rate: 24_000, bit_rate: None, } } } impl VolcengineSpeechConfig { pub fn new( api_key: Option, app_id: Option, access_key: Option, asr_resource_id: String, tts_resource_id: String, asr_ws_url: String, tts_bidirection_ws_url: String, tts_sse_url: String, request_timeout_ms: u64, ) -> Result { let config = Self { api_key: normalize_optional_secret(api_key), app_id: normalize_optional_secret(app_id), access_key: normalize_optional_secret(access_key), asr_resource_id: default_if_empty(asr_resource_id, DEFAULT_ASR_RESOURCE_ID), tts_resource_id: default_if_empty(tts_resource_id, DEFAULT_TTS_RESOURCE_ID), asr_ws_url: default_if_empty(asr_ws_url, DEFAULT_ASR_WS_URL), tts_bidirection_ws_url: default_if_empty( tts_bidirection_ws_url, DEFAULT_TTS_BIDIRECTION_WS_URL, ), tts_sse_url: default_if_empty(tts_sse_url, DEFAULT_TTS_SSE_URL), request_timeout_ms: request_timeout_ms.max(1), }; config.auth_mode()?; Ok(config) } pub fn auth_mode(&self) -> Result { if self.api_key.as_ref().is_some_and(|value| !value.is_empty()) { return Ok(VolcengineSpeechAuthMode::ApiKey); } if self.app_id.as_ref().is_some_and(|value| !value.is_empty()) && self .access_key .as_ref() .is_some_and(|value| !value.is_empty()) { return Ok(VolcengineSpeechAuthMode::LegacyApp); } Err(SpeechError::InvalidConfig( "火山语音密钥未配置:需要 VOLCENGINE_SPEECH_API_KEY,或旧版 VOLCENGINE_SPEECH_APP_ID + VOLCENGINE_SPEECH_ACCESS_KEY" .to_string(), )) } pub fn public_config(&self) -> PublicSpeechConfig { PublicSpeechConfig { asr_resource_id: self.asr_resource_id.clone(), tts_resource_id: self.tts_resource_id.clone(), asr_audio: AsrAudioConfig::default(), tts_audio: TtsAudioParams::default(), endpoints: PublicSpeechEndpoints { asr_stream: "/api/speech/volcengine/asr/stream", tts_bidirection: "/api/speech/volcengine/tts/bidirection", tts_sse: "/api/speech/volcengine/tts/sse", }, } } fn auth_headers(&self, resource_id: &str) -> Result { let mut headers = HeaderMap::new(); match self.auth_mode()? { VolcengineSpeechAuthMode::ApiKey => { headers.insert( "X-Api-Key", header_value(self.api_key.as_deref().unwrap_or(""))?, ); } VolcengineSpeechAuthMode::LegacyApp => { headers.insert( "X-Api-App-Key", header_value(self.app_id.as_deref().unwrap_or(""))?, ); headers.insert( "X-Api-Access-Key", header_value(self.access_key.as_deref().unwrap_or(""))?, ); } } headers.insert("X-Api-Resource-Id", header_value(resource_id)?); Ok(headers) } } impl VolcengineSpeechClient { pub fn new(config: VolcengineSpeechConfig) -> Self { Self { config } } pub fn config(&self) -> &VolcengineSpeechConfig { &self.config } pub async fn connect_asr( &self, ) -> Result<(SpeechWsStream, BTreeMap), SpeechError> { let request_id = Uuid::new_v4().to_string(); let mut headers = self.config.auth_headers(&self.config.asr_resource_id)?; headers.insert("X-Api-Request-Id", header_value(&request_id)?); headers.insert("X-Api-Sequence", HeaderValue::from_static("-1")); self.connect_ws(&self.config.asr_ws_url, headers).await } pub async fn connect_tts_bidirection( &self, ) -> Result<(SpeechWsStream, BTreeMap), SpeechError> { let connect_id = Uuid::new_v4().to_string(); let mut headers = self.config.auth_headers(&self.config.tts_resource_id)?; headers.insert("X-Api-Connect-Id", header_value(&connect_id)?); self.connect_ws(&self.config.tts_bidirection_ws_url, headers) .await } pub fn build_tts_sse_upstream_request( &self, request: TtsSseRequest, user_id: &str, ) -> Result { let mut headers = self.config.auth_headers(&self.config.tts_resource_id)?; headers.insert( "X-Api-Request-Id", header_value(&Uuid::new_v4().to_string())?, ); headers.insert( reqwest::header::ACCEPT, HeaderValue::from_static("text/event-stream"), ); headers.insert( reqwest::header::CONTENT_TYPE, HeaderValue::from_static("application/json"), ); let body = build_tts_sse_body(request, user_id)?; Ok(TtsSseUpstreamRequest { url: self.config.tts_sse_url.clone(), headers, body, timeout: Duration::from_millis(self.config.request_timeout_ms), }) } async fn connect_ws( &self, url: &str, headers: HeaderMap, ) -> Result<(SpeechWsStream, BTreeMap), SpeechError> { let uri: Uri = url.parse().map_err(|error| { SpeechError::InvalidConfig(format!("火山语音 WebSocket URL 非法:{error}")) })?; let mut request = uri.into_client_request().map_err(|error| { SpeechError::InvalidConfig(format!("构造火山语音 WebSocket 请求失败:{error}")) })?; for (name, value) in headers { if let Some(name) = name { request.headers_mut().insert(name, value); } } let (stream, response) = connect_async(request).await.map_err(|error| { SpeechError::Upstream(format!("连接火山语音 WebSocket 失败:{error}")) })?; let response_headers = response .headers() .iter() .filter_map(|(name, value)| { value .to_str() .ok() .map(|value| (name.as_str().to_string(), value.to_string())) }) .collect(); Ok((stream, response_headers)) } } pub fn default_asr_request_payload(user_id: &str, override_payload: Option) -> Value { let mut payload = json!({ "user": { "uid": user_id, }, "audio": { "format": AsrAudioConfig::default().format, "codec": AsrAudioConfig::default().codec, "rate": AsrAudioConfig::default().rate, "bits": AsrAudioConfig::default().bits, "channel": AsrAudioConfig::default().channel, }, "request": { "model_name": "bigmodel", "enable_itn": true, "enable_punc": true, "show_utterances": true, "result_type": "full", } }); if let Some(override_payload) = override_payload { merge_json_object(&mut payload, override_payload); } payload } pub fn build_asr_frame(kind: AsrFrameKind, payload: &[u8]) -> Result, SpeechError> { match kind { AsrFrameKind::FullClientRequest => build_sized_frame( MESSAGE_FULL_CLIENT_REQUEST, FLAG_NONE, SERIALIZATION_JSON, COMPRESSION_GZIP, &gzip_bytes(payload)?, ), AsrFrameKind::Audio => build_sized_frame( MESSAGE_AUDIO_ONLY_REQUEST, FLAG_NONE, SERIALIZATION_NONE, COMPRESSION_GZIP, &gzip_bytes(payload)?, ), AsrFrameKind::LastAudio => build_sized_frame( MESSAGE_AUDIO_ONLY_REQUEST, FLAG_LAST_PACKET, SERIALIZATION_NONE, COMPRESSION_GZIP, &gzip_bytes(payload)?, ), } } pub fn build_asr_full_client_request(payload: &Value) -> Result, SpeechError> { let bytes = serde_json::to_vec(payload) .map_err(|error| SpeechError::Serialize(format!("序列化 ASR 请求失败:{error}")))?; build_asr_frame(AsrFrameKind::FullClientRequest, &bytes) } pub fn parse_asr_response_frame(bytes: &[u8]) -> Result { let header = parse_header(bytes)?; let mut offset = usize::from(header.header_size_words) * 4; let sequence = if matches!( header.flags, FLAG_WITH_SEQUENCE | FLAG_WITH_NEGATIVE_SEQUENCE ) && bytes.len() >= offset + 4 { let sequence = read_i32(bytes, &mut offset)?; Some(sequence) } else { None }; if header.message_type == MESSAGE_ERROR { let error_code = read_u32(bytes, &mut offset).ok(); let payload = read_payload_value(bytes, &mut offset, SpeechCompression::None) .unwrap_or_else(|_| json!({ "message": decode_lossy(&bytes[offset..]) })); return Ok(ParsedAsrResponse { header, sequence, event: None, payload, error_code, }); } let payload = read_payload_value(bytes, &mut offset, header.compression())?; Ok(ParsedAsrResponse { header, sequence, event: None, payload, error_code: None, }) } pub fn build_tts_bidirection_frame( event: TtsEvent, session_id: Option<&str>, payload: Option<&Value>, ) -> Result, SpeechError> { let payload_bytes = serde_json::to_vec(payload.unwrap_or(&json!({}))) .map_err(|error| SpeechError::Serialize(format!("序列化 TTS 事件失败:{error}")))?; let event_number = event.to_i32(); let message_type = match event { TtsEvent::TaskRequest => MESSAGE_FULL_CLIENT_REQUEST, _ => MESSAGE_FULL_CLIENT_REQUEST, }; build_tts_event_frame( message_type, SERIALIZATION_JSON, COMPRESSION_NONE, event_number, session_id, &payload_bytes, ) } pub fn build_tts_bidirection_frame_from_client_event( event: TtsBidirectionClientEvent, ) -> Result, SpeechError> { match event { TtsBidirectionClientEvent::StartConnection { payload } => { build_tts_bidirection_frame(TtsEvent::StartConnection, None, payload.as_ref()) } TtsBidirectionClientEvent::FinishConnection { payload } => { build_tts_bidirection_frame(TtsEvent::FinishConnection, None, payload.as_ref()) } TtsBidirectionClientEvent::StartSession { session_id, payload, } => { let session_id = session_id.unwrap_or_else(|| Uuid::new_v4().to_string()); build_tts_bidirection_frame(TtsEvent::StartSession, Some(&session_id), Some(&payload)) } TtsBidirectionClientEvent::FinishSession { session_id, payload, } => build_tts_bidirection_frame( TtsEvent::FinishSession, Some(&session_id), payload.as_ref(), ), TtsBidirectionClientEvent::CancelSession { session_id, payload, } => build_tts_bidirection_frame( TtsEvent::CancelSession, Some(&session_id), payload.as_ref(), ), TtsBidirectionClientEvent::TaskRequest { session_id, payload, } => build_tts_bidirection_frame(TtsEvent::TaskRequest, Some(&session_id), Some(&payload)), } } pub fn parse_tts_response_frame(bytes: &[u8]) -> Result { let header = parse_header(bytes)?; let mut offset = usize::from(header.header_size_words) * 4; if header.message_type == MESSAGE_ERROR { let error_code = read_u32(bytes, &mut offset).ok(); let payload = read_payload_value(bytes, &mut offset, header.compression()) .unwrap_or_else(|_| json!({ "message": decode_lossy(&bytes[offset..]) })); return Ok(ParsedTtsResponse { header, event: None, session_id: None, connection_id: None, payload: Some(payload), audio: None, error_code, }); } let event = if header.flags == FLAG_WITH_EVENT { Some(TtsEvent::from_i32(read_i32(bytes, &mut offset)?)) } else { None }; let session_or_connection_id = read_optional_length_prefixed_string(bytes, &mut offset)?; let payload_bytes = read_payload_bytes(bytes, &mut offset)?; let payload_bytes = match header.compression() { SpeechCompression::None => payload_bytes, SpeechCompression::Gzip => ungzip_bytes(&payload_bytes)?, }; let is_audio = header.message_type == MESSAGE_AUDIO_ONLY_RESPONSE || event == Some(TtsEvent::TtsResponse) && header.serialization == SERIALIZATION_NONE; let payload = if is_audio { None } else if payload_bytes.is_empty() { Some(json!({})) } else { Some(serde_json::from_slice(&payload_bytes).map_err(|error| { SpeechError::InvalidFrame(format!("解析 TTS JSON 响应失败:{error}")) })?) }; let audio = if is_audio { Some(payload_bytes) } else { None }; let (connection_id, session_id) = match event { Some(TtsEvent::ConnectionStarted) | Some(TtsEvent::ConnectionFailed) | Some(TtsEvent::ConnectionFinished) => (session_or_connection_id, None), _ => (None, session_or_connection_id), }; Ok(ParsedTtsResponse { header, event, session_id, connection_id, payload, audio, error_code: None, }) } pub fn tts_response_to_client_value(response: &ParsedTtsResponse) -> Value { json!({ "event": response.event.map(|event| event.name()), "eventCode": response.event.map(|event| event.to_i32()), "sessionId": response.session_id, "connectionId": response.connection_id, "payload": response.payload, "audioBytes": response.audio.as_ref().map(Vec::len), "errorCode": response.error_code, }) } pub fn build_tts_sse_body(request: TtsSseRequest, user_id: &str) -> Result { let text = request.text.trim(); let speaker = request.speaker.trim(); if text.is_empty() && request.ssml.as_deref().unwrap_or("").trim().is_empty() { return Err(SpeechError::InvalidConfig("TTS 文本不能为空".to_string())); } if speaker.is_empty() { return Err(SpeechError::InvalidConfig( "TTS speaker 不能为空".to_string(), )); } let mut req_params = json!({ "text": text, "speaker": speaker, "audio_params": { "format": request.audio_params.clone().unwrap_or_default().format, "sample_rate": request.audio_params.clone().unwrap_or_default().sample_rate, }, }); if let Some(bit_rate) = request .audio_params .as_ref() .and_then(|params| params.bit_rate) { req_params["audio_params"]["bit_rate"] = json!(bit_rate); } if let Some(model) = normalize_optional_secret(request.model) { req_params["model"] = json!(model); } if let Some(ssml) = normalize_optional_secret(request.ssml) { req_params["ssml"] = json!(ssml); } if let Some(additions) = request.additions { req_params["additions"] = additions; } Ok(json!({ "user": { "uid": user_id, }, "req_params": req_params, })) } pub async fn send_binary(ws: &mut SpeechWsStream, bytes: Vec) -> Result<(), SpeechError> { ws.send(Message::Binary(bytes.into())) .await .map_err(|error| SpeechError::Upstream(format!("发送火山语音 WebSocket 帧失败:{error}"))) } pub async fn recv_binary(ws: &mut SpeechWsStream) -> Result>, SpeechError> { while let Some(message) = ws.next().await { match message { Ok(Message::Binary(bytes)) => return Ok(Some(bytes.to_vec())), Ok(Message::Text(text)) => return Ok(Some(text.as_bytes().to_vec())), Ok(Message::Close(_)) => return Ok(None), Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {} Ok(Message::Frame(_)) => {} Err(error) => { return Err(SpeechError::Upstream(format!( "读取火山语音 WebSocket 帧失败:{error}" ))); } } } Ok(None) } impl SpeechFrameHeader { fn compression(self) -> SpeechCompression { match self.compression { COMPRESSION_GZIP => SpeechCompression::Gzip, _ => SpeechCompression::None, } } } impl TtsEvent { pub fn from_i32(value: i32) -> Self { match value { EVENT_START_CONNECTION => Self::StartConnection, EVENT_FINISH_CONNECTION => Self::FinishConnection, EVENT_CONNECTION_STARTED => Self::ConnectionStarted, EVENT_CONNECTION_FAILED => Self::ConnectionFailed, EVENT_CONNECTION_FINISHED => Self::ConnectionFinished, EVENT_START_SESSION => Self::StartSession, EVENT_CANCEL_SESSION => Self::CancelSession, EVENT_FINISH_SESSION => Self::FinishSession, EVENT_SESSION_STARTED => Self::SessionStarted, EVENT_SESSION_CANCELED => Self::SessionCanceled, EVENT_SESSION_FINISHED => Self::SessionFinished, EVENT_SESSION_FAILED => Self::SessionFailed, EVENT_TASK_REQUEST => Self::TaskRequest, EVENT_TTS_SENTENCE_END => Self::TtsSentenceEnd, EVENT_TTS_RESPONSE => Self::TtsResponse, EVENT_TTS_SUBTITLE => Self::TtsSubtitle, other => Self::Unknown(other), } } pub fn to_i32(self) -> i32 { match self { Self::StartConnection => EVENT_START_CONNECTION, Self::FinishConnection => EVENT_FINISH_CONNECTION, Self::ConnectionStarted => EVENT_CONNECTION_STARTED, Self::ConnectionFailed => EVENT_CONNECTION_FAILED, Self::ConnectionFinished => EVENT_CONNECTION_FINISHED, Self::StartSession => EVENT_START_SESSION, Self::CancelSession => EVENT_CANCEL_SESSION, Self::FinishSession => EVENT_FINISH_SESSION, Self::SessionStarted => EVENT_SESSION_STARTED, Self::SessionCanceled => EVENT_SESSION_CANCELED, Self::SessionFinished => EVENT_SESSION_FINISHED, Self::SessionFailed => EVENT_SESSION_FAILED, Self::TaskRequest => EVENT_TASK_REQUEST, Self::TtsSentenceEnd => EVENT_TTS_SENTENCE_END, Self::TtsResponse => EVENT_TTS_RESPONSE, Self::TtsSubtitle => EVENT_TTS_SUBTITLE, Self::Unknown(value) => value, } } pub fn name(self) -> &'static str { match self { Self::StartConnection => "start_connection", Self::FinishConnection => "finish_connection", Self::ConnectionStarted => "connection_started", Self::ConnectionFailed => "connection_failed", Self::ConnectionFinished => "connection_finished", Self::StartSession => "start_session", Self::CancelSession => "cancel_session", Self::FinishSession => "finish_session", Self::SessionStarted => "session_started", Self::SessionCanceled => "session_canceled", Self::SessionFinished => "session_finished", Self::SessionFailed => "session_failed", Self::TaskRequest => "task_request", Self::TtsSentenceEnd => "tts_sentence_end", Self::TtsResponse => "tts_response", Self::TtsSubtitle => "tts_subtitle", Self::Unknown(_) => "unknown", } } } impl fmt::Display for SpeechError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::InvalidConfig(message) | Self::InvalidHeader(message) | Self::InvalidFrame(message) | Self::Serialize(message) | Self::Io(message) | Self::Upstream(message) => write!(f, "{message}"), } } } impl Error for SpeechError {} fn build_sized_frame( message_type: u8, flags: u8, serialization: u8, compression: u8, payload: &[u8], ) -> Result, SpeechError> { let payload_len = u32::try_from(payload.len()) .map_err(|_| SpeechError::InvalidFrame("语音帧 payload 超过 u32 上限".to_string()))?; let mut frame = Vec::with_capacity(8 + payload.len()); frame.push((PROTOCOL_VERSION << 4) | HEADER_SIZE_FOUR_BYTES); frame.push((message_type << 4) | flags); frame.push((serialization << 4) | compression); frame.push(0); frame.extend_from_slice(&payload_len.to_be_bytes()); frame.extend_from_slice(payload); Ok(frame) } fn build_tts_event_frame( message_type: u8, serialization: u8, compression: u8, event: i32, session_id: Option<&str>, payload: &[u8], ) -> Result, SpeechError> { let id = session_id.unwrap_or(""); let id_bytes = id.as_bytes(); let id_len = u32::try_from(id_bytes.len()) .map_err(|_| SpeechError::InvalidFrame("TTS session id 超过 u32 上限".to_string()))?; let payload_len = u32::try_from(payload.len()) .map_err(|_| SpeechError::InvalidFrame("TTS payload 超过 u32 上限".to_string()))?; let mut frame = Vec::with_capacity(12 + id_bytes.len() + payload.len()); frame.push((PROTOCOL_VERSION << 4) | HEADER_SIZE_FOUR_BYTES); frame.push((message_type << 4) | FLAG_WITH_EVENT); frame.push((serialization << 4) | compression); frame.push(0); frame.extend_from_slice(&event.to_be_bytes()); if session_id.is_some() { frame.extend_from_slice(&id_len.to_be_bytes()); frame.extend_from_slice(id_bytes); } frame.extend_from_slice(&payload_len.to_be_bytes()); frame.extend_from_slice(payload); Ok(frame) } fn parse_header(bytes: &[u8]) -> Result { if bytes.len() < 4 { return Err(SpeechError::InvalidFrame( "语音帧长度不足 4 字节".to_string(), )); } Ok(SpeechFrameHeader { version: bytes[0] >> 4, header_size_words: bytes[0] & 0x0f, message_type: bytes[1] >> 4, flags: bytes[1] & 0x0f, serialization: bytes[2] >> 4, compression: bytes[2] & 0x0f, }) } fn read_payload_value( bytes: &[u8], offset: &mut usize, compression: SpeechCompression, ) -> Result { let payload = read_payload_bytes(bytes, offset)?; let payload = match compression { SpeechCompression::None => payload, SpeechCompression::Gzip => ungzip_bytes(&payload)?, }; if payload.is_empty() { return Ok(json!({})); } serde_json::from_slice(&payload) .map_err(|error| SpeechError::InvalidFrame(format!("解析语音 JSON 帧失败:{error}"))) } fn read_payload_bytes(bytes: &[u8], offset: &mut usize) -> Result, SpeechError> { let payload_len = read_u32(bytes, offset)? as usize; if bytes.len() < *offset + payload_len { return Err(SpeechError::InvalidFrame( "语音帧 payload 长度超过实际数据".to_string(), )); } let payload = bytes[*offset..*offset + payload_len].to_vec(); *offset += payload_len; Ok(payload) } fn read_optional_length_prefixed_string( bytes: &[u8], offset: &mut usize, ) -> Result, SpeechError> { if bytes.len() < *offset + 4 { return Ok(None); } let saved_offset = *offset; let len = read_u32(bytes, offset)? as usize; if bytes.len() < *offset + len { *offset = saved_offset; return Ok(None); } let text = decode_lossy(&bytes[*offset..*offset + len]); *offset += len; if text.is_empty() { Ok(None) } else { Ok(Some(text)) } } fn read_i32(bytes: &[u8], offset: &mut usize) -> Result { if bytes.len() < *offset + 4 { return Err(SpeechError::InvalidFrame("语音帧缺少 i32 字段".to_string())); } let value = i32::from_be_bytes([ bytes[*offset], bytes[*offset + 1], bytes[*offset + 2], bytes[*offset + 3], ]); *offset += 4; Ok(value) } fn read_u32(bytes: &[u8], offset: &mut usize) -> Result { if bytes.len() < *offset + 4 { return Err(SpeechError::InvalidFrame("语音帧缺少 u32 字段".to_string())); } let value = u32::from_be_bytes([ bytes[*offset], bytes[*offset + 1], bytes[*offset + 2], bytes[*offset + 3], ]); *offset += 4; Ok(value) } fn gzip_bytes(payload: &[u8]) -> Result, SpeechError> { let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); encoder .write_all(payload) .map_err(|error| SpeechError::Io(format!("gzip 压缩语音帧失败:{error}")))?; encoder .finish() .map_err(|error| SpeechError::Io(format!("完成 gzip 压缩语音帧失败:{error}"))) } fn ungzip_bytes(payload: &[u8]) -> Result, SpeechError> { let mut decoder = GzDecoder::new(payload); let mut output = Vec::new(); decoder .read_to_end(&mut output) .map_err(|error| SpeechError::Io(format!("gzip 解压语音帧失败:{error}")))?; Ok(output) } fn header_value(value: &str) -> Result { HeaderValue::from_str(value) .map_err(|error| SpeechError::InvalidHeader(format!("构造火山语音请求头失败:{error}"))) } fn normalize_optional_secret(value: Option) -> Option { value .map(|value| value.trim().to_string()) .filter(|value| !value.is_empty()) } fn default_if_empty(value: String, default_value: &str) -> String { let value = value.trim(); if value.is_empty() { default_value.to_string() } else { value.to_string() } } fn merge_json_object(target: &mut Value, source: Value) { match (target, source) { (Value::Object(target), Value::Object(source)) => { for (key, value) in source { merge_json_object(target.entry(key).or_insert(Value::Null), value); } } (target, source) => *target = source, } } fn decode_lossy(bytes: &[u8]) -> String { String::from_utf8_lossy(bytes).trim().to_string() } #[cfg(test)] mod tests { use super::*; fn test_config_with_api_key() -> VolcengineSpeechConfig { VolcengineSpeechConfig::new( Some("api-key".to_string()), None, None, String::new(), String::new(), String::new(), String::new(), String::new(), DEFAULT_REQUEST_TIMEOUT_MS, ) .expect("config should build") } #[test] fn config_prefers_api_key_auth_and_exposes_no_secret_in_public_config() { let config = test_config_with_api_key(); assert_eq!( config.auth_mode().unwrap(), VolcengineSpeechAuthMode::ApiKey ); let public_config = config.public_config(); assert_eq!(public_config.asr_resource_id, DEFAULT_ASR_RESOURCE_ID); assert_eq!(public_config.tts_resource_id, DEFAULT_TTS_RESOURCE_ID); assert_eq!( public_config.endpoints.asr_stream, "/api/speech/volcengine/asr/stream" ); } #[test] fn config_accepts_legacy_auth_when_api_key_missing() { let config = VolcengineSpeechConfig::new( None, Some("app-id".to_string()), Some("access-key".to_string()), String::new(), String::new(), String::new(), String::new(), String::new(), DEFAULT_REQUEST_TIMEOUT_MS, ) .expect("legacy config should build"); assert_eq!( config.auth_mode().unwrap(), VolcengineSpeechAuthMode::LegacyApp ); } #[test] fn asr_frame_roundtrip_parses_gzip_json_response() { let payload = json!({ "result": { "text": "你好" } }); let payload_bytes = serde_json::to_vec(&payload).unwrap(); let compressed = gzip_bytes(&payload_bytes).unwrap(); let mut frame = vec![ (PROTOCOL_VERSION << 4) | HEADER_SIZE_FOUR_BYTES, (MESSAGE_FULL_SERVER_RESPONSE << 4) | FLAG_WITH_SEQUENCE, (SERIALIZATION_JSON << 4) | COMPRESSION_GZIP, 0, ]; frame.extend_from_slice(&7_i32.to_be_bytes()); frame.extend_from_slice(&(compressed.len() as u32).to_be_bytes()); frame.extend_from_slice(&compressed); let parsed = parse_asr_response_frame(&frame).expect("asr response should parse"); assert_eq!(parsed.sequence, Some(7)); assert_eq!(parsed.payload["result"]["text"], "你好"); } #[test] fn asr_full_request_frame_uses_expected_header() { let frame = build_asr_full_client_request(&default_asr_request_payload("user-1", None)) .expect("frame should build"); assert_eq!(frame[0], 0x11); assert_eq!(frame[1], 0x10); assert_eq!(frame[2], 0x11); assert!(frame.len() > 8); } #[test] fn tts_start_session_frame_contains_event_and_session_id() { let frame = build_tts_bidirection_frame( TtsEvent::StartSession, Some("session-1"), Some(&json!({ "req_params": { "speaker": "voice" } })), ) .expect("tts frame should build"); assert_eq!(frame[0], 0x11); assert_eq!(frame[1], 0x14); assert_eq!(frame[2], 0x10); assert_eq!( i32::from_be_bytes([frame[4], frame[5], frame[6], frame[7]]), 100 ); assert_eq!( u32::from_be_bytes([frame[8], frame[9], frame[10], frame[11]]), 9 ); } #[test] fn tts_response_frame_parses_json_event() { let payload = json!({ "status_code": 20000000, "message": "ok" }); let payload_bytes = serde_json::to_vec(&payload).unwrap(); let frame = build_tts_event_frame( MESSAGE_FULL_SERVER_RESPONSE, SERIALIZATION_JSON, COMPRESSION_NONE, TtsEvent::SessionFinished.to_i32(), Some("session-1"), &payload_bytes, ) .expect("response frame should build"); let parsed = parse_tts_response_frame(&frame).expect("tts response should parse"); assert_eq!(parsed.event, Some(TtsEvent::SessionFinished)); assert_eq!(parsed.session_id.as_deref(), Some("session-1")); assert_eq!(parsed.payload.unwrap()["status_code"], 20000000); } #[test] fn tts_sse_body_uses_snake_case_audio_params() { let body = build_tts_sse_body( TtsSseRequest { text: "你好".to_string(), speaker: "voice".to_string(), model: None, audio_params: Some(TtsAudioParams { format: "mp3".to_string(), sample_rate: 24_000, bit_rate: Some(64_000), }), additions: None, ssml: None, }, "user-1", ) .expect("sse body should build"); assert_eq!(body["user"]["uid"], "user-1"); assert_eq!(body["req_params"]["audio_params"]["sample_rate"], 24_000); assert_eq!(body["req_params"]["audio_params"]["bit_rate"], 64_000); } }