Files
Genarrative/server-rs/crates/api-server/src/sse.rs

131 lines
3.8 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use axum::{
http::{HeaderName, StatusCode, header},
response::{IntoResponse, Response},
};
use serde::Serialize;
use serde_json::json;
use crate::http_error::AppError;
/// 最小缓冲式 SSE builder适用于“先完成业务再一次性返回完整 SSE 文本”的兼容链路。
#[derive(Default)]
pub struct SseEventBuffer {
body: String,
}
impl SseEventBuffer {
pub fn new() -> Self {
Self::default()
}
pub fn push_json<T>(&mut self, event: &str, payload: &T) -> Result<(), AppError>
where
T: Serialize,
{
encode_sse_event(&mut self.body, event, payload)
}
pub fn into_response(self) -> Response {
build_sse_response(self.body)
}
}
pub fn encode_sse_event<T>(body: &mut String, event: &str, payload: &T) -> Result<(), AppError>
where
T: Serialize,
{
let payload_text = serde_json::to_string(payload).map_err(|error| {
AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR).with_details(json!({
"provider": "sse",
"message": format!("SSE payload 序列化失败:{error}"),
}))
})?;
body.push_str("event: ");
body.push_str(event);
body.push('\n');
body.push_str("data: ");
body.push_str(&payload_text);
body.push_str("\n\n");
Ok(())
}
pub fn build_sse_response(body: String) -> Response {
(
[
(header::CONTENT_TYPE, "text/event-stream; charset=utf-8"),
(header::CACHE_CONTROL, "no-cache"),
// 反向代理场景下显式关闭缓冲,避免 SSE 事件被聚合后才下发。
(HeaderName::from_static("x-accel-buffering"), "no"),
],
body,
)
.into_response()
}
#[cfg(test)]
mod tests {
use super::{SseEventBuffer, build_sse_response, encode_sse_event};
use axum::body::to_bytes;
use serde_json::json;
#[tokio::test]
async fn encode_sse_event_writes_standard_format() {
let mut body = String::new();
encode_sse_event(&mut body, "reply_delta", &json!({ "text": "hello" }))
.expect("encoding should succeed");
assert_eq!(body, "event: reply_delta\ndata: {\"text\":\"hello\"}\n\n");
}
#[tokio::test]
async fn build_sse_response_sets_standard_headers() {
let response = build_sse_response("event: done\ndata: {\"ok\":true}\n\n".to_string());
assert_eq!(
response
.headers()
.get(header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok()),
Some("text/event-stream; charset=utf-8")
);
assert_eq!(
response
.headers()
.get(header::CACHE_CONTROL)
.and_then(|value| value.to_str().ok()),
Some("no-cache")
);
assert_eq!(
response
.headers()
.get(HeaderName::from_static("x-accel-buffering"))
.and_then(|value| value.to_str().ok()),
Some("no")
);
}
#[tokio::test]
async fn sse_event_buffer_collects_events_and_returns_response() {
let mut buffer = SseEventBuffer::new();
buffer
.push_json("reply_delta", &json!({ "text": "hello" }))
.expect("first event should encode");
buffer
.push_json("done", &json!({ "ok": true }))
.expect("second event should encode");
let response = buffer.into_response();
let body = to_bytes(response.into_body(), usize::MAX)
.await
.expect("response body should read");
let text = String::from_utf8(body.to_vec()).expect("body should be utf8");
assert_eq!(
text,
"event: reply_delta\ndata: {\"text\":\"hello\"}\n\nevent: done\ndata: {\"ok\":true}\n\n"
);
}
}