use axum::{ http::{HeaderName, StatusCode, header}, response::{IntoResponse, Response}, }; use serde::Serialize; use serde_json::json; use crate::http_error::AppError; /// 最小缓冲式 SSE builder,适用于“先完成业务,再一次性返回完整 SSE 文本”的兼容链路。 #[derive(Default)] pub struct SseEventBuffer { body: String, } impl SseEventBuffer { pub fn new() -> Self { Self::default() } pub fn push_json(&mut self, event: &str, payload: &T) -> Result<(), AppError> where T: Serialize, { encode_sse_event(&mut self.body, event, payload) } pub fn into_response(self) -> Response { build_sse_response(self.body) } } pub fn encode_sse_event(body: &mut String, event: &str, payload: &T) -> Result<(), AppError> where T: Serialize, { let payload_text = serde_json::to_string(payload).map_err(|error| { AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR).with_details(json!({ "provider": "sse", "message": format!("SSE payload 序列化失败:{error}"), })) })?; body.push_str("event: "); body.push_str(event); body.push('\n'); body.push_str("data: "); body.push_str(&payload_text); body.push_str("\n\n"); Ok(()) } pub fn build_sse_response(body: String) -> Response { ( [ (header::CONTENT_TYPE, "text/event-stream; charset=utf-8"), (header::CACHE_CONTROL, "no-cache"), // 反向代理场景下显式关闭缓冲,避免 SSE 事件被聚合后才下发。 (HeaderName::from_static("x-accel-buffering"), "no"), ], body, ) .into_response() } #[cfg(test)] mod tests { use super::{SseEventBuffer, build_sse_response, encode_sse_event}; use axum::body::to_bytes; use serde_json::json; #[tokio::test] async fn encode_sse_event_writes_standard_format() { let mut body = String::new(); encode_sse_event(&mut body, "reply_delta", &json!({ "text": "hello" })) .expect("encoding should succeed"); assert_eq!(body, "event: reply_delta\ndata: {\"text\":\"hello\"}\n\n"); } #[tokio::test] async fn build_sse_response_sets_standard_headers() { let response = build_sse_response("event: done\ndata: {\"ok\":true}\n\n".to_string()); assert_eq!( response .headers() .get(header::CONTENT_TYPE) .and_then(|value| value.to_str().ok()), Some("text/event-stream; charset=utf-8") ); assert_eq!( response .headers() .get(header::CACHE_CONTROL) .and_then(|value| value.to_str().ok()), Some("no-cache") ); assert_eq!( response .headers() .get(HeaderName::from_static("x-accel-buffering")) .and_then(|value| value.to_str().ok()), Some("no") ); } #[tokio::test] async fn sse_event_buffer_collects_events_and_returns_response() { let mut buffer = SseEventBuffer::new(); buffer .push_json("reply_delta", &json!({ "text": "hello" })) .expect("first event should encode"); buffer .push_json("done", &json!({ "ok": true })) .expect("second event should encode"); let response = buffer.into_response(); let body = to_bytes(response.into_body(), usize::MAX) .await .expect("response body should read"); let text = String::from_utf8(body.to_vec()).expect("body should be utf8"); assert_eq!( text, "event: reply_delta\ndata: {\"text\":\"hello\"}\n\nevent: done\ndata: {\"ok\":true}\n\n" ); } }