443 lines
15 KiB
Rust
443 lines
15 KiB
Rust
use axum::{
|
|
Json,
|
|
extract::{Extension, State},
|
|
http::StatusCode,
|
|
response::{
|
|
IntoResponse, Response,
|
|
sse::{Event, Sse},
|
|
},
|
|
};
|
|
use platform_llm::{LlmMessage, LlmMessageRole, LlmTextProtocol, LlmTextRequest};
|
|
use serde_json::{Value, json};
|
|
use shared_contracts::llm::{
|
|
LlmChatCompletionRequest, LlmChatCompletionResponse, LlmChatMessagePayload, LlmChatMessageRole,
|
|
};
|
|
use std::convert::Infallible;
|
|
|
|
use crate::{
|
|
api_response::json_success_body, auth::AuthenticatedAccessToken, http_error::AppError,
|
|
platform_errors::map_llm_error, request_context::RequestContext, state::AppState,
|
|
};
|
|
|
|
pub async fn proxy_llm_chat_completions(
|
|
State(state): State<AppState>,
|
|
Extension(request_context): Extension<RequestContext>,
|
|
Extension(_authenticated): Extension<AuthenticatedAccessToken>,
|
|
Json(payload): Json<LlmChatCompletionRequest>,
|
|
) -> Result<Response, Response> {
|
|
let llm_client = state.llm_client().ok_or_else(|| {
|
|
llm_error_response(
|
|
&request_context,
|
|
AppError::from_status(StatusCode::SERVICE_UNAVAILABLE)
|
|
.with_message("服务端尚未配置可用的 LLM API Key"),
|
|
)
|
|
})?;
|
|
|
|
let request = LlmTextRequest {
|
|
model: payload.model,
|
|
protocol: LlmTextProtocol::ChatCompletions,
|
|
messages: payload
|
|
.messages
|
|
.into_iter()
|
|
.map(map_chat_message)
|
|
.collect::<Vec<_>>(),
|
|
max_tokens: None,
|
|
enable_web_search: false,
|
|
};
|
|
|
|
if payload.stream {
|
|
return Ok(stream_llm_chat_completions(llm_client.clone(), request).into_response());
|
|
}
|
|
|
|
let response = llm_client
|
|
.request_text(request)
|
|
.await
|
|
.map_err(|error| llm_error_response(&request_context, map_llm_error(error)))?;
|
|
|
|
Ok(json_success_body(
|
|
Some(&request_context),
|
|
LlmChatCompletionResponse {
|
|
id: response.response_id,
|
|
model: response.model,
|
|
content: response.content,
|
|
finish_reason: response.finish_reason,
|
|
},
|
|
)
|
|
.into_response())
|
|
}
|
|
|
|
fn stream_llm_chat_completions(
|
|
llm_client: platform_llm::LlmClient,
|
|
request: LlmTextRequest,
|
|
) -> Sse<impl tokio_stream::Stream<Item = Result<Event, Infallible>>> {
|
|
let stream = async_stream::stream! {
|
|
let (delta_tx, mut delta_rx) = tokio::sync::mpsc::unbounded_channel::<Value>();
|
|
let llm_stream = llm_client.stream_text(request, move |delta| {
|
|
let _ = delta_tx.send(json!({
|
|
"delta": delta.delta_text,
|
|
"content": delta.accumulated_text,
|
|
"finishReason": delta.finish_reason,
|
|
}));
|
|
});
|
|
tokio::pin!(llm_stream);
|
|
|
|
let llm_result = loop {
|
|
// `platform-llm` 负责上游 SSE 解析;这里尽快把增量转成 API 层 SSE 事件。
|
|
tokio::select! {
|
|
result = &mut llm_stream => break result,
|
|
maybe_delta = delta_rx.recv() => {
|
|
if let Some(delta) = maybe_delta {
|
|
yield Ok::<Event, Infallible>(llm_sse_json_event_or_error("delta", delta));
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
while let Some(delta) = delta_rx.recv().await {
|
|
yield Ok::<Event, Infallible>(llm_sse_json_event_or_error("delta", delta));
|
|
}
|
|
|
|
match llm_result {
|
|
Ok(response) => {
|
|
yield Ok::<Event, Infallible>(llm_sse_json_event_or_error(
|
|
"complete",
|
|
json!(LlmChatCompletionResponse {
|
|
id: response.response_id,
|
|
model: response.model,
|
|
content: response.content,
|
|
finish_reason: response.finish_reason,
|
|
}),
|
|
));
|
|
}
|
|
Err(error) => {
|
|
let app_error = map_llm_error(error);
|
|
yield Ok::<Event, Infallible>(llm_sse_json_event_or_error(
|
|
"error",
|
|
json!({
|
|
"code": app_error.code(),
|
|
"message": app_error.message(),
|
|
}),
|
|
));
|
|
}
|
|
}
|
|
|
|
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
|
|
};
|
|
|
|
Sse::new(stream)
|
|
}
|
|
|
|
fn llm_sse_json_event_or_error(event_name: &str, payload: Value) -> Event {
|
|
match serde_json::to_string(&payload) {
|
|
Ok(payload_text) => Event::default().event(event_name).data(payload_text),
|
|
Err(_) => Event::default()
|
|
.event("error")
|
|
.data("{\"code\":\"INTERNAL_SERVER_ERROR\",\"message\":\"SSE payload 序列化失败\"}"),
|
|
}
|
|
}
|
|
|
|
fn map_chat_message(message: LlmChatMessagePayload) -> LlmMessage {
|
|
let role = match message.role {
|
|
LlmChatMessageRole::System => LlmMessageRole::System,
|
|
LlmChatMessageRole::User => LlmMessageRole::User,
|
|
LlmChatMessageRole::Assistant => LlmMessageRole::Assistant,
|
|
};
|
|
|
|
LlmMessage::new(role, message.content)
|
|
}
|
|
|
|
fn llm_error_response(request_context: &RequestContext, error: AppError) -> Response {
|
|
error.into_response_with_context(Some(request_context))
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use std::{
|
|
io::{Read, Write},
|
|
net::TcpListener,
|
|
thread,
|
|
time::Duration as StdDuration,
|
|
};
|
|
|
|
use axum::{
|
|
body::Body,
|
|
http::{Request, StatusCode},
|
|
};
|
|
use http_body_util::BodyExt;
|
|
use platform_auth::{
|
|
AccessTokenClaims, AccessTokenClaimsInput, AuthProvider, BindingStatus, sign_access_token,
|
|
};
|
|
use serde_json::{Value, json};
|
|
use time::OffsetDateTime;
|
|
use tower::ServiceExt;
|
|
|
|
use crate::{app::build_router, config::AppConfig, state::AppState};
|
|
|
|
struct MockResponse {
|
|
status_line: &'static str,
|
|
content_type: &'static str,
|
|
body: String,
|
|
extra_headers: Vec<(&'static str, &'static str)>,
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn llm_chat_completions_returns_non_stream_text_payload() {
|
|
let server_url = spawn_mock_server(vec![MockResponse {
|
|
status_line: "200 OK",
|
|
content_type: "application/json; charset=utf-8",
|
|
body: r#"{"id":"resp_api_server_01","model":"ark-router-test","choices":[{"message":{"content":"代理成功"},"finish_reason":"stop"}]}"#.to_string(),
|
|
extra_headers: Vec::new(),
|
|
}]);
|
|
let state = seed_authenticated_state(AppConfig {
|
|
llm_base_url: server_url,
|
|
llm_api_key: Some("test-key".to_string()),
|
|
..AppConfig::default()
|
|
})
|
|
.await;
|
|
let token = issue_access_token(&state);
|
|
let app = build_router(state);
|
|
|
|
let response = app
|
|
.oneshot(
|
|
Request::builder()
|
|
.method("POST")
|
|
.uri("/api/llm/chat/completions")
|
|
.header("authorization", format!("Bearer {token}"))
|
|
.header("content-type", "application/json")
|
|
.header("x-genarrative-response-envelope", "v1")
|
|
.body(Body::from(
|
|
json!({
|
|
"messages": [
|
|
{ "role": "system", "content": "系统" },
|
|
{ "role": "user", "content": "用户" }
|
|
]
|
|
})
|
|
.to_string(),
|
|
))
|
|
.expect("request should build"),
|
|
)
|
|
.await
|
|
.expect("request should succeed");
|
|
|
|
assert_eq!(response.status(), StatusCode::OK);
|
|
|
|
let body = response
|
|
.into_body()
|
|
.collect()
|
|
.await
|
|
.expect("body should collect")
|
|
.to_bytes();
|
|
let payload: Value =
|
|
serde_json::from_slice(&body).expect("response body should be valid json");
|
|
|
|
assert_eq!(payload["ok"], Value::Bool(true));
|
|
assert_eq!(
|
|
payload["data"]["id"],
|
|
Value::String("resp_api_server_01".to_string())
|
|
);
|
|
assert_eq!(
|
|
payload["data"]["model"],
|
|
Value::String("ark-router-test".to_string())
|
|
);
|
|
assert_eq!(
|
|
payload["data"]["content"],
|
|
Value::String("代理成功".to_string())
|
|
);
|
|
assert_eq!(
|
|
payload["data"]["finishReason"],
|
|
Value::String("stop".to_string())
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn llm_chat_completions_streams_sse_payload() {
|
|
let server_url = spawn_mock_server(vec![MockResponse {
|
|
status_line: "200 OK",
|
|
content_type: "text/event-stream; charset=utf-8",
|
|
body: concat!(
|
|
"data: {\"choices\":[{\"delta\":{\"content\":\"你\"}}]}\n\n",
|
|
"data: {\"choices\":[{\"delta\":{\"content\":\"好\"}}]}\n\n",
|
|
"data: {\"choices\":[{\"finish_reason\":\"stop\"}]}\n\n",
|
|
"data: [DONE]\n\n"
|
|
)
|
|
.to_string(),
|
|
extra_headers: vec![("x-request-id", "req_llm_stream_01")],
|
|
}]);
|
|
let state = seed_authenticated_state(AppConfig {
|
|
llm_base_url: server_url,
|
|
llm_api_key: Some("test-key".to_string()),
|
|
..AppConfig::default()
|
|
})
|
|
.await;
|
|
let token = issue_access_token(&state);
|
|
let app = build_router(state);
|
|
|
|
let response = app
|
|
.oneshot(
|
|
Request::builder()
|
|
.method("POST")
|
|
.uri("/api/llm/chat/completions")
|
|
.header("authorization", format!("Bearer {token}"))
|
|
.header("content-type", "application/json")
|
|
.body(Body::from(
|
|
json!({
|
|
"stream": true,
|
|
"messages": [
|
|
{ "role": "user", "content": "用户" }
|
|
]
|
|
})
|
|
.to_string(),
|
|
))
|
|
.expect("request should build"),
|
|
)
|
|
.await
|
|
.expect("request should succeed");
|
|
|
|
assert_eq!(response.status(), StatusCode::OK);
|
|
assert_eq!(
|
|
response
|
|
.headers()
|
|
.get("content-type")
|
|
.and_then(|value| value.to_str().ok()),
|
|
Some("text/event-stream")
|
|
);
|
|
|
|
let body = response
|
|
.into_body()
|
|
.collect()
|
|
.await
|
|
.expect("body should collect")
|
|
.to_bytes();
|
|
let body_text = String::from_utf8(body.to_vec()).expect("body should be utf8");
|
|
|
|
assert!(body_text.contains("event: delta"));
|
|
assert!(body_text.contains(r#""delta":"你""#));
|
|
assert!(body_text.contains(r#""content":"你好""#));
|
|
assert!(body_text.contains("event: complete"));
|
|
assert!(body_text.contains(r#""id":"req_llm_stream_01""#));
|
|
assert!(body_text.contains(r#""finishReason":"stop""#));
|
|
assert!(body_text.contains("data: [DONE]"));
|
|
}
|
|
|
|
async fn seed_authenticated_state(config: AppConfig) -> AppState {
|
|
let state = AppState::new(config).expect("state should build");
|
|
state
|
|
.seed_test_phone_user_with_password("13800138101", "secret123")
|
|
.await
|
|
.id;
|
|
state
|
|
}
|
|
|
|
fn issue_access_token(state: &AppState) -> String {
|
|
let claims = AccessTokenClaims::from_input(
|
|
AccessTokenClaimsInput {
|
|
user_id: "user_00000001".to_string(),
|
|
session_id: "sess_llm_proxy".to_string(),
|
|
provider: AuthProvider::Password,
|
|
roles: vec!["user".to_string()],
|
|
token_version: 2,
|
|
phone_verified: true,
|
|
binding_status: BindingStatus::Active,
|
|
display_name: Some("LLM 代理用户".to_string()),
|
|
},
|
|
state.auth_jwt_config(),
|
|
OffsetDateTime::now_utc(),
|
|
)
|
|
.expect("claims should build");
|
|
|
|
sign_access_token(&claims, state.auth_jwt_config()).expect("token should sign")
|
|
}
|
|
|
|
fn spawn_mock_server(responses: Vec<MockResponse>) -> String {
|
|
let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind");
|
|
let address = listener.local_addr().expect("listener should have addr");
|
|
|
|
thread::spawn(move || {
|
|
for response in responses {
|
|
let (mut stream, _) = listener.accept().expect("request should connect");
|
|
read_request(&mut stream);
|
|
write_response(&mut stream, response);
|
|
}
|
|
});
|
|
|
|
format!("http://{address}")
|
|
}
|
|
|
|
fn read_request(stream: &mut std::net::TcpStream) {
|
|
stream
|
|
.set_read_timeout(Some(StdDuration::from_secs(1)))
|
|
.expect("read timeout should be set");
|
|
let mut buffer = Vec::new();
|
|
let mut chunk = [0_u8; 1024];
|
|
let mut expected_total = None;
|
|
|
|
loop {
|
|
match stream.read(&mut chunk) {
|
|
Ok(0) => break,
|
|
Ok(bytes_read) => {
|
|
buffer.extend_from_slice(&chunk[..bytes_read]);
|
|
|
|
if expected_total.is_none()
|
|
&& let Some(header_end) = find_header_end(&buffer)
|
|
{
|
|
let content_length =
|
|
read_content_length(&buffer[..header_end]).unwrap_or(0);
|
|
expected_total = Some(header_end + content_length);
|
|
}
|
|
|
|
if let Some(total_bytes) = expected_total
|
|
&& buffer.len() >= total_bytes
|
|
{
|
|
break;
|
|
}
|
|
}
|
|
Err(error)
|
|
if error.kind() == std::io::ErrorKind::WouldBlock
|
|
|| error.kind() == std::io::ErrorKind::TimedOut =>
|
|
{
|
|
break;
|
|
}
|
|
Err(error) => panic!("mock server failed to read request: {error}"),
|
|
}
|
|
}
|
|
}
|
|
|
|
fn write_response(stream: &mut std::net::TcpStream, response: MockResponse) {
|
|
let body = response.body;
|
|
let mut raw_response = format!(
|
|
"HTTP/1.1 {}\r\nContent-Type: {}\r\nContent-Length: {}\r\nConnection: close\r\n",
|
|
response.status_line,
|
|
response.content_type,
|
|
body.len()
|
|
);
|
|
for (name, value) in response.extra_headers {
|
|
raw_response.push_str(format!("{name}: {value}\r\n").as_str());
|
|
}
|
|
raw_response.push_str("\r\n");
|
|
raw_response.push_str(body.as_str());
|
|
|
|
stream
|
|
.write_all(raw_response.as_bytes())
|
|
.expect("mock response should be written");
|
|
stream.flush().expect("mock response should flush");
|
|
}
|
|
|
|
fn find_header_end(buffer: &[u8]) -> Option<usize> {
|
|
buffer
|
|
.windows(4)
|
|
.position(|window| window == b"\r\n\r\n")
|
|
.map(|index| index + 4)
|
|
}
|
|
|
|
fn read_content_length(headers: &[u8]) -> Option<usize> {
|
|
let text = String::from_utf8_lossy(headers);
|
|
text.lines().find_map(|line| {
|
|
let (name, value) = line.split_once(':')?;
|
|
if name.eq_ignore_ascii_case("content-length") {
|
|
return value.trim().parse::<usize>().ok();
|
|
}
|
|
None
|
|
})
|
|
}
|
|
}
|