extract rust sse infrastructure

This commit is contained in:
2026-04-22 14:52:30 +08:00
parent 91fb8edee7
commit 28ba990123
4 changed files with 140 additions and 23 deletions

View File

@@ -6,6 +6,7 @@ license.workspace = true
[dependencies]
axum = "0.8"
bytes = "1"
dotenvy = "0.15"
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] }
module-ai = { path = "../module-ai" }
@@ -27,7 +28,8 @@ shared-contracts = { path = "../shared-contracts" }
shared-kernel = { path = "../shared-kernel" }
shared-logging = { path = "../shared-logging" }
spacetime-client = { path = "../spacetime-client" }
tokio = { version = "1", features = ["macros", "rt-multi-thread", "net"] }
tokio = { version = "1", features = ["macros", "rt-multi-thread", "net", "sync"] }
tokio-stream = "0.1"
time = { version = "0.3", features = ["formatting"] }
tower-http = { version = "0.6", features = ["trace"] }
tracing = "0.1"

View File

@@ -1,9 +1,14 @@
use axum::{
body::Body,
http::{HeaderName, StatusCode, header},
response::{IntoResponse, Response},
};
use bytes::Bytes;
use serde::Serialize;
use serde_json::json;
use std::convert::Infallible;
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use crate::http_error::AppError;
@@ -30,6 +35,46 @@ impl SseEventBuffer {
}
}
/// 实时 SSE writer适用于“先返回响应再逐步推送事件”的真流式链路。
#[derive(Clone)]
#[allow(dead_code)]
pub struct SseStreamWriter {
sender: mpsc::UnboundedSender<Result<Bytes, Infallible>>,
}
#[allow(dead_code)]
impl SseStreamWriter {
pub fn push_json<T>(&self, event: &str, payload: &T) -> Result<(), AppError>
where
T: Serialize,
{
let mut body = String::new();
encode_sse_event(&mut body, event, payload)?;
self.sender.send(Ok(Bytes::from(body))).map_err(|_| {
AppError::from_status(StatusCode::GONE).with_details(json!({
"provider": "sse",
"message": "实时 SSE 通道已关闭,无法继续写入事件",
}))
})
}
}
/// 创建一条实时 SSE 响应和对应 writer。
///
/// 典型用法:
/// 1. handler 先调用本函数拿到 `(writer, response)`
/// 2. 立即把 `response` 返回给客户端
/// 3. 在后台任务里持续调用 `writer.push_json(...)`
/// 4. 所有 writer 被 drop 后SSE 流自动结束
#[allow(dead_code)]
pub fn new_sse_stream() -> (SseStreamWriter, Response) {
let (sender, receiver) = mpsc::unbounded_channel::<Result<Bytes, Infallible>>();
let body = Body::from_stream(UnboundedReceiverStream::new(receiver));
let response = build_sse_body_response(body);
(SseStreamWriter { sender }, response)
}
pub fn encode_sse_event<T>(body: &mut String, event: &str, payload: &T) -> Result<(), AppError>
where
T: Serialize,
@@ -52,21 +97,34 @@ where
}
pub fn build_sse_response(body: String) -> Response {
(
[
(header::CONTENT_TYPE, "text/event-stream; charset=utf-8"),
(header::CACHE_CONTROL, "no-cache"),
// 反向代理场景下显式关闭缓冲,避免 SSE 事件被聚合后才下发。
(HeaderName::from_static("x-accel-buffering"), "no"),
],
body,
)
.into_response()
build_sse_body_response(body)
}
fn build_sse_body_response(body: impl IntoResponse) -> Response {
let mut response = body.into_response();
let headers = response.headers_mut();
headers.insert(
header::CONTENT_TYPE,
"text/event-stream; charset=utf-8"
.parse()
.expect("valid sse content-type"),
);
headers.insert(
header::CACHE_CONTROL,
"no-cache".parse().expect("valid cache-control"),
);
// 反向代理场景下显式关闭缓冲,避免 SSE 事件被聚合后才下发。
headers.insert(
HeaderName::from_static("x-accel-buffering"),
"no".parse().expect("valid x-accel-buffering header"),
);
response
}
#[cfg(test)]
mod tests {
use super::{SseEventBuffer, build_sse_response, encode_sse_event};
use super::{SseEventBuffer, build_sse_response, encode_sse_event, new_sse_stream};
use axum::body::to_bytes;
use serde_json::json;
@@ -127,4 +185,35 @@ mod tests {
"event: reply_delta\ndata: {\"text\":\"hello\"}\n\nevent: done\ndata: {\"ok\":true}\n\n"
);
}
#[tokio::test]
async fn sse_stream_writer_writes_events_into_live_response_body() {
let (writer, response) = new_sse_stream();
writer
.push_json("reply_delta", &json!({ "text": "hello" }))
.expect("first live event should encode");
writer
.push_json("done", &json!({ "ok": true }))
.expect("second live event should encode");
drop(writer);
assert_eq!(
response
.headers()
.get(header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok()),
Some("text/event-stream; charset=utf-8")
);
let body = to_bytes(response.into_body(), usize::MAX)
.await
.expect("live response body should read");
let text = String::from_utf8(body.to_vec()).expect("live body should be utf8");
assert_eq!(
text,
"event: reply_delta\ndata: {\"text\":\"hello\"}\n\nevent: done\ndata: {\"ok\":true}\n\n"
);
}
}