1206 lines
39 KiB
Rust
1206 lines
39 KiB
Rust
use std::{
|
||
collections::BTreeMap,
|
||
error::Error,
|
||
fmt,
|
||
io::{Read, Write},
|
||
time::Duration,
|
||
};
|
||
|
||
use flate2::{Compression, read::GzDecoder, write::GzEncoder};
|
||
use futures_util::{SinkExt, StreamExt};
|
||
use reqwest::header::{HeaderMap, HeaderValue};
|
||
use serde::{Deserialize, Serialize};
|
||
use serde_json::{Value, json};
|
||
use tokio_tungstenite::{
|
||
MaybeTlsStream, WebSocketStream, connect_async,
|
||
tungstenite::{Message, client::IntoClientRequest, http::Uri},
|
||
};
|
||
use uuid::Uuid;
|
||
|
||
pub use tokio_tungstenite::tungstenite::{Error as UpstreamWsError, Message as UpstreamWsMessage};
|
||
|
||
pub const DEFAULT_ASR_WS_URL: &str = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async";
|
||
pub const DEFAULT_TTS_BIDIRECTION_WS_URL: &str =
|
||
"wss://openspeech.bytedance.com/api/v3/tts/bidirection";
|
||
pub const DEFAULT_TTS_SSE_URL: &str =
|
||
"https://openspeech.bytedance.com/api/v3/tts/unidirectional/sse";
|
||
pub const DEFAULT_ASR_RESOURCE_ID: &str = "volc.seedasr.sauc.concurrent";
|
||
pub const DEFAULT_TTS_RESOURCE_ID: &str = "seed-tts-2.0";
|
||
pub const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 180_000;
|
||
|
||
const PROTOCOL_VERSION: u8 = 0b0001;
|
||
const HEADER_SIZE_FOUR_BYTES: u8 = 0b0001;
|
||
const SERIALIZATION_NONE: u8 = 0b0000;
|
||
const SERIALIZATION_JSON: u8 = 0b0001;
|
||
const COMPRESSION_NONE: u8 = 0b0000;
|
||
const COMPRESSION_GZIP: u8 = 0b0001;
|
||
const MESSAGE_FULL_CLIENT_REQUEST: u8 = 0b0001;
|
||
const MESSAGE_AUDIO_ONLY_REQUEST: u8 = 0b0010;
|
||
pub const MESSAGE_FULL_SERVER_RESPONSE: u8 = 0b1001;
|
||
const MESSAGE_AUDIO_ONLY_RESPONSE: u8 = 0b1011;
|
||
const MESSAGE_ERROR: u8 = 0b1111;
|
||
const FLAG_NONE: u8 = 0b0000;
|
||
const FLAG_WITH_SEQUENCE: u8 = 0b0001;
|
||
const FLAG_LAST_PACKET: u8 = 0b0010;
|
||
const FLAG_WITH_NEGATIVE_SEQUENCE: u8 = 0b0011;
|
||
const FLAG_WITH_EVENT: u8 = 0b0100;
|
||
|
||
const EVENT_START_CONNECTION: i32 = 1;
|
||
const EVENT_FINISH_CONNECTION: i32 = 2;
|
||
const EVENT_CONNECTION_STARTED: i32 = 50;
|
||
const EVENT_CONNECTION_FAILED: i32 = 51;
|
||
const EVENT_CONNECTION_FINISHED: i32 = 52;
|
||
const EVENT_START_SESSION: i32 = 100;
|
||
const EVENT_CANCEL_SESSION: i32 = 101;
|
||
const EVENT_FINISH_SESSION: i32 = 102;
|
||
const EVENT_SESSION_STARTED: i32 = 150;
|
||
const EVENT_SESSION_CANCELED: i32 = 151;
|
||
const EVENT_SESSION_FINISHED: i32 = 152;
|
||
const EVENT_SESSION_FAILED: i32 = 153;
|
||
const EVENT_TASK_REQUEST: i32 = 200;
|
||
const EVENT_TTS_SENTENCE_END: i32 = 351;
|
||
const EVENT_TTS_RESPONSE: i32 = 352;
|
||
const EVENT_TTS_SUBTITLE: i32 = 353;
|
||
|
||
pub type SpeechWsStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
|
||
|
||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||
pub struct VolcengineSpeechConfig {
|
||
pub api_key: Option<String>,
|
||
pub app_id: Option<String>,
|
||
pub access_key: Option<String>,
|
||
pub asr_resource_id: String,
|
||
pub tts_resource_id: String,
|
||
pub asr_ws_url: String,
|
||
pub tts_bidirection_ws_url: String,
|
||
pub tts_sse_url: String,
|
||
pub request_timeout_ms: u64,
|
||
}
|
||
|
||
#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
|
||
#[serde(rename_all = "camelCase")]
|
||
pub struct PublicSpeechConfig {
|
||
pub asr_resource_id: String,
|
||
pub tts_resource_id: String,
|
||
pub asr_audio: AsrAudioConfig,
|
||
pub tts_audio: TtsAudioParams,
|
||
pub endpoints: PublicSpeechEndpoints,
|
||
}
|
||
|
||
#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
|
||
#[serde(rename_all = "camelCase")]
|
||
pub struct PublicSpeechEndpoints {
|
||
pub asr_stream: &'static str,
|
||
pub tts_bidirection: &'static str,
|
||
pub tts_sse: &'static str,
|
||
}
|
||
|
||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||
#[serde(rename_all = "camelCase")]
|
||
pub struct AsrAudioConfig {
|
||
pub format: String,
|
||
pub codec: String,
|
||
pub rate: u32,
|
||
pub bits: u8,
|
||
pub channel: u8,
|
||
}
|
||
|
||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||
#[serde(rename_all = "camelCase")]
|
||
pub struct TtsAudioParams {
|
||
pub format: String,
|
||
pub sample_rate: u32,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
pub bit_rate: Option<u32>,
|
||
}
|
||
|
||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||
pub struct VolcengineSpeechClient {
|
||
config: VolcengineSpeechConfig,
|
||
}
|
||
|
||
#[derive(Debug)]
|
||
pub enum SpeechError {
|
||
InvalidConfig(String),
|
||
InvalidHeader(String),
|
||
InvalidFrame(String),
|
||
Serialize(String),
|
||
Io(String),
|
||
Upstream(String),
|
||
}
|
||
|
||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||
pub enum VolcengineSpeechAuthMode {
|
||
ApiKey,
|
||
LegacyApp,
|
||
}
|
||
|
||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||
pub enum AsrFrameKind {
|
||
FullClientRequest,
|
||
Audio,
|
||
LastAudio,
|
||
}
|
||
|
||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||
pub struct SpeechFrameHeader {
|
||
pub version: u8,
|
||
pub header_size_words: u8,
|
||
pub message_type: u8,
|
||
pub flags: u8,
|
||
pub serialization: u8,
|
||
pub compression: u8,
|
||
}
|
||
|
||
#[derive(Clone, Debug, PartialEq)]
|
||
pub struct ParsedAsrResponse {
|
||
pub header: SpeechFrameHeader,
|
||
pub sequence: Option<i32>,
|
||
pub event: Option<i32>,
|
||
pub payload: Value,
|
||
pub error_code: Option<u32>,
|
||
}
|
||
|
||
#[derive(Clone, Debug, PartialEq)]
|
||
pub struct ParsedTtsResponse {
|
||
pub header: SpeechFrameHeader,
|
||
pub event: Option<TtsEvent>,
|
||
pub session_id: Option<String>,
|
||
pub connection_id: Option<String>,
|
||
pub payload: Option<Value>,
|
||
pub audio: Option<Vec<u8>>,
|
||
pub error_code: Option<u32>,
|
||
}
|
||
|
||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||
#[serde(rename_all = "snake_case")]
|
||
pub enum TtsEvent {
|
||
StartConnection,
|
||
FinishConnection,
|
||
ConnectionStarted,
|
||
ConnectionFailed,
|
||
ConnectionFinished,
|
||
StartSession,
|
||
CancelSession,
|
||
FinishSession,
|
||
SessionStarted,
|
||
SessionCanceled,
|
||
SessionFinished,
|
||
SessionFailed,
|
||
TaskRequest,
|
||
TtsSentenceEnd,
|
||
TtsResponse,
|
||
TtsSubtitle,
|
||
Unknown(i32),
|
||
}
|
||
|
||
#[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
|
||
#[serde(rename_all = "snake_case", tag = "type")]
|
||
pub enum TtsBidirectionClientEvent {
|
||
StartConnection {
|
||
#[serde(default)]
|
||
payload: Option<Value>,
|
||
},
|
||
FinishConnection {
|
||
#[serde(default)]
|
||
payload: Option<Value>,
|
||
},
|
||
StartSession {
|
||
#[serde(rename = "sessionId")]
|
||
session_id: Option<String>,
|
||
#[serde(default)]
|
||
payload: Value,
|
||
},
|
||
FinishSession {
|
||
#[serde(rename = "sessionId")]
|
||
session_id: String,
|
||
#[serde(default)]
|
||
payload: Option<Value>,
|
||
},
|
||
CancelSession {
|
||
#[serde(rename = "sessionId")]
|
||
session_id: String,
|
||
#[serde(default)]
|
||
payload: Option<Value>,
|
||
},
|
||
TaskRequest {
|
||
#[serde(rename = "sessionId")]
|
||
session_id: String,
|
||
#[serde(default)]
|
||
payload: Value,
|
||
},
|
||
}
|
||
|
||
#[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
|
||
#[serde(rename_all = "camelCase")]
|
||
pub struct TtsSseRequest {
|
||
pub text: String,
|
||
pub speaker: String,
|
||
#[serde(default)]
|
||
pub model: Option<String>,
|
||
#[serde(default)]
|
||
pub audio_params: Option<TtsAudioParams>,
|
||
#[serde(default)]
|
||
pub additions: Option<Value>,
|
||
#[serde(default)]
|
||
pub ssml: Option<String>,
|
||
}
|
||
|
||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||
pub struct TtsSseUpstreamRequest {
|
||
pub url: String,
|
||
pub headers: HeaderMap,
|
||
pub body: Value,
|
||
pub timeout: Duration,
|
||
}
|
||
|
||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||
enum SpeechCompression {
|
||
None,
|
||
Gzip,
|
||
}
|
||
|
||
impl Default for AsrAudioConfig {
|
||
fn default() -> Self {
|
||
Self {
|
||
format: "pcm".to_string(),
|
||
codec: "raw".to_string(),
|
||
rate: 16_000,
|
||
bits: 16,
|
||
channel: 1,
|
||
}
|
||
}
|
||
}
|
||
|
||
impl Default for TtsAudioParams {
|
||
fn default() -> Self {
|
||
Self {
|
||
format: "mp3".to_string(),
|
||
sample_rate: 24_000,
|
||
bit_rate: None,
|
||
}
|
||
}
|
||
}
|
||
|
||
impl VolcengineSpeechConfig {
|
||
pub fn new(
|
||
api_key: Option<String>,
|
||
app_id: Option<String>,
|
||
access_key: Option<String>,
|
||
asr_resource_id: String,
|
||
tts_resource_id: String,
|
||
asr_ws_url: String,
|
||
tts_bidirection_ws_url: String,
|
||
tts_sse_url: String,
|
||
request_timeout_ms: u64,
|
||
) -> Result<Self, SpeechError> {
|
||
let config = Self {
|
||
api_key: normalize_optional_secret(api_key),
|
||
app_id: normalize_optional_secret(app_id),
|
||
access_key: normalize_optional_secret(access_key),
|
||
asr_resource_id: default_if_empty(asr_resource_id, DEFAULT_ASR_RESOURCE_ID),
|
||
tts_resource_id: default_if_empty(tts_resource_id, DEFAULT_TTS_RESOURCE_ID),
|
||
asr_ws_url: default_if_empty(asr_ws_url, DEFAULT_ASR_WS_URL),
|
||
tts_bidirection_ws_url: default_if_empty(
|
||
tts_bidirection_ws_url,
|
||
DEFAULT_TTS_BIDIRECTION_WS_URL,
|
||
),
|
||
tts_sse_url: default_if_empty(tts_sse_url, DEFAULT_TTS_SSE_URL),
|
||
request_timeout_ms: request_timeout_ms.max(1),
|
||
};
|
||
config.auth_mode()?;
|
||
Ok(config)
|
||
}
|
||
|
||
pub fn auth_mode(&self) -> Result<VolcengineSpeechAuthMode, SpeechError> {
|
||
if self.api_key.as_ref().is_some_and(|value| !value.is_empty()) {
|
||
return Ok(VolcengineSpeechAuthMode::ApiKey);
|
||
}
|
||
|
||
if self.app_id.as_ref().is_some_and(|value| !value.is_empty())
|
||
&& self
|
||
.access_key
|
||
.as_ref()
|
||
.is_some_and(|value| !value.is_empty())
|
||
{
|
||
return Ok(VolcengineSpeechAuthMode::LegacyApp);
|
||
}
|
||
|
||
Err(SpeechError::InvalidConfig(
|
||
"火山语音密钥未配置:需要 VOLCENGINE_SPEECH_API_KEY,或旧版 VOLCENGINE_SPEECH_APP_ID + VOLCENGINE_SPEECH_ACCESS_KEY"
|
||
.to_string(),
|
||
))
|
||
}
|
||
|
||
pub fn public_config(&self) -> PublicSpeechConfig {
|
||
PublicSpeechConfig {
|
||
asr_resource_id: self.asr_resource_id.clone(),
|
||
tts_resource_id: self.tts_resource_id.clone(),
|
||
asr_audio: AsrAudioConfig::default(),
|
||
tts_audio: TtsAudioParams::default(),
|
||
endpoints: PublicSpeechEndpoints {
|
||
asr_stream: "/api/speech/volcengine/asr/stream",
|
||
tts_bidirection: "/api/speech/volcengine/tts/bidirection",
|
||
tts_sse: "/api/speech/volcengine/tts/sse",
|
||
},
|
||
}
|
||
}
|
||
|
||
fn auth_headers(&self, resource_id: &str) -> Result<HeaderMap, SpeechError> {
|
||
let mut headers = HeaderMap::new();
|
||
match self.auth_mode()? {
|
||
VolcengineSpeechAuthMode::ApiKey => {
|
||
headers.insert(
|
||
"X-Api-Key",
|
||
header_value(self.api_key.as_deref().unwrap_or(""))?,
|
||
);
|
||
}
|
||
VolcengineSpeechAuthMode::LegacyApp => {
|
||
headers.insert(
|
||
"X-Api-App-Key",
|
||
header_value(self.app_id.as_deref().unwrap_or(""))?,
|
||
);
|
||
headers.insert(
|
||
"X-Api-Access-Key",
|
||
header_value(self.access_key.as_deref().unwrap_or(""))?,
|
||
);
|
||
}
|
||
}
|
||
headers.insert("X-Api-Resource-Id", header_value(resource_id)?);
|
||
Ok(headers)
|
||
}
|
||
}
|
||
|
||
impl VolcengineSpeechClient {
|
||
pub fn new(config: VolcengineSpeechConfig) -> Self {
|
||
Self { config }
|
||
}
|
||
|
||
pub fn config(&self) -> &VolcengineSpeechConfig {
|
||
&self.config
|
||
}
|
||
|
||
pub async fn connect_asr(
|
||
&self,
|
||
) -> Result<(SpeechWsStream, BTreeMap<String, String>), SpeechError> {
|
||
let request_id = Uuid::new_v4().to_string();
|
||
let mut headers = self.config.auth_headers(&self.config.asr_resource_id)?;
|
||
headers.insert("X-Api-Request-Id", header_value(&request_id)?);
|
||
headers.insert("X-Api-Sequence", HeaderValue::from_static("-1"));
|
||
self.connect_ws(&self.config.asr_ws_url, headers).await
|
||
}
|
||
|
||
pub async fn connect_tts_bidirection(
|
||
&self,
|
||
) -> Result<(SpeechWsStream, BTreeMap<String, String>), SpeechError> {
|
||
let connect_id = Uuid::new_v4().to_string();
|
||
let mut headers = self.config.auth_headers(&self.config.tts_resource_id)?;
|
||
headers.insert("X-Api-Connect-Id", header_value(&connect_id)?);
|
||
self.connect_ws(&self.config.tts_bidirection_ws_url, headers)
|
||
.await
|
||
}
|
||
|
||
pub fn build_tts_sse_upstream_request(
|
||
&self,
|
||
request: TtsSseRequest,
|
||
user_id: &str,
|
||
) -> Result<TtsSseUpstreamRequest, SpeechError> {
|
||
let mut headers = self.config.auth_headers(&self.config.tts_resource_id)?;
|
||
headers.insert(
|
||
"X-Api-Request-Id",
|
||
header_value(&Uuid::new_v4().to_string())?,
|
||
);
|
||
headers.insert(
|
||
reqwest::header::ACCEPT,
|
||
HeaderValue::from_static("text/event-stream"),
|
||
);
|
||
headers.insert(
|
||
reqwest::header::CONTENT_TYPE,
|
||
HeaderValue::from_static("application/json"),
|
||
);
|
||
let body = build_tts_sse_body(request, user_id)?;
|
||
Ok(TtsSseUpstreamRequest {
|
||
url: self.config.tts_sse_url.clone(),
|
||
headers,
|
||
body,
|
||
timeout: Duration::from_millis(self.config.request_timeout_ms),
|
||
})
|
||
}
|
||
|
||
async fn connect_ws(
|
||
&self,
|
||
url: &str,
|
||
headers: HeaderMap,
|
||
) -> Result<(SpeechWsStream, BTreeMap<String, String>), SpeechError> {
|
||
let uri: Uri = url.parse().map_err(|error| {
|
||
SpeechError::InvalidConfig(format!("火山语音 WebSocket URL 非法:{error}"))
|
||
})?;
|
||
let mut request = uri.into_client_request().map_err(|error| {
|
||
SpeechError::InvalidConfig(format!("构造火山语音 WebSocket 请求失败:{error}"))
|
||
})?;
|
||
for (name, value) in headers {
|
||
if let Some(name) = name {
|
||
request.headers_mut().insert(name, value);
|
||
}
|
||
}
|
||
|
||
let (stream, response) = connect_async(request).await.map_err(|error| {
|
||
SpeechError::Upstream(format!("连接火山语音 WebSocket 失败:{error}"))
|
||
})?;
|
||
let response_headers = response
|
||
.headers()
|
||
.iter()
|
||
.filter_map(|(name, value)| {
|
||
value
|
||
.to_str()
|
||
.ok()
|
||
.map(|value| (name.as_str().to_string(), value.to_string()))
|
||
})
|
||
.collect();
|
||
Ok((stream, response_headers))
|
||
}
|
||
}
|
||
|
||
pub fn default_asr_request_payload(user_id: &str, override_payload: Option<Value>) -> Value {
|
||
let mut payload = json!({
|
||
"user": {
|
||
"uid": user_id,
|
||
},
|
||
"audio": {
|
||
"format": AsrAudioConfig::default().format,
|
||
"codec": AsrAudioConfig::default().codec,
|
||
"rate": AsrAudioConfig::default().rate,
|
||
"bits": AsrAudioConfig::default().bits,
|
||
"channel": AsrAudioConfig::default().channel,
|
||
},
|
||
"request": {
|
||
"model_name": "bigmodel",
|
||
"enable_itn": true,
|
||
"enable_punc": true,
|
||
"show_utterances": true,
|
||
"result_type": "full",
|
||
}
|
||
});
|
||
if let Some(override_payload) = override_payload {
|
||
merge_json_object(&mut payload, override_payload);
|
||
}
|
||
payload
|
||
}
|
||
|
||
pub fn build_asr_frame(kind: AsrFrameKind, payload: &[u8]) -> Result<Vec<u8>, SpeechError> {
|
||
match kind {
|
||
AsrFrameKind::FullClientRequest => build_sized_frame(
|
||
MESSAGE_FULL_CLIENT_REQUEST,
|
||
FLAG_NONE,
|
||
SERIALIZATION_JSON,
|
||
COMPRESSION_GZIP,
|
||
&gzip_bytes(payload)?,
|
||
),
|
||
AsrFrameKind::Audio => build_sized_frame(
|
||
MESSAGE_AUDIO_ONLY_REQUEST,
|
||
FLAG_NONE,
|
||
SERIALIZATION_NONE,
|
||
COMPRESSION_GZIP,
|
||
&gzip_bytes(payload)?,
|
||
),
|
||
AsrFrameKind::LastAudio => build_sized_frame(
|
||
MESSAGE_AUDIO_ONLY_REQUEST,
|
||
FLAG_LAST_PACKET,
|
||
SERIALIZATION_NONE,
|
||
COMPRESSION_GZIP,
|
||
&gzip_bytes(payload)?,
|
||
),
|
||
}
|
||
}
|
||
|
||
pub fn build_asr_full_client_request(payload: &Value) -> Result<Vec<u8>, SpeechError> {
|
||
let bytes = serde_json::to_vec(payload)
|
||
.map_err(|error| SpeechError::Serialize(format!("序列化 ASR 请求失败:{error}")))?;
|
||
build_asr_frame(AsrFrameKind::FullClientRequest, &bytes)
|
||
}
|
||
|
||
pub fn parse_asr_response_frame(bytes: &[u8]) -> Result<ParsedAsrResponse, SpeechError> {
|
||
let header = parse_header(bytes)?;
|
||
let mut offset = usize::from(header.header_size_words) * 4;
|
||
let sequence = if matches!(
|
||
header.flags,
|
||
FLAG_WITH_SEQUENCE | FLAG_WITH_NEGATIVE_SEQUENCE
|
||
) && bytes.len() >= offset + 4
|
||
{
|
||
let sequence = read_i32(bytes, &mut offset)?;
|
||
Some(sequence)
|
||
} else {
|
||
None
|
||
};
|
||
|
||
if header.message_type == MESSAGE_ERROR {
|
||
let error_code = read_u32(bytes, &mut offset).ok();
|
||
let payload = read_payload_value(bytes, &mut offset, SpeechCompression::None)
|
||
.unwrap_or_else(|_| json!({ "message": decode_lossy(&bytes[offset..]) }));
|
||
return Ok(ParsedAsrResponse {
|
||
header,
|
||
sequence,
|
||
event: None,
|
||
payload,
|
||
error_code,
|
||
});
|
||
}
|
||
|
||
let payload = read_payload_value(bytes, &mut offset, header.compression())?;
|
||
Ok(ParsedAsrResponse {
|
||
header,
|
||
sequence,
|
||
event: None,
|
||
payload,
|
||
error_code: None,
|
||
})
|
||
}
|
||
|
||
pub fn build_tts_bidirection_frame(
|
||
event: TtsEvent,
|
||
session_id: Option<&str>,
|
||
payload: Option<&Value>,
|
||
) -> Result<Vec<u8>, SpeechError> {
|
||
let payload_bytes = serde_json::to_vec(payload.unwrap_or(&json!({})))
|
||
.map_err(|error| SpeechError::Serialize(format!("序列化 TTS 事件失败:{error}")))?;
|
||
let event_number = event.to_i32();
|
||
let message_type = match event {
|
||
TtsEvent::TaskRequest => MESSAGE_FULL_CLIENT_REQUEST,
|
||
_ => MESSAGE_FULL_CLIENT_REQUEST,
|
||
};
|
||
build_tts_event_frame(
|
||
message_type,
|
||
SERIALIZATION_JSON,
|
||
COMPRESSION_NONE,
|
||
event_number,
|
||
session_id,
|
||
&payload_bytes,
|
||
)
|
||
}
|
||
|
||
pub fn build_tts_bidirection_frame_from_client_event(
|
||
event: TtsBidirectionClientEvent,
|
||
) -> Result<Vec<u8>, SpeechError> {
|
||
match event {
|
||
TtsBidirectionClientEvent::StartConnection { payload } => {
|
||
build_tts_bidirection_frame(TtsEvent::StartConnection, None, payload.as_ref())
|
||
}
|
||
TtsBidirectionClientEvent::FinishConnection { payload } => {
|
||
build_tts_bidirection_frame(TtsEvent::FinishConnection, None, payload.as_ref())
|
||
}
|
||
TtsBidirectionClientEvent::StartSession {
|
||
session_id,
|
||
payload,
|
||
} => {
|
||
let session_id = session_id.unwrap_or_else(|| Uuid::new_v4().to_string());
|
||
build_tts_bidirection_frame(TtsEvent::StartSession, Some(&session_id), Some(&payload))
|
||
}
|
||
TtsBidirectionClientEvent::FinishSession {
|
||
session_id,
|
||
payload,
|
||
} => build_tts_bidirection_frame(
|
||
TtsEvent::FinishSession,
|
||
Some(&session_id),
|
||
payload.as_ref(),
|
||
),
|
||
TtsBidirectionClientEvent::CancelSession {
|
||
session_id,
|
||
payload,
|
||
} => build_tts_bidirection_frame(
|
||
TtsEvent::CancelSession,
|
||
Some(&session_id),
|
||
payload.as_ref(),
|
||
),
|
||
TtsBidirectionClientEvent::TaskRequest {
|
||
session_id,
|
||
payload,
|
||
} => build_tts_bidirection_frame(TtsEvent::TaskRequest, Some(&session_id), Some(&payload)),
|
||
}
|
||
}
|
||
|
||
pub fn parse_tts_response_frame(bytes: &[u8]) -> Result<ParsedTtsResponse, SpeechError> {
|
||
let header = parse_header(bytes)?;
|
||
let mut offset = usize::from(header.header_size_words) * 4;
|
||
if header.message_type == MESSAGE_ERROR {
|
||
let error_code = read_u32(bytes, &mut offset).ok();
|
||
let payload = read_payload_value(bytes, &mut offset, header.compression())
|
||
.unwrap_or_else(|_| json!({ "message": decode_lossy(&bytes[offset..]) }));
|
||
return Ok(ParsedTtsResponse {
|
||
header,
|
||
event: None,
|
||
session_id: None,
|
||
connection_id: None,
|
||
payload: Some(payload),
|
||
audio: None,
|
||
error_code,
|
||
});
|
||
}
|
||
|
||
let event = if header.flags == FLAG_WITH_EVENT {
|
||
Some(TtsEvent::from_i32(read_i32(bytes, &mut offset)?))
|
||
} else {
|
||
None
|
||
};
|
||
let session_or_connection_id = read_optional_length_prefixed_string(bytes, &mut offset)?;
|
||
let payload_bytes = read_payload_bytes(bytes, &mut offset)?;
|
||
let payload_bytes = match header.compression() {
|
||
SpeechCompression::None => payload_bytes,
|
||
SpeechCompression::Gzip => ungzip_bytes(&payload_bytes)?,
|
||
};
|
||
let is_audio = header.message_type == MESSAGE_AUDIO_ONLY_RESPONSE
|
||
|| event == Some(TtsEvent::TtsResponse) && header.serialization == SERIALIZATION_NONE;
|
||
let payload = if is_audio {
|
||
None
|
||
} else if payload_bytes.is_empty() {
|
||
Some(json!({}))
|
||
} else {
|
||
Some(serde_json::from_slice(&payload_bytes).map_err(|error| {
|
||
SpeechError::InvalidFrame(format!("解析 TTS JSON 响应失败:{error}"))
|
||
})?)
|
||
};
|
||
let audio = if is_audio { Some(payload_bytes) } else { None };
|
||
let (connection_id, session_id) = match event {
|
||
Some(TtsEvent::ConnectionStarted)
|
||
| Some(TtsEvent::ConnectionFailed)
|
||
| Some(TtsEvent::ConnectionFinished) => (session_or_connection_id, None),
|
||
_ => (None, session_or_connection_id),
|
||
};
|
||
|
||
Ok(ParsedTtsResponse {
|
||
header,
|
||
event,
|
||
session_id,
|
||
connection_id,
|
||
payload,
|
||
audio,
|
||
error_code: None,
|
||
})
|
||
}
|
||
|
||
pub fn tts_response_to_client_value(response: &ParsedTtsResponse) -> Value {
|
||
json!({
|
||
"event": response.event.map(|event| event.name()),
|
||
"eventCode": response.event.map(|event| event.to_i32()),
|
||
"sessionId": response.session_id,
|
||
"connectionId": response.connection_id,
|
||
"payload": response.payload,
|
||
"audioBytes": response.audio.as_ref().map(Vec::len),
|
||
"errorCode": response.error_code,
|
||
})
|
||
}
|
||
|
||
pub fn build_tts_sse_body(request: TtsSseRequest, user_id: &str) -> Result<Value, SpeechError> {
|
||
let text = request.text.trim();
|
||
let speaker = request.speaker.trim();
|
||
if text.is_empty() && request.ssml.as_deref().unwrap_or("").trim().is_empty() {
|
||
return Err(SpeechError::InvalidConfig("TTS 文本不能为空".to_string()));
|
||
}
|
||
if speaker.is_empty() {
|
||
return Err(SpeechError::InvalidConfig(
|
||
"TTS speaker 不能为空".to_string(),
|
||
));
|
||
}
|
||
|
||
let mut req_params = json!({
|
||
"text": text,
|
||
"speaker": speaker,
|
||
"audio_params": {
|
||
"format": request.audio_params.clone().unwrap_or_default().format,
|
||
"sample_rate": request.audio_params.clone().unwrap_or_default().sample_rate,
|
||
},
|
||
});
|
||
if let Some(bit_rate) = request
|
||
.audio_params
|
||
.as_ref()
|
||
.and_then(|params| params.bit_rate)
|
||
{
|
||
req_params["audio_params"]["bit_rate"] = json!(bit_rate);
|
||
}
|
||
if let Some(model) = normalize_optional_secret(request.model) {
|
||
req_params["model"] = json!(model);
|
||
}
|
||
if let Some(ssml) = normalize_optional_secret(request.ssml) {
|
||
req_params["ssml"] = json!(ssml);
|
||
}
|
||
if let Some(additions) = request.additions {
|
||
req_params["additions"] = additions;
|
||
}
|
||
|
||
Ok(json!({
|
||
"user": {
|
||
"uid": user_id,
|
||
},
|
||
"req_params": req_params,
|
||
}))
|
||
}
|
||
|
||
pub async fn send_binary(ws: &mut SpeechWsStream, bytes: Vec<u8>) -> Result<(), SpeechError> {
|
||
ws.send(Message::Binary(bytes.into()))
|
||
.await
|
||
.map_err(|error| SpeechError::Upstream(format!("发送火山语音 WebSocket 帧失败:{error}")))
|
||
}
|
||
|
||
pub async fn recv_binary(ws: &mut SpeechWsStream) -> Result<Option<Vec<u8>>, SpeechError> {
|
||
while let Some(message) = ws.next().await {
|
||
match message {
|
||
Ok(Message::Binary(bytes)) => return Ok(Some(bytes.to_vec())),
|
||
Ok(Message::Text(text)) => return Ok(Some(text.as_bytes().to_vec())),
|
||
Ok(Message::Close(_)) => return Ok(None),
|
||
Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {}
|
||
Ok(Message::Frame(_)) => {}
|
||
Err(error) => {
|
||
return Err(SpeechError::Upstream(format!(
|
||
"读取火山语音 WebSocket 帧失败:{error}"
|
||
)));
|
||
}
|
||
}
|
||
}
|
||
Ok(None)
|
||
}
|
||
|
||
impl SpeechFrameHeader {
|
||
fn compression(self) -> SpeechCompression {
|
||
match self.compression {
|
||
COMPRESSION_GZIP => SpeechCompression::Gzip,
|
||
_ => SpeechCompression::None,
|
||
}
|
||
}
|
||
}
|
||
|
||
impl TtsEvent {
|
||
pub fn from_i32(value: i32) -> Self {
|
||
match value {
|
||
EVENT_START_CONNECTION => Self::StartConnection,
|
||
EVENT_FINISH_CONNECTION => Self::FinishConnection,
|
||
EVENT_CONNECTION_STARTED => Self::ConnectionStarted,
|
||
EVENT_CONNECTION_FAILED => Self::ConnectionFailed,
|
||
EVENT_CONNECTION_FINISHED => Self::ConnectionFinished,
|
||
EVENT_START_SESSION => Self::StartSession,
|
||
EVENT_CANCEL_SESSION => Self::CancelSession,
|
||
EVENT_FINISH_SESSION => Self::FinishSession,
|
||
EVENT_SESSION_STARTED => Self::SessionStarted,
|
||
EVENT_SESSION_CANCELED => Self::SessionCanceled,
|
||
EVENT_SESSION_FINISHED => Self::SessionFinished,
|
||
EVENT_SESSION_FAILED => Self::SessionFailed,
|
||
EVENT_TASK_REQUEST => Self::TaskRequest,
|
||
EVENT_TTS_SENTENCE_END => Self::TtsSentenceEnd,
|
||
EVENT_TTS_RESPONSE => Self::TtsResponse,
|
||
EVENT_TTS_SUBTITLE => Self::TtsSubtitle,
|
||
other => Self::Unknown(other),
|
||
}
|
||
}
|
||
|
||
pub fn to_i32(self) -> i32 {
|
||
match self {
|
||
Self::StartConnection => EVENT_START_CONNECTION,
|
||
Self::FinishConnection => EVENT_FINISH_CONNECTION,
|
||
Self::ConnectionStarted => EVENT_CONNECTION_STARTED,
|
||
Self::ConnectionFailed => EVENT_CONNECTION_FAILED,
|
||
Self::ConnectionFinished => EVENT_CONNECTION_FINISHED,
|
||
Self::StartSession => EVENT_START_SESSION,
|
||
Self::CancelSession => EVENT_CANCEL_SESSION,
|
||
Self::FinishSession => EVENT_FINISH_SESSION,
|
||
Self::SessionStarted => EVENT_SESSION_STARTED,
|
||
Self::SessionCanceled => EVENT_SESSION_CANCELED,
|
||
Self::SessionFinished => EVENT_SESSION_FINISHED,
|
||
Self::SessionFailed => EVENT_SESSION_FAILED,
|
||
Self::TaskRequest => EVENT_TASK_REQUEST,
|
||
Self::TtsSentenceEnd => EVENT_TTS_SENTENCE_END,
|
||
Self::TtsResponse => EVENT_TTS_RESPONSE,
|
||
Self::TtsSubtitle => EVENT_TTS_SUBTITLE,
|
||
Self::Unknown(value) => value,
|
||
}
|
||
}
|
||
|
||
pub fn name(self) -> &'static str {
|
||
match self {
|
||
Self::StartConnection => "start_connection",
|
||
Self::FinishConnection => "finish_connection",
|
||
Self::ConnectionStarted => "connection_started",
|
||
Self::ConnectionFailed => "connection_failed",
|
||
Self::ConnectionFinished => "connection_finished",
|
||
Self::StartSession => "start_session",
|
||
Self::CancelSession => "cancel_session",
|
||
Self::FinishSession => "finish_session",
|
||
Self::SessionStarted => "session_started",
|
||
Self::SessionCanceled => "session_canceled",
|
||
Self::SessionFinished => "session_finished",
|
||
Self::SessionFailed => "session_failed",
|
||
Self::TaskRequest => "task_request",
|
||
Self::TtsSentenceEnd => "tts_sentence_end",
|
||
Self::TtsResponse => "tts_response",
|
||
Self::TtsSubtitle => "tts_subtitle",
|
||
Self::Unknown(_) => "unknown",
|
||
}
|
||
}
|
||
}
|
||
|
||
impl fmt::Display for SpeechError {
|
||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||
match self {
|
||
Self::InvalidConfig(message)
|
||
| Self::InvalidHeader(message)
|
||
| Self::InvalidFrame(message)
|
||
| Self::Serialize(message)
|
||
| Self::Io(message)
|
||
| Self::Upstream(message) => write!(f, "{message}"),
|
||
}
|
||
}
|
||
}
|
||
|
||
impl Error for SpeechError {}
|
||
|
||
fn build_sized_frame(
|
||
message_type: u8,
|
||
flags: u8,
|
||
serialization: u8,
|
||
compression: u8,
|
||
payload: &[u8],
|
||
) -> Result<Vec<u8>, SpeechError> {
|
||
let payload_len = u32::try_from(payload.len())
|
||
.map_err(|_| SpeechError::InvalidFrame("语音帧 payload 超过 u32 上限".to_string()))?;
|
||
let mut frame = Vec::with_capacity(8 + payload.len());
|
||
frame.push((PROTOCOL_VERSION << 4) | HEADER_SIZE_FOUR_BYTES);
|
||
frame.push((message_type << 4) | flags);
|
||
frame.push((serialization << 4) | compression);
|
||
frame.push(0);
|
||
frame.extend_from_slice(&payload_len.to_be_bytes());
|
||
frame.extend_from_slice(payload);
|
||
Ok(frame)
|
||
}
|
||
|
||
fn build_tts_event_frame(
|
||
message_type: u8,
|
||
serialization: u8,
|
||
compression: u8,
|
||
event: i32,
|
||
session_id: Option<&str>,
|
||
payload: &[u8],
|
||
) -> Result<Vec<u8>, SpeechError> {
|
||
let id = session_id.unwrap_or("");
|
||
let id_bytes = id.as_bytes();
|
||
let id_len = u32::try_from(id_bytes.len())
|
||
.map_err(|_| SpeechError::InvalidFrame("TTS session id 超过 u32 上限".to_string()))?;
|
||
let payload_len = u32::try_from(payload.len())
|
||
.map_err(|_| SpeechError::InvalidFrame("TTS payload 超过 u32 上限".to_string()))?;
|
||
let mut frame = Vec::with_capacity(12 + id_bytes.len() + payload.len());
|
||
frame.push((PROTOCOL_VERSION << 4) | HEADER_SIZE_FOUR_BYTES);
|
||
frame.push((message_type << 4) | FLAG_WITH_EVENT);
|
||
frame.push((serialization << 4) | compression);
|
||
frame.push(0);
|
||
frame.extend_from_slice(&event.to_be_bytes());
|
||
if session_id.is_some() {
|
||
frame.extend_from_slice(&id_len.to_be_bytes());
|
||
frame.extend_from_slice(id_bytes);
|
||
}
|
||
frame.extend_from_slice(&payload_len.to_be_bytes());
|
||
frame.extend_from_slice(payload);
|
||
Ok(frame)
|
||
}
|
||
|
||
fn parse_header(bytes: &[u8]) -> Result<SpeechFrameHeader, SpeechError> {
|
||
if bytes.len() < 4 {
|
||
return Err(SpeechError::InvalidFrame(
|
||
"语音帧长度不足 4 字节".to_string(),
|
||
));
|
||
}
|
||
Ok(SpeechFrameHeader {
|
||
version: bytes[0] >> 4,
|
||
header_size_words: bytes[0] & 0x0f,
|
||
message_type: bytes[1] >> 4,
|
||
flags: bytes[1] & 0x0f,
|
||
serialization: bytes[2] >> 4,
|
||
compression: bytes[2] & 0x0f,
|
||
})
|
||
}
|
||
|
||
fn read_payload_value(
|
||
bytes: &[u8],
|
||
offset: &mut usize,
|
||
compression: SpeechCompression,
|
||
) -> Result<Value, SpeechError> {
|
||
let payload = read_payload_bytes(bytes, offset)?;
|
||
let payload = match compression {
|
||
SpeechCompression::None => payload,
|
||
SpeechCompression::Gzip => ungzip_bytes(&payload)?,
|
||
};
|
||
if payload.is_empty() {
|
||
return Ok(json!({}));
|
||
}
|
||
serde_json::from_slice(&payload)
|
||
.map_err(|error| SpeechError::InvalidFrame(format!("解析语音 JSON 帧失败:{error}")))
|
||
}
|
||
|
||
fn read_payload_bytes(bytes: &[u8], offset: &mut usize) -> Result<Vec<u8>, SpeechError> {
|
||
let payload_len = read_u32(bytes, offset)? as usize;
|
||
if bytes.len() < *offset + payload_len {
|
||
return Err(SpeechError::InvalidFrame(
|
||
"语音帧 payload 长度超过实际数据".to_string(),
|
||
));
|
||
}
|
||
let payload = bytes[*offset..*offset + payload_len].to_vec();
|
||
*offset += payload_len;
|
||
Ok(payload)
|
||
}
|
||
|
||
fn read_optional_length_prefixed_string(
|
||
bytes: &[u8],
|
||
offset: &mut usize,
|
||
) -> Result<Option<String>, SpeechError> {
|
||
if bytes.len() < *offset + 4 {
|
||
return Ok(None);
|
||
}
|
||
let saved_offset = *offset;
|
||
let len = read_u32(bytes, offset)? as usize;
|
||
if bytes.len() < *offset + len {
|
||
*offset = saved_offset;
|
||
return Ok(None);
|
||
}
|
||
let text = decode_lossy(&bytes[*offset..*offset + len]);
|
||
*offset += len;
|
||
if text.is_empty() {
|
||
Ok(None)
|
||
} else {
|
||
Ok(Some(text))
|
||
}
|
||
}
|
||
|
||
fn read_i32(bytes: &[u8], offset: &mut usize) -> Result<i32, SpeechError> {
|
||
if bytes.len() < *offset + 4 {
|
||
return Err(SpeechError::InvalidFrame("语音帧缺少 i32 字段".to_string()));
|
||
}
|
||
let value = i32::from_be_bytes([
|
||
bytes[*offset],
|
||
bytes[*offset + 1],
|
||
bytes[*offset + 2],
|
||
bytes[*offset + 3],
|
||
]);
|
||
*offset += 4;
|
||
Ok(value)
|
||
}
|
||
|
||
fn read_u32(bytes: &[u8], offset: &mut usize) -> Result<u32, SpeechError> {
|
||
if bytes.len() < *offset + 4 {
|
||
return Err(SpeechError::InvalidFrame("语音帧缺少 u32 字段".to_string()));
|
||
}
|
||
let value = u32::from_be_bytes([
|
||
bytes[*offset],
|
||
bytes[*offset + 1],
|
||
bytes[*offset + 2],
|
||
bytes[*offset + 3],
|
||
]);
|
||
*offset += 4;
|
||
Ok(value)
|
||
}
|
||
|
||
fn gzip_bytes(payload: &[u8]) -> Result<Vec<u8>, SpeechError> {
|
||
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
|
||
encoder
|
||
.write_all(payload)
|
||
.map_err(|error| SpeechError::Io(format!("gzip 压缩语音帧失败:{error}")))?;
|
||
encoder
|
||
.finish()
|
||
.map_err(|error| SpeechError::Io(format!("完成 gzip 压缩语音帧失败:{error}")))
|
||
}
|
||
|
||
fn ungzip_bytes(payload: &[u8]) -> Result<Vec<u8>, SpeechError> {
|
||
let mut decoder = GzDecoder::new(payload);
|
||
let mut output = Vec::new();
|
||
decoder
|
||
.read_to_end(&mut output)
|
||
.map_err(|error| SpeechError::Io(format!("gzip 解压语音帧失败:{error}")))?;
|
||
Ok(output)
|
||
}
|
||
|
||
fn header_value(value: &str) -> Result<HeaderValue, SpeechError> {
|
||
HeaderValue::from_str(value)
|
||
.map_err(|error| SpeechError::InvalidHeader(format!("构造火山语音请求头失败:{error}")))
|
||
}
|
||
|
||
fn normalize_optional_secret(value: Option<String>) -> Option<String> {
|
||
value
|
||
.map(|value| value.trim().to_string())
|
||
.filter(|value| !value.is_empty())
|
||
}
|
||
|
||
fn default_if_empty(value: String, default_value: &str) -> String {
|
||
let value = value.trim();
|
||
if value.is_empty() {
|
||
default_value.to_string()
|
||
} else {
|
||
value.to_string()
|
||
}
|
||
}
|
||
|
||
fn merge_json_object(target: &mut Value, source: Value) {
|
||
match (target, source) {
|
||
(Value::Object(target), Value::Object(source)) => {
|
||
for (key, value) in source {
|
||
merge_json_object(target.entry(key).or_insert(Value::Null), value);
|
||
}
|
||
}
|
||
(target, source) => *target = source,
|
||
}
|
||
}
|
||
|
||
fn decode_lossy(bytes: &[u8]) -> String {
|
||
String::from_utf8_lossy(bytes).trim().to_string()
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
fn test_config_with_api_key() -> VolcengineSpeechConfig {
|
||
VolcengineSpeechConfig::new(
|
||
Some("api-key".to_string()),
|
||
None,
|
||
None,
|
||
String::new(),
|
||
String::new(),
|
||
String::new(),
|
||
String::new(),
|
||
String::new(),
|
||
DEFAULT_REQUEST_TIMEOUT_MS,
|
||
)
|
||
.expect("config should build")
|
||
}
|
||
|
||
#[test]
|
||
fn config_prefers_api_key_auth_and_exposes_no_secret_in_public_config() {
|
||
let config = test_config_with_api_key();
|
||
|
||
assert_eq!(
|
||
config.auth_mode().unwrap(),
|
||
VolcengineSpeechAuthMode::ApiKey
|
||
);
|
||
let public_config = config.public_config();
|
||
assert_eq!(public_config.asr_resource_id, DEFAULT_ASR_RESOURCE_ID);
|
||
assert_eq!(public_config.tts_resource_id, DEFAULT_TTS_RESOURCE_ID);
|
||
assert_eq!(
|
||
public_config.endpoints.asr_stream,
|
||
"/api/speech/volcengine/asr/stream"
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn config_accepts_legacy_auth_when_api_key_missing() {
|
||
let config = VolcengineSpeechConfig::new(
|
||
None,
|
||
Some("app-id".to_string()),
|
||
Some("access-key".to_string()),
|
||
String::new(),
|
||
String::new(),
|
||
String::new(),
|
||
String::new(),
|
||
String::new(),
|
||
DEFAULT_REQUEST_TIMEOUT_MS,
|
||
)
|
||
.expect("legacy config should build");
|
||
|
||
assert_eq!(
|
||
config.auth_mode().unwrap(),
|
||
VolcengineSpeechAuthMode::LegacyApp
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn asr_frame_roundtrip_parses_gzip_json_response() {
|
||
let payload = json!({ "result": { "text": "你好" } });
|
||
let payload_bytes = serde_json::to_vec(&payload).unwrap();
|
||
let compressed = gzip_bytes(&payload_bytes).unwrap();
|
||
let mut frame = vec![
|
||
(PROTOCOL_VERSION << 4) | HEADER_SIZE_FOUR_BYTES,
|
||
(MESSAGE_FULL_SERVER_RESPONSE << 4) | FLAG_WITH_SEQUENCE,
|
||
(SERIALIZATION_JSON << 4) | COMPRESSION_GZIP,
|
||
0,
|
||
];
|
||
frame.extend_from_slice(&7_i32.to_be_bytes());
|
||
frame.extend_from_slice(&(compressed.len() as u32).to_be_bytes());
|
||
frame.extend_from_slice(&compressed);
|
||
|
||
let parsed = parse_asr_response_frame(&frame).expect("asr response should parse");
|
||
|
||
assert_eq!(parsed.sequence, Some(7));
|
||
assert_eq!(parsed.payload["result"]["text"], "你好");
|
||
}
|
||
|
||
#[test]
|
||
fn asr_full_request_frame_uses_expected_header() {
|
||
let frame = build_asr_full_client_request(&default_asr_request_payload("user-1", None))
|
||
.expect("frame should build");
|
||
|
||
assert_eq!(frame[0], 0x11);
|
||
assert_eq!(frame[1], 0x10);
|
||
assert_eq!(frame[2], 0x11);
|
||
assert!(frame.len() > 8);
|
||
}
|
||
|
||
#[test]
|
||
fn tts_start_session_frame_contains_event_and_session_id() {
|
||
let frame = build_tts_bidirection_frame(
|
||
TtsEvent::StartSession,
|
||
Some("session-1"),
|
||
Some(&json!({ "req_params": { "speaker": "voice" } })),
|
||
)
|
||
.expect("tts frame should build");
|
||
|
||
assert_eq!(frame[0], 0x11);
|
||
assert_eq!(frame[1], 0x14);
|
||
assert_eq!(frame[2], 0x10);
|
||
assert_eq!(
|
||
i32::from_be_bytes([frame[4], frame[5], frame[6], frame[7]]),
|
||
100
|
||
);
|
||
assert_eq!(
|
||
u32::from_be_bytes([frame[8], frame[9], frame[10], frame[11]]),
|
||
9
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn tts_response_frame_parses_json_event() {
|
||
let payload = json!({ "status_code": 20000000, "message": "ok" });
|
||
let payload_bytes = serde_json::to_vec(&payload).unwrap();
|
||
let frame = build_tts_event_frame(
|
||
MESSAGE_FULL_SERVER_RESPONSE,
|
||
SERIALIZATION_JSON,
|
||
COMPRESSION_NONE,
|
||
TtsEvent::SessionFinished.to_i32(),
|
||
Some("session-1"),
|
||
&payload_bytes,
|
||
)
|
||
.expect("response frame should build");
|
||
|
||
let parsed = parse_tts_response_frame(&frame).expect("tts response should parse");
|
||
|
||
assert_eq!(parsed.event, Some(TtsEvent::SessionFinished));
|
||
assert_eq!(parsed.session_id.as_deref(), Some("session-1"));
|
||
assert_eq!(parsed.payload.unwrap()["status_code"], 20000000);
|
||
}
|
||
|
||
#[test]
|
||
fn tts_sse_body_uses_snake_case_audio_params() {
|
||
let body = build_tts_sse_body(
|
||
TtsSseRequest {
|
||
text: "你好".to_string(),
|
||
speaker: "voice".to_string(),
|
||
model: None,
|
||
audio_params: Some(TtsAudioParams {
|
||
format: "mp3".to_string(),
|
||
sample_rate: 24_000,
|
||
bit_rate: Some(64_000),
|
||
}),
|
||
additions: None,
|
||
ssml: None,
|
||
},
|
||
"user-1",
|
||
)
|
||
.expect("sse body should build");
|
||
|
||
assert_eq!(body["user"]["uid"], "user-1");
|
||
assert_eq!(body["req_params"]["audio_params"]["sample_rate"], 24_000);
|
||
assert_eq!(body["req_params"]["audio_params"]["bit_rate"], 64_000);
|
||
}
|
||
}
|