Files
Genarrative/server-rs/crates/platform-speech/src/lib.rs
2026-05-10 13:18:46 +08:00

1206 lines
39 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use std::{
collections::BTreeMap,
error::Error,
fmt,
io::{Read, Write},
time::Duration,
};
use flate2::{Compression, read::GzDecoder, write::GzEncoder};
use futures_util::{SinkExt, StreamExt};
use reqwest::header::{HeaderMap, HeaderValue};
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use tokio_tungstenite::{
MaybeTlsStream, WebSocketStream, connect_async,
tungstenite::{Message, client::IntoClientRequest, http::Uri},
};
use uuid::Uuid;
pub use tokio_tungstenite::tungstenite::{Error as UpstreamWsError, Message as UpstreamWsMessage};
pub const DEFAULT_ASR_WS_URL: &str = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async";
pub const DEFAULT_TTS_BIDIRECTION_WS_URL: &str =
"wss://openspeech.bytedance.com/api/v3/tts/bidirection";
pub const DEFAULT_TTS_SSE_URL: &str =
"https://openspeech.bytedance.com/api/v3/tts/unidirectional/sse";
pub const DEFAULT_ASR_RESOURCE_ID: &str = "volc.seedasr.sauc.concurrent";
pub const DEFAULT_TTS_RESOURCE_ID: &str = "seed-tts-2.0";
pub const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 180_000;
const PROTOCOL_VERSION: u8 = 0b0001;
const HEADER_SIZE_FOUR_BYTES: u8 = 0b0001;
const SERIALIZATION_NONE: u8 = 0b0000;
const SERIALIZATION_JSON: u8 = 0b0001;
const COMPRESSION_NONE: u8 = 0b0000;
const COMPRESSION_GZIP: u8 = 0b0001;
const MESSAGE_FULL_CLIENT_REQUEST: u8 = 0b0001;
const MESSAGE_AUDIO_ONLY_REQUEST: u8 = 0b0010;
pub const MESSAGE_FULL_SERVER_RESPONSE: u8 = 0b1001;
const MESSAGE_AUDIO_ONLY_RESPONSE: u8 = 0b1011;
const MESSAGE_ERROR: u8 = 0b1111;
const FLAG_NONE: u8 = 0b0000;
const FLAG_WITH_SEQUENCE: u8 = 0b0001;
const FLAG_LAST_PACKET: u8 = 0b0010;
const FLAG_WITH_NEGATIVE_SEQUENCE: u8 = 0b0011;
const FLAG_WITH_EVENT: u8 = 0b0100;
const EVENT_START_CONNECTION: i32 = 1;
const EVENT_FINISH_CONNECTION: i32 = 2;
const EVENT_CONNECTION_STARTED: i32 = 50;
const EVENT_CONNECTION_FAILED: i32 = 51;
const EVENT_CONNECTION_FINISHED: i32 = 52;
const EVENT_START_SESSION: i32 = 100;
const EVENT_CANCEL_SESSION: i32 = 101;
const EVENT_FINISH_SESSION: i32 = 102;
const EVENT_SESSION_STARTED: i32 = 150;
const EVENT_SESSION_CANCELED: i32 = 151;
const EVENT_SESSION_FINISHED: i32 = 152;
const EVENT_SESSION_FAILED: i32 = 153;
const EVENT_TASK_REQUEST: i32 = 200;
const EVENT_TTS_SENTENCE_END: i32 = 351;
const EVENT_TTS_RESPONSE: i32 = 352;
const EVENT_TTS_SUBTITLE: i32 = 353;
pub type SpeechWsStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct VolcengineSpeechConfig {
pub api_key: Option<String>,
pub app_id: Option<String>,
pub access_key: Option<String>,
pub asr_resource_id: String,
pub tts_resource_id: String,
pub asr_ws_url: String,
pub tts_bidirection_ws_url: String,
pub tts_sse_url: String,
pub request_timeout_ms: u64,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct PublicSpeechConfig {
pub asr_resource_id: String,
pub tts_resource_id: String,
pub asr_audio: AsrAudioConfig,
pub tts_audio: TtsAudioParams,
pub endpoints: PublicSpeechEndpoints,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct PublicSpeechEndpoints {
pub asr_stream: &'static str,
pub tts_bidirection: &'static str,
pub tts_sse: &'static str,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AsrAudioConfig {
pub format: String,
pub codec: String,
pub rate: u32,
pub bits: u8,
pub channel: u8,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TtsAudioParams {
pub format: String,
pub sample_rate: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub bit_rate: Option<u32>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct VolcengineSpeechClient {
config: VolcengineSpeechConfig,
}
#[derive(Debug)]
pub enum SpeechError {
InvalidConfig(String),
InvalidHeader(String),
InvalidFrame(String),
Serialize(String),
Io(String),
Upstream(String),
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum VolcengineSpeechAuthMode {
ApiKey,
LegacyApp,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum AsrFrameKind {
FullClientRequest,
Audio,
LastAudio,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct SpeechFrameHeader {
pub version: u8,
pub header_size_words: u8,
pub message_type: u8,
pub flags: u8,
pub serialization: u8,
pub compression: u8,
}
#[derive(Clone, Debug, PartialEq)]
pub struct ParsedAsrResponse {
pub header: SpeechFrameHeader,
pub sequence: Option<i32>,
pub event: Option<i32>,
pub payload: Value,
pub error_code: Option<u32>,
}
#[derive(Clone, Debug, PartialEq)]
pub struct ParsedTtsResponse {
pub header: SpeechFrameHeader,
pub event: Option<TtsEvent>,
pub session_id: Option<String>,
pub connection_id: Option<String>,
pub payload: Option<Value>,
pub audio: Option<Vec<u8>>,
pub error_code: Option<u32>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TtsEvent {
StartConnection,
FinishConnection,
ConnectionStarted,
ConnectionFailed,
ConnectionFinished,
StartSession,
CancelSession,
FinishSession,
SessionStarted,
SessionCanceled,
SessionFinished,
SessionFailed,
TaskRequest,
TtsSentenceEnd,
TtsResponse,
TtsSubtitle,
Unknown(i32),
}
#[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum TtsBidirectionClientEvent {
StartConnection {
#[serde(default)]
payload: Option<Value>,
},
FinishConnection {
#[serde(default)]
payload: Option<Value>,
},
StartSession {
#[serde(rename = "sessionId")]
session_id: Option<String>,
#[serde(default)]
payload: Value,
},
FinishSession {
#[serde(rename = "sessionId")]
session_id: String,
#[serde(default)]
payload: Option<Value>,
},
CancelSession {
#[serde(rename = "sessionId")]
session_id: String,
#[serde(default)]
payload: Option<Value>,
},
TaskRequest {
#[serde(rename = "sessionId")]
session_id: String,
#[serde(default)]
payload: Value,
},
}
#[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TtsSseRequest {
pub text: String,
pub speaker: String,
#[serde(default)]
pub model: Option<String>,
#[serde(default)]
pub audio_params: Option<TtsAudioParams>,
#[serde(default)]
pub additions: Option<Value>,
#[serde(default)]
pub ssml: Option<String>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct TtsSseUpstreamRequest {
pub url: String,
pub headers: HeaderMap,
pub body: Value,
pub timeout: Duration,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum SpeechCompression {
None,
Gzip,
}
impl Default for AsrAudioConfig {
fn default() -> Self {
Self {
format: "pcm".to_string(),
codec: "raw".to_string(),
rate: 16_000,
bits: 16,
channel: 1,
}
}
}
impl Default for TtsAudioParams {
fn default() -> Self {
Self {
format: "mp3".to_string(),
sample_rate: 24_000,
bit_rate: None,
}
}
}
impl VolcengineSpeechConfig {
pub fn new(
api_key: Option<String>,
app_id: Option<String>,
access_key: Option<String>,
asr_resource_id: String,
tts_resource_id: String,
asr_ws_url: String,
tts_bidirection_ws_url: String,
tts_sse_url: String,
request_timeout_ms: u64,
) -> Result<Self, SpeechError> {
let config = Self {
api_key: normalize_optional_secret(api_key),
app_id: normalize_optional_secret(app_id),
access_key: normalize_optional_secret(access_key),
asr_resource_id: default_if_empty(asr_resource_id, DEFAULT_ASR_RESOURCE_ID),
tts_resource_id: default_if_empty(tts_resource_id, DEFAULT_TTS_RESOURCE_ID),
asr_ws_url: default_if_empty(asr_ws_url, DEFAULT_ASR_WS_URL),
tts_bidirection_ws_url: default_if_empty(
tts_bidirection_ws_url,
DEFAULT_TTS_BIDIRECTION_WS_URL,
),
tts_sse_url: default_if_empty(tts_sse_url, DEFAULT_TTS_SSE_URL),
request_timeout_ms: request_timeout_ms.max(1),
};
config.auth_mode()?;
Ok(config)
}
pub fn auth_mode(&self) -> Result<VolcengineSpeechAuthMode, SpeechError> {
if self.api_key.as_ref().is_some_and(|value| !value.is_empty()) {
return Ok(VolcengineSpeechAuthMode::ApiKey);
}
if self.app_id.as_ref().is_some_and(|value| !value.is_empty())
&& self
.access_key
.as_ref()
.is_some_and(|value| !value.is_empty())
{
return Ok(VolcengineSpeechAuthMode::LegacyApp);
}
Err(SpeechError::InvalidConfig(
"火山语音密钥未配置:需要 VOLCENGINE_SPEECH_API_KEY或旧版 VOLCENGINE_SPEECH_APP_ID + VOLCENGINE_SPEECH_ACCESS_KEY"
.to_string(),
))
}
pub fn public_config(&self) -> PublicSpeechConfig {
PublicSpeechConfig {
asr_resource_id: self.asr_resource_id.clone(),
tts_resource_id: self.tts_resource_id.clone(),
asr_audio: AsrAudioConfig::default(),
tts_audio: TtsAudioParams::default(),
endpoints: PublicSpeechEndpoints {
asr_stream: "/api/speech/volcengine/asr/stream",
tts_bidirection: "/api/speech/volcengine/tts/bidirection",
tts_sse: "/api/speech/volcengine/tts/sse",
},
}
}
fn auth_headers(&self, resource_id: &str) -> Result<HeaderMap, SpeechError> {
let mut headers = HeaderMap::new();
match self.auth_mode()? {
VolcengineSpeechAuthMode::ApiKey => {
headers.insert(
"X-Api-Key",
header_value(self.api_key.as_deref().unwrap_or(""))?,
);
}
VolcengineSpeechAuthMode::LegacyApp => {
headers.insert(
"X-Api-App-Key",
header_value(self.app_id.as_deref().unwrap_or(""))?,
);
headers.insert(
"X-Api-Access-Key",
header_value(self.access_key.as_deref().unwrap_or(""))?,
);
}
}
headers.insert("X-Api-Resource-Id", header_value(resource_id)?);
Ok(headers)
}
}
impl VolcengineSpeechClient {
pub fn new(config: VolcengineSpeechConfig) -> Self {
Self { config }
}
pub fn config(&self) -> &VolcengineSpeechConfig {
&self.config
}
pub async fn connect_asr(
&self,
) -> Result<(SpeechWsStream, BTreeMap<String, String>), SpeechError> {
let request_id = Uuid::new_v4().to_string();
let mut headers = self.config.auth_headers(&self.config.asr_resource_id)?;
headers.insert("X-Api-Request-Id", header_value(&request_id)?);
headers.insert("X-Api-Sequence", HeaderValue::from_static("-1"));
self.connect_ws(&self.config.asr_ws_url, headers).await
}
pub async fn connect_tts_bidirection(
&self,
) -> Result<(SpeechWsStream, BTreeMap<String, String>), SpeechError> {
let connect_id = Uuid::new_v4().to_string();
let mut headers = self.config.auth_headers(&self.config.tts_resource_id)?;
headers.insert("X-Api-Connect-Id", header_value(&connect_id)?);
self.connect_ws(&self.config.tts_bidirection_ws_url, headers)
.await
}
pub fn build_tts_sse_upstream_request(
&self,
request: TtsSseRequest,
user_id: &str,
) -> Result<TtsSseUpstreamRequest, SpeechError> {
let mut headers = self.config.auth_headers(&self.config.tts_resource_id)?;
headers.insert(
"X-Api-Request-Id",
header_value(&Uuid::new_v4().to_string())?,
);
headers.insert(
reqwest::header::ACCEPT,
HeaderValue::from_static("text/event-stream"),
);
headers.insert(
reqwest::header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
let body = build_tts_sse_body(request, user_id)?;
Ok(TtsSseUpstreamRequest {
url: self.config.tts_sse_url.clone(),
headers,
body,
timeout: Duration::from_millis(self.config.request_timeout_ms),
})
}
async fn connect_ws(
&self,
url: &str,
headers: HeaderMap,
) -> Result<(SpeechWsStream, BTreeMap<String, String>), SpeechError> {
let uri: Uri = url.parse().map_err(|error| {
SpeechError::InvalidConfig(format!("火山语音 WebSocket URL 非法:{error}"))
})?;
let mut request = uri.into_client_request().map_err(|error| {
SpeechError::InvalidConfig(format!("构造火山语音 WebSocket 请求失败:{error}"))
})?;
for (name, value) in headers {
if let Some(name) = name {
request.headers_mut().insert(name, value);
}
}
let (stream, response) = connect_async(request).await.map_err(|error| {
SpeechError::Upstream(format!("连接火山语音 WebSocket 失败:{error}"))
})?;
let response_headers = response
.headers()
.iter()
.filter_map(|(name, value)| {
value
.to_str()
.ok()
.map(|value| (name.as_str().to_string(), value.to_string()))
})
.collect();
Ok((stream, response_headers))
}
}
pub fn default_asr_request_payload(user_id: &str, override_payload: Option<Value>) -> Value {
let mut payload = json!({
"user": {
"uid": user_id,
},
"audio": {
"format": AsrAudioConfig::default().format,
"codec": AsrAudioConfig::default().codec,
"rate": AsrAudioConfig::default().rate,
"bits": AsrAudioConfig::default().bits,
"channel": AsrAudioConfig::default().channel,
},
"request": {
"model_name": "bigmodel",
"enable_itn": true,
"enable_punc": true,
"show_utterances": true,
"result_type": "full",
}
});
if let Some(override_payload) = override_payload {
merge_json_object(&mut payload, override_payload);
}
payload
}
pub fn build_asr_frame(kind: AsrFrameKind, payload: &[u8]) -> Result<Vec<u8>, SpeechError> {
match kind {
AsrFrameKind::FullClientRequest => build_sized_frame(
MESSAGE_FULL_CLIENT_REQUEST,
FLAG_NONE,
SERIALIZATION_JSON,
COMPRESSION_GZIP,
&gzip_bytes(payload)?,
),
AsrFrameKind::Audio => build_sized_frame(
MESSAGE_AUDIO_ONLY_REQUEST,
FLAG_NONE,
SERIALIZATION_NONE,
COMPRESSION_GZIP,
&gzip_bytes(payload)?,
),
AsrFrameKind::LastAudio => build_sized_frame(
MESSAGE_AUDIO_ONLY_REQUEST,
FLAG_LAST_PACKET,
SERIALIZATION_NONE,
COMPRESSION_GZIP,
&gzip_bytes(payload)?,
),
}
}
pub fn build_asr_full_client_request(payload: &Value) -> Result<Vec<u8>, SpeechError> {
let bytes = serde_json::to_vec(payload)
.map_err(|error| SpeechError::Serialize(format!("序列化 ASR 请求失败:{error}")))?;
build_asr_frame(AsrFrameKind::FullClientRequest, &bytes)
}
pub fn parse_asr_response_frame(bytes: &[u8]) -> Result<ParsedAsrResponse, SpeechError> {
let header = parse_header(bytes)?;
let mut offset = usize::from(header.header_size_words) * 4;
let sequence = if matches!(
header.flags,
FLAG_WITH_SEQUENCE | FLAG_WITH_NEGATIVE_SEQUENCE
) && bytes.len() >= offset + 4
{
let sequence = read_i32(bytes, &mut offset)?;
Some(sequence)
} else {
None
};
if header.message_type == MESSAGE_ERROR {
let error_code = read_u32(bytes, &mut offset).ok();
let payload = read_payload_value(bytes, &mut offset, SpeechCompression::None)
.unwrap_or_else(|_| json!({ "message": decode_lossy(&bytes[offset..]) }));
return Ok(ParsedAsrResponse {
header,
sequence,
event: None,
payload,
error_code,
});
}
let payload = read_payload_value(bytes, &mut offset, header.compression())?;
Ok(ParsedAsrResponse {
header,
sequence,
event: None,
payload,
error_code: None,
})
}
pub fn build_tts_bidirection_frame(
event: TtsEvent,
session_id: Option<&str>,
payload: Option<&Value>,
) -> Result<Vec<u8>, SpeechError> {
let payload_bytes = serde_json::to_vec(payload.unwrap_or(&json!({})))
.map_err(|error| SpeechError::Serialize(format!("序列化 TTS 事件失败:{error}")))?;
let event_number = event.to_i32();
let message_type = match event {
TtsEvent::TaskRequest => MESSAGE_FULL_CLIENT_REQUEST,
_ => MESSAGE_FULL_CLIENT_REQUEST,
};
build_tts_event_frame(
message_type,
SERIALIZATION_JSON,
COMPRESSION_NONE,
event_number,
session_id,
&payload_bytes,
)
}
pub fn build_tts_bidirection_frame_from_client_event(
event: TtsBidirectionClientEvent,
) -> Result<Vec<u8>, SpeechError> {
match event {
TtsBidirectionClientEvent::StartConnection { payload } => {
build_tts_bidirection_frame(TtsEvent::StartConnection, None, payload.as_ref())
}
TtsBidirectionClientEvent::FinishConnection { payload } => {
build_tts_bidirection_frame(TtsEvent::FinishConnection, None, payload.as_ref())
}
TtsBidirectionClientEvent::StartSession {
session_id,
payload,
} => {
let session_id = session_id.unwrap_or_else(|| Uuid::new_v4().to_string());
build_tts_bidirection_frame(TtsEvent::StartSession, Some(&session_id), Some(&payload))
}
TtsBidirectionClientEvent::FinishSession {
session_id,
payload,
} => build_tts_bidirection_frame(
TtsEvent::FinishSession,
Some(&session_id),
payload.as_ref(),
),
TtsBidirectionClientEvent::CancelSession {
session_id,
payload,
} => build_tts_bidirection_frame(
TtsEvent::CancelSession,
Some(&session_id),
payload.as_ref(),
),
TtsBidirectionClientEvent::TaskRequest {
session_id,
payload,
} => build_tts_bidirection_frame(TtsEvent::TaskRequest, Some(&session_id), Some(&payload)),
}
}
pub fn parse_tts_response_frame(bytes: &[u8]) -> Result<ParsedTtsResponse, SpeechError> {
let header = parse_header(bytes)?;
let mut offset = usize::from(header.header_size_words) * 4;
if header.message_type == MESSAGE_ERROR {
let error_code = read_u32(bytes, &mut offset).ok();
let payload = read_payload_value(bytes, &mut offset, header.compression())
.unwrap_or_else(|_| json!({ "message": decode_lossy(&bytes[offset..]) }));
return Ok(ParsedTtsResponse {
header,
event: None,
session_id: None,
connection_id: None,
payload: Some(payload),
audio: None,
error_code,
});
}
let event = if header.flags == FLAG_WITH_EVENT {
Some(TtsEvent::from_i32(read_i32(bytes, &mut offset)?))
} else {
None
};
let session_or_connection_id = read_optional_length_prefixed_string(bytes, &mut offset)?;
let payload_bytes = read_payload_bytes(bytes, &mut offset)?;
let payload_bytes = match header.compression() {
SpeechCompression::None => payload_bytes,
SpeechCompression::Gzip => ungzip_bytes(&payload_bytes)?,
};
let is_audio = header.message_type == MESSAGE_AUDIO_ONLY_RESPONSE
|| event == Some(TtsEvent::TtsResponse) && header.serialization == SERIALIZATION_NONE;
let payload = if is_audio {
None
} else if payload_bytes.is_empty() {
Some(json!({}))
} else {
Some(serde_json::from_slice(&payload_bytes).map_err(|error| {
SpeechError::InvalidFrame(format!("解析 TTS JSON 响应失败:{error}"))
})?)
};
let audio = if is_audio { Some(payload_bytes) } else { None };
let (connection_id, session_id) = match event {
Some(TtsEvent::ConnectionStarted)
| Some(TtsEvent::ConnectionFailed)
| Some(TtsEvent::ConnectionFinished) => (session_or_connection_id, None),
_ => (None, session_or_connection_id),
};
Ok(ParsedTtsResponse {
header,
event,
session_id,
connection_id,
payload,
audio,
error_code: None,
})
}
pub fn tts_response_to_client_value(response: &ParsedTtsResponse) -> Value {
json!({
"event": response.event.map(|event| event.name()),
"eventCode": response.event.map(|event| event.to_i32()),
"sessionId": response.session_id,
"connectionId": response.connection_id,
"payload": response.payload,
"audioBytes": response.audio.as_ref().map(Vec::len),
"errorCode": response.error_code,
})
}
pub fn build_tts_sse_body(request: TtsSseRequest, user_id: &str) -> Result<Value, SpeechError> {
let text = request.text.trim();
let speaker = request.speaker.trim();
if text.is_empty() && request.ssml.as_deref().unwrap_or("").trim().is_empty() {
return Err(SpeechError::InvalidConfig("TTS 文本不能为空".to_string()));
}
if speaker.is_empty() {
return Err(SpeechError::InvalidConfig(
"TTS speaker 不能为空".to_string(),
));
}
let mut req_params = json!({
"text": text,
"speaker": speaker,
"audio_params": {
"format": request.audio_params.clone().unwrap_or_default().format,
"sample_rate": request.audio_params.clone().unwrap_or_default().sample_rate,
},
});
if let Some(bit_rate) = request
.audio_params
.as_ref()
.and_then(|params| params.bit_rate)
{
req_params["audio_params"]["bit_rate"] = json!(bit_rate);
}
if let Some(model) = normalize_optional_secret(request.model) {
req_params["model"] = json!(model);
}
if let Some(ssml) = normalize_optional_secret(request.ssml) {
req_params["ssml"] = json!(ssml);
}
if let Some(additions) = request.additions {
req_params["additions"] = additions;
}
Ok(json!({
"user": {
"uid": user_id,
},
"req_params": req_params,
}))
}
pub async fn send_binary(ws: &mut SpeechWsStream, bytes: Vec<u8>) -> Result<(), SpeechError> {
ws.send(Message::Binary(bytes.into()))
.await
.map_err(|error| SpeechError::Upstream(format!("发送火山语音 WebSocket 帧失败:{error}")))
}
pub async fn recv_binary(ws: &mut SpeechWsStream) -> Result<Option<Vec<u8>>, SpeechError> {
while let Some(message) = ws.next().await {
match message {
Ok(Message::Binary(bytes)) => return Ok(Some(bytes.to_vec())),
Ok(Message::Text(text)) => return Ok(Some(text.as_bytes().to_vec())),
Ok(Message::Close(_)) => return Ok(None),
Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {}
Ok(Message::Frame(_)) => {}
Err(error) => {
return Err(SpeechError::Upstream(format!(
"读取火山语音 WebSocket 帧失败:{error}"
)));
}
}
}
Ok(None)
}
impl SpeechFrameHeader {
fn compression(self) -> SpeechCompression {
match self.compression {
COMPRESSION_GZIP => SpeechCompression::Gzip,
_ => SpeechCompression::None,
}
}
}
impl TtsEvent {
pub fn from_i32(value: i32) -> Self {
match value {
EVENT_START_CONNECTION => Self::StartConnection,
EVENT_FINISH_CONNECTION => Self::FinishConnection,
EVENT_CONNECTION_STARTED => Self::ConnectionStarted,
EVENT_CONNECTION_FAILED => Self::ConnectionFailed,
EVENT_CONNECTION_FINISHED => Self::ConnectionFinished,
EVENT_START_SESSION => Self::StartSession,
EVENT_CANCEL_SESSION => Self::CancelSession,
EVENT_FINISH_SESSION => Self::FinishSession,
EVENT_SESSION_STARTED => Self::SessionStarted,
EVENT_SESSION_CANCELED => Self::SessionCanceled,
EVENT_SESSION_FINISHED => Self::SessionFinished,
EVENT_SESSION_FAILED => Self::SessionFailed,
EVENT_TASK_REQUEST => Self::TaskRequest,
EVENT_TTS_SENTENCE_END => Self::TtsSentenceEnd,
EVENT_TTS_RESPONSE => Self::TtsResponse,
EVENT_TTS_SUBTITLE => Self::TtsSubtitle,
other => Self::Unknown(other),
}
}
pub fn to_i32(self) -> i32 {
match self {
Self::StartConnection => EVENT_START_CONNECTION,
Self::FinishConnection => EVENT_FINISH_CONNECTION,
Self::ConnectionStarted => EVENT_CONNECTION_STARTED,
Self::ConnectionFailed => EVENT_CONNECTION_FAILED,
Self::ConnectionFinished => EVENT_CONNECTION_FINISHED,
Self::StartSession => EVENT_START_SESSION,
Self::CancelSession => EVENT_CANCEL_SESSION,
Self::FinishSession => EVENT_FINISH_SESSION,
Self::SessionStarted => EVENT_SESSION_STARTED,
Self::SessionCanceled => EVENT_SESSION_CANCELED,
Self::SessionFinished => EVENT_SESSION_FINISHED,
Self::SessionFailed => EVENT_SESSION_FAILED,
Self::TaskRequest => EVENT_TASK_REQUEST,
Self::TtsSentenceEnd => EVENT_TTS_SENTENCE_END,
Self::TtsResponse => EVENT_TTS_RESPONSE,
Self::TtsSubtitle => EVENT_TTS_SUBTITLE,
Self::Unknown(value) => value,
}
}
pub fn name(self) -> &'static str {
match self {
Self::StartConnection => "start_connection",
Self::FinishConnection => "finish_connection",
Self::ConnectionStarted => "connection_started",
Self::ConnectionFailed => "connection_failed",
Self::ConnectionFinished => "connection_finished",
Self::StartSession => "start_session",
Self::CancelSession => "cancel_session",
Self::FinishSession => "finish_session",
Self::SessionStarted => "session_started",
Self::SessionCanceled => "session_canceled",
Self::SessionFinished => "session_finished",
Self::SessionFailed => "session_failed",
Self::TaskRequest => "task_request",
Self::TtsSentenceEnd => "tts_sentence_end",
Self::TtsResponse => "tts_response",
Self::TtsSubtitle => "tts_subtitle",
Self::Unknown(_) => "unknown",
}
}
}
impl fmt::Display for SpeechError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidConfig(message)
| Self::InvalidHeader(message)
| Self::InvalidFrame(message)
| Self::Serialize(message)
| Self::Io(message)
| Self::Upstream(message) => write!(f, "{message}"),
}
}
}
impl Error for SpeechError {}
fn build_sized_frame(
message_type: u8,
flags: u8,
serialization: u8,
compression: u8,
payload: &[u8],
) -> Result<Vec<u8>, SpeechError> {
let payload_len = u32::try_from(payload.len())
.map_err(|_| SpeechError::InvalidFrame("语音帧 payload 超过 u32 上限".to_string()))?;
let mut frame = Vec::with_capacity(8 + payload.len());
frame.push((PROTOCOL_VERSION << 4) | HEADER_SIZE_FOUR_BYTES);
frame.push((message_type << 4) | flags);
frame.push((serialization << 4) | compression);
frame.push(0);
frame.extend_from_slice(&payload_len.to_be_bytes());
frame.extend_from_slice(payload);
Ok(frame)
}
fn build_tts_event_frame(
message_type: u8,
serialization: u8,
compression: u8,
event: i32,
session_id: Option<&str>,
payload: &[u8],
) -> Result<Vec<u8>, SpeechError> {
let id = session_id.unwrap_or("");
let id_bytes = id.as_bytes();
let id_len = u32::try_from(id_bytes.len())
.map_err(|_| SpeechError::InvalidFrame("TTS session id 超过 u32 上限".to_string()))?;
let payload_len = u32::try_from(payload.len())
.map_err(|_| SpeechError::InvalidFrame("TTS payload 超过 u32 上限".to_string()))?;
let mut frame = Vec::with_capacity(12 + id_bytes.len() + payload.len());
frame.push((PROTOCOL_VERSION << 4) | HEADER_SIZE_FOUR_BYTES);
frame.push((message_type << 4) | FLAG_WITH_EVENT);
frame.push((serialization << 4) | compression);
frame.push(0);
frame.extend_from_slice(&event.to_be_bytes());
if session_id.is_some() {
frame.extend_from_slice(&id_len.to_be_bytes());
frame.extend_from_slice(id_bytes);
}
frame.extend_from_slice(&payload_len.to_be_bytes());
frame.extend_from_slice(payload);
Ok(frame)
}
fn parse_header(bytes: &[u8]) -> Result<SpeechFrameHeader, SpeechError> {
if bytes.len() < 4 {
return Err(SpeechError::InvalidFrame(
"语音帧长度不足 4 字节".to_string(),
));
}
Ok(SpeechFrameHeader {
version: bytes[0] >> 4,
header_size_words: bytes[0] & 0x0f,
message_type: bytes[1] >> 4,
flags: bytes[1] & 0x0f,
serialization: bytes[2] >> 4,
compression: bytes[2] & 0x0f,
})
}
fn read_payload_value(
bytes: &[u8],
offset: &mut usize,
compression: SpeechCompression,
) -> Result<Value, SpeechError> {
let payload = read_payload_bytes(bytes, offset)?;
let payload = match compression {
SpeechCompression::None => payload,
SpeechCompression::Gzip => ungzip_bytes(&payload)?,
};
if payload.is_empty() {
return Ok(json!({}));
}
serde_json::from_slice(&payload)
.map_err(|error| SpeechError::InvalidFrame(format!("解析语音 JSON 帧失败:{error}")))
}
fn read_payload_bytes(bytes: &[u8], offset: &mut usize) -> Result<Vec<u8>, SpeechError> {
let payload_len = read_u32(bytes, offset)? as usize;
if bytes.len() < *offset + payload_len {
return Err(SpeechError::InvalidFrame(
"语音帧 payload 长度超过实际数据".to_string(),
));
}
let payload = bytes[*offset..*offset + payload_len].to_vec();
*offset += payload_len;
Ok(payload)
}
fn read_optional_length_prefixed_string(
bytes: &[u8],
offset: &mut usize,
) -> Result<Option<String>, SpeechError> {
if bytes.len() < *offset + 4 {
return Ok(None);
}
let saved_offset = *offset;
let len = read_u32(bytes, offset)? as usize;
if bytes.len() < *offset + len {
*offset = saved_offset;
return Ok(None);
}
let text = decode_lossy(&bytes[*offset..*offset + len]);
*offset += len;
if text.is_empty() {
Ok(None)
} else {
Ok(Some(text))
}
}
fn read_i32(bytes: &[u8], offset: &mut usize) -> Result<i32, SpeechError> {
if bytes.len() < *offset + 4 {
return Err(SpeechError::InvalidFrame("语音帧缺少 i32 字段".to_string()));
}
let value = i32::from_be_bytes([
bytes[*offset],
bytes[*offset + 1],
bytes[*offset + 2],
bytes[*offset + 3],
]);
*offset += 4;
Ok(value)
}
fn read_u32(bytes: &[u8], offset: &mut usize) -> Result<u32, SpeechError> {
if bytes.len() < *offset + 4 {
return Err(SpeechError::InvalidFrame("语音帧缺少 u32 字段".to_string()));
}
let value = u32::from_be_bytes([
bytes[*offset],
bytes[*offset + 1],
bytes[*offset + 2],
bytes[*offset + 3],
]);
*offset += 4;
Ok(value)
}
fn gzip_bytes(payload: &[u8]) -> Result<Vec<u8>, SpeechError> {
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder
.write_all(payload)
.map_err(|error| SpeechError::Io(format!("gzip 压缩语音帧失败:{error}")))?;
encoder
.finish()
.map_err(|error| SpeechError::Io(format!("完成 gzip 压缩语音帧失败:{error}")))
}
fn ungzip_bytes(payload: &[u8]) -> Result<Vec<u8>, SpeechError> {
let mut decoder = GzDecoder::new(payload);
let mut output = Vec::new();
decoder
.read_to_end(&mut output)
.map_err(|error| SpeechError::Io(format!("gzip 解压语音帧失败:{error}")))?;
Ok(output)
}
fn header_value(value: &str) -> Result<HeaderValue, SpeechError> {
HeaderValue::from_str(value)
.map_err(|error| SpeechError::InvalidHeader(format!("构造火山语音请求头失败:{error}")))
}
fn normalize_optional_secret(value: Option<String>) -> Option<String> {
value
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())
}
fn default_if_empty(value: String, default_value: &str) -> String {
let value = value.trim();
if value.is_empty() {
default_value.to_string()
} else {
value.to_string()
}
}
fn merge_json_object(target: &mut Value, source: Value) {
match (target, source) {
(Value::Object(target), Value::Object(source)) => {
for (key, value) in source {
merge_json_object(target.entry(key).or_insert(Value::Null), value);
}
}
(target, source) => *target = source,
}
}
fn decode_lossy(bytes: &[u8]) -> String {
String::from_utf8_lossy(bytes).trim().to_string()
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config_with_api_key() -> VolcengineSpeechConfig {
VolcengineSpeechConfig::new(
Some("api-key".to_string()),
None,
None,
String::new(),
String::new(),
String::new(),
String::new(),
String::new(),
DEFAULT_REQUEST_TIMEOUT_MS,
)
.expect("config should build")
}
#[test]
fn config_prefers_api_key_auth_and_exposes_no_secret_in_public_config() {
let config = test_config_with_api_key();
assert_eq!(
config.auth_mode().unwrap(),
VolcengineSpeechAuthMode::ApiKey
);
let public_config = config.public_config();
assert_eq!(public_config.asr_resource_id, DEFAULT_ASR_RESOURCE_ID);
assert_eq!(public_config.tts_resource_id, DEFAULT_TTS_RESOURCE_ID);
assert_eq!(
public_config.endpoints.asr_stream,
"/api/speech/volcengine/asr/stream"
);
}
#[test]
fn config_accepts_legacy_auth_when_api_key_missing() {
let config = VolcengineSpeechConfig::new(
None,
Some("app-id".to_string()),
Some("access-key".to_string()),
String::new(),
String::new(),
String::new(),
String::new(),
String::new(),
DEFAULT_REQUEST_TIMEOUT_MS,
)
.expect("legacy config should build");
assert_eq!(
config.auth_mode().unwrap(),
VolcengineSpeechAuthMode::LegacyApp
);
}
#[test]
fn asr_frame_roundtrip_parses_gzip_json_response() {
let payload = json!({ "result": { "text": "你好" } });
let payload_bytes = serde_json::to_vec(&payload).unwrap();
let compressed = gzip_bytes(&payload_bytes).unwrap();
let mut frame = vec![
(PROTOCOL_VERSION << 4) | HEADER_SIZE_FOUR_BYTES,
(MESSAGE_FULL_SERVER_RESPONSE << 4) | FLAG_WITH_SEQUENCE,
(SERIALIZATION_JSON << 4) | COMPRESSION_GZIP,
0,
];
frame.extend_from_slice(&7_i32.to_be_bytes());
frame.extend_from_slice(&(compressed.len() as u32).to_be_bytes());
frame.extend_from_slice(&compressed);
let parsed = parse_asr_response_frame(&frame).expect("asr response should parse");
assert_eq!(parsed.sequence, Some(7));
assert_eq!(parsed.payload["result"]["text"], "你好");
}
#[test]
fn asr_full_request_frame_uses_expected_header() {
let frame = build_asr_full_client_request(&default_asr_request_payload("user-1", None))
.expect("frame should build");
assert_eq!(frame[0], 0x11);
assert_eq!(frame[1], 0x10);
assert_eq!(frame[2], 0x11);
assert!(frame.len() > 8);
}
#[test]
fn tts_start_session_frame_contains_event_and_session_id() {
let frame = build_tts_bidirection_frame(
TtsEvent::StartSession,
Some("session-1"),
Some(&json!({ "req_params": { "speaker": "voice" } })),
)
.expect("tts frame should build");
assert_eq!(frame[0], 0x11);
assert_eq!(frame[1], 0x14);
assert_eq!(frame[2], 0x10);
assert_eq!(
i32::from_be_bytes([frame[4], frame[5], frame[6], frame[7]]),
100
);
assert_eq!(
u32::from_be_bytes([frame[8], frame[9], frame[10], frame[11]]),
9
);
}
#[test]
fn tts_response_frame_parses_json_event() {
let payload = json!({ "status_code": 20000000, "message": "ok" });
let payload_bytes = serde_json::to_vec(&payload).unwrap();
let frame = build_tts_event_frame(
MESSAGE_FULL_SERVER_RESPONSE,
SERIALIZATION_JSON,
COMPRESSION_NONE,
TtsEvent::SessionFinished.to_i32(),
Some("session-1"),
&payload_bytes,
)
.expect("response frame should build");
let parsed = parse_tts_response_frame(&frame).expect("tts response should parse");
assert_eq!(parsed.event, Some(TtsEvent::SessionFinished));
assert_eq!(parsed.session_id.as_deref(), Some("session-1"));
assert_eq!(parsed.payload.unwrap()["status_code"], 20000000);
}
#[test]
fn tts_sse_body_uses_snake_case_audio_params() {
let body = build_tts_sse_body(
TtsSseRequest {
text: "你好".to_string(),
speaker: "voice".to_string(),
model: None,
audio_params: Some(TtsAudioParams {
format: "mp3".to_string(),
sample_rate: 24_000,
bit_rate: Some(64_000),
}),
additions: None,
ssml: None,
},
"user-1",
)
.expect("sse body should build");
assert_eq!(body["user"]["uid"], "user-1");
assert_eq!(body["req_params"]["audio_params"]["sample_rate"], 24_000);
assert_eq!(body["req_params"]["audio_params"]["bit_rate"], 64_000);
}
}