switch api server sse to axum builtin
This commit is contained in:
@@ -2,7 +2,10 @@ use axum::{
|
||||
Json,
|
||||
extract::{Extension, Path, State, rejection::JsonRejection},
|
||||
http::StatusCode,
|
||||
response::Response,
|
||||
response::{
|
||||
IntoResponse, Response,
|
||||
sse::{Event, Sse},
|
||||
},
|
||||
};
|
||||
use module_custom_world::{
|
||||
CustomWorldThemeMode, empty_agent_anchor_content_json, empty_agent_asset_coverage_json,
|
||||
@@ -34,10 +37,11 @@ use spacetime_client::{
|
||||
CustomWorldResultPreviewBlockerRecord, CustomWorldSupportedActionRecord,
|
||||
CustomWorldWorkSummaryRecord, SpacetimeClientError,
|
||||
};
|
||||
use std::convert::Infallible;
|
||||
|
||||
use crate::{
|
||||
api_response::json_success_body, auth::AuthenticatedAccessToken, http_error::AppError,
|
||||
request_context::RequestContext, sse::SseEventBuffer, state::AppState,
|
||||
request_context::RequestContext, state::AppState,
|
||||
};
|
||||
|
||||
pub async fn get_custom_world_library(
|
||||
@@ -576,17 +580,23 @@ pub async fn stream_custom_world_agent_message(
|
||||
let session_response = map_custom_world_agent_session_response(session);
|
||||
let reply_text = resolve_stream_reply_text(&session_response);
|
||||
|
||||
// 这里先用“一次性构造完整 SSE 文本”的最小兼容方案,
|
||||
// 复用 Stage 7 的同步 deterministic 写表逻辑,保证前端当前的 reader 协议可直接消费。
|
||||
let mut sse = SseEventBuffer::new();
|
||||
sse.push_json("reply_delta", &json!({ "text": reply_text }))
|
||||
.map_err(|error| custom_world_error_response(&request_context, error))?;
|
||||
sse.push_json("session", &json!({ "session": session_response }))
|
||||
.map_err(|error| custom_world_error_response(&request_context, error))?;
|
||||
sse.push_json("done", &json!({ "ok": true }))
|
||||
.map_err(|error| custom_world_error_response(&request_context, error))?;
|
||||
// 这里仍保持“一次性返回完整事件序列”的兼容语义;
|
||||
// SSE 编码、标准响应头与 body frame 交给 Axum 内建实现维护。
|
||||
let events = vec![
|
||||
custom_world_sse_json_event("reply_delta", json!({ "text": reply_text }))
|
||||
.map_err(|error| custom_world_error_response(&request_context, error))?,
|
||||
custom_world_sse_json_event("session", json!({ "session": session_response }))
|
||||
.map_err(|error| custom_world_error_response(&request_context, error))?,
|
||||
custom_world_sse_json_event("done", json!({ "ok": true }))
|
||||
.map_err(|error| custom_world_error_response(&request_context, error))?,
|
||||
];
|
||||
let stream = tokio_stream::iter(
|
||||
events
|
||||
.into_iter()
|
||||
.map(|event| Ok::<Event, Infallible>(event)),
|
||||
);
|
||||
|
||||
Ok(sse.into_response())
|
||||
Ok(Sse::new(stream).into_response())
|
||||
}
|
||||
|
||||
pub async fn get_custom_world_agent_operation(
|
||||
@@ -983,6 +993,18 @@ fn custom_world_error_response(request_context: &RequestContext, error: AppError
|
||||
error.into_response_with_context(Some(request_context))
|
||||
}
|
||||
|
||||
fn custom_world_sse_json_event(event_name: &str, payload: Value) -> Result<Event, AppError> {
|
||||
Event::default()
|
||||
.event(event_name)
|
||||
.json_data(payload)
|
||||
.map_err(|error| {
|
||||
AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR).with_details(json!({
|
||||
"provider": "sse",
|
||||
"message": format!("SSE payload 序列化失败:{error}"),
|
||||
}))
|
||||
})
|
||||
}
|
||||
|
||||
fn resolve_author_display_name(_authenticated: &AuthenticatedAccessToken) -> String {
|
||||
"玩家".to_string()
|
||||
}
|
||||
|
||||
@@ -28,7 +28,6 @@ mod runtime_save;
|
||||
mod runtime_settings;
|
||||
mod runtime_story;
|
||||
mod session_client;
|
||||
mod sse;
|
||||
mod state;
|
||||
mod story_battles;
|
||||
mod story_sessions;
|
||||
|
||||
@@ -1,219 +0,0 @@
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{HeaderName, StatusCode, header},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use serde::Serialize;
|
||||
use serde_json::json;
|
||||
use std::convert::Infallible;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
|
||||
use crate::http_error::AppError;
|
||||
|
||||
/// 最小缓冲式 SSE builder,适用于“先完成业务,再一次性返回完整 SSE 文本”的兼容链路。
|
||||
#[derive(Default)]
|
||||
pub struct SseEventBuffer {
|
||||
body: String,
|
||||
}
|
||||
|
||||
impl SseEventBuffer {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn push_json<T>(&mut self, event: &str, payload: &T) -> Result<(), AppError>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
encode_sse_event(&mut self.body, event, payload)
|
||||
}
|
||||
|
||||
pub fn into_response(self) -> Response {
|
||||
build_sse_response(self.body)
|
||||
}
|
||||
}
|
||||
|
||||
/// 实时 SSE writer,适用于“先返回响应,再逐步推送事件”的真流式链路。
|
||||
#[derive(Clone)]
|
||||
#[allow(dead_code)]
|
||||
pub struct SseStreamWriter {
|
||||
sender: mpsc::UnboundedSender<Result<Bytes, Infallible>>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl SseStreamWriter {
|
||||
pub fn push_json<T>(&self, event: &str, payload: &T) -> Result<(), AppError>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
let mut body = String::new();
|
||||
encode_sse_event(&mut body, event, payload)?;
|
||||
self.sender.send(Ok(Bytes::from(body))).map_err(|_| {
|
||||
AppError::from_status(StatusCode::GONE).with_details(json!({
|
||||
"provider": "sse",
|
||||
"message": "实时 SSE 通道已关闭,无法继续写入事件",
|
||||
}))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// 创建一条实时 SSE 响应和对应 writer。
|
||||
///
|
||||
/// 典型用法:
|
||||
/// 1. handler 先调用本函数拿到 `(writer, response)`
|
||||
/// 2. 立即把 `response` 返回给客户端
|
||||
/// 3. 在后台任务里持续调用 `writer.push_json(...)`
|
||||
/// 4. 所有 writer 被 drop 后,SSE 流自动结束
|
||||
#[allow(dead_code)]
|
||||
pub fn new_sse_stream() -> (SseStreamWriter, Response) {
|
||||
let (sender, receiver) = mpsc::unbounded_channel::<Result<Bytes, Infallible>>();
|
||||
let body = Body::from_stream(UnboundedReceiverStream::new(receiver));
|
||||
let response = build_sse_body_response(body);
|
||||
|
||||
(SseStreamWriter { sender }, response)
|
||||
}
|
||||
|
||||
pub fn encode_sse_event<T>(body: &mut String, event: &str, payload: &T) -> Result<(), AppError>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
let payload_text = serde_json::to_string(payload).map_err(|error| {
|
||||
AppError::from_status(StatusCode::INTERNAL_SERVER_ERROR).with_details(json!({
|
||||
"provider": "sse",
|
||||
"message": format!("SSE payload 序列化失败:{error}"),
|
||||
}))
|
||||
})?;
|
||||
|
||||
body.push_str("event: ");
|
||||
body.push_str(event);
|
||||
body.push('\n');
|
||||
body.push_str("data: ");
|
||||
body.push_str(&payload_text);
|
||||
body.push_str("\n\n");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn build_sse_response(body: String) -> Response {
|
||||
build_sse_body_response(body)
|
||||
}
|
||||
|
||||
fn build_sse_body_response(body: impl IntoResponse) -> Response {
|
||||
let mut response = body.into_response();
|
||||
let headers = response.headers_mut();
|
||||
headers.insert(
|
||||
header::CONTENT_TYPE,
|
||||
"text/event-stream; charset=utf-8"
|
||||
.parse()
|
||||
.expect("valid sse content-type"),
|
||||
);
|
||||
headers.insert(
|
||||
header::CACHE_CONTROL,
|
||||
"no-cache".parse().expect("valid cache-control"),
|
||||
);
|
||||
// 反向代理场景下显式关闭缓冲,避免 SSE 事件被聚合后才下发。
|
||||
headers.insert(
|
||||
HeaderName::from_static("x-accel-buffering"),
|
||||
"no".parse().expect("valid x-accel-buffering header"),
|
||||
);
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{SseEventBuffer, build_sse_response, encode_sse_event, new_sse_stream};
|
||||
use axum::body::to_bytes;
|
||||
use serde_json::json;
|
||||
|
||||
#[tokio::test]
|
||||
async fn encode_sse_event_writes_standard_format() {
|
||||
let mut body = String::new();
|
||||
encode_sse_event(&mut body, "reply_delta", &json!({ "text": "hello" }))
|
||||
.expect("encoding should succeed");
|
||||
|
||||
assert_eq!(body, "event: reply_delta\ndata: {\"text\":\"hello\"}\n\n");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn build_sse_response_sets_standard_headers() {
|
||||
let response = build_sse_response("event: done\ndata: {\"ok\":true}\n\n".to_string());
|
||||
|
||||
assert_eq!(
|
||||
response
|
||||
.headers()
|
||||
.get(header::CONTENT_TYPE)
|
||||
.and_then(|value| value.to_str().ok()),
|
||||
Some("text/event-stream; charset=utf-8")
|
||||
);
|
||||
assert_eq!(
|
||||
response
|
||||
.headers()
|
||||
.get(header::CACHE_CONTROL)
|
||||
.and_then(|value| value.to_str().ok()),
|
||||
Some("no-cache")
|
||||
);
|
||||
assert_eq!(
|
||||
response
|
||||
.headers()
|
||||
.get(HeaderName::from_static("x-accel-buffering"))
|
||||
.and_then(|value| value.to_str().ok()),
|
||||
Some("no")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sse_event_buffer_collects_events_and_returns_response() {
|
||||
let mut buffer = SseEventBuffer::new();
|
||||
buffer
|
||||
.push_json("reply_delta", &json!({ "text": "hello" }))
|
||||
.expect("first event should encode");
|
||||
buffer
|
||||
.push_json("done", &json!({ "ok": true }))
|
||||
.expect("second event should encode");
|
||||
|
||||
let response = buffer.into_response();
|
||||
let body = to_bytes(response.into_body(), usize::MAX)
|
||||
.await
|
||||
.expect("response body should read");
|
||||
let text = String::from_utf8(body.to_vec()).expect("body should be utf8");
|
||||
|
||||
assert_eq!(
|
||||
text,
|
||||
"event: reply_delta\ndata: {\"text\":\"hello\"}\n\nevent: done\ndata: {\"ok\":true}\n\n"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sse_stream_writer_writes_events_into_live_response_body() {
|
||||
let (writer, response) = new_sse_stream();
|
||||
|
||||
writer
|
||||
.push_json("reply_delta", &json!({ "text": "hello" }))
|
||||
.expect("first live event should encode");
|
||||
writer
|
||||
.push_json("done", &json!({ "ok": true }))
|
||||
.expect("second live event should encode");
|
||||
drop(writer);
|
||||
|
||||
assert_eq!(
|
||||
response
|
||||
.headers()
|
||||
.get(header::CONTENT_TYPE)
|
||||
.and_then(|value| value.to_str().ok()),
|
||||
Some("text/event-stream; charset=utf-8")
|
||||
);
|
||||
|
||||
let body = to_bytes(response.into_body(), usize::MAX)
|
||||
.await
|
||||
.expect("live response body should read");
|
||||
let text = String::from_utf8(body.to_vec()).expect("live body should be utf8");
|
||||
|
||||
assert_eq!(
|
||||
text,
|
||||
"event: reply_delta\ndata: {\"text\":\"hello\"}\n\nevent: done\ndata: {\"ok\":true}\n\n"
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user