use axum::{ Json, extract::{Extension, State}, http::StatusCode, response::{ IntoResponse, Response, sse::{Event, Sse}, }, }; use platform_llm::{LlmMessage, LlmMessageRole, LlmTextProtocol, LlmTextRequest}; use serde_json::{Value, json}; use shared_contracts::llm::{ LlmChatCompletionRequest, LlmChatCompletionResponse, LlmChatMessagePayload, LlmChatMessageRole, }; use std::convert::Infallible; use crate::{ api_response::json_success_body, auth::AuthenticatedAccessToken, http_error::AppError, platform_errors::map_llm_error, request_context::RequestContext, state::AppState, }; pub async fn proxy_llm_chat_completions( State(state): State, Extension(request_context): Extension, Extension(_authenticated): Extension, Json(payload): Json, ) -> Result { let llm_client = state.llm_client().ok_or_else(|| { llm_error_response( &request_context, AppError::from_status(StatusCode::SERVICE_UNAVAILABLE) .with_message("服务端尚未配置可用的 LLM API Key"), ) })?; let request = LlmTextRequest { model: payload.model, protocol: LlmTextProtocol::ChatCompletions, messages: payload .messages .into_iter() .map(map_chat_message) .collect::>(), max_tokens: None, enable_web_search: false, }; if payload.stream { return Ok(stream_llm_chat_completions(llm_client.clone(), request).into_response()); } let response = llm_client .request_text(request) .await .map_err(|error| llm_error_response(&request_context, map_llm_error(error)))?; Ok(json_success_body( Some(&request_context), LlmChatCompletionResponse { id: response.response_id, model: response.model, content: response.content, finish_reason: response.finish_reason, }, ) .into_response()) } fn stream_llm_chat_completions( llm_client: platform_llm::LlmClient, request: LlmTextRequest, ) -> Sse>> { let stream = async_stream::stream! { let (delta_tx, mut delta_rx) = tokio::sync::mpsc::unbounded_channel::(); let llm_stream = llm_client.stream_text(request, move |delta| { let _ = delta_tx.send(json!({ "delta": delta.delta_text, "content": delta.accumulated_text, "finishReason": delta.finish_reason, })); }); tokio::pin!(llm_stream); let llm_result = loop { // `platform-llm` 负责上游 SSE 解析;这里尽快把增量转成 API 层 SSE 事件。 tokio::select! { result = &mut llm_stream => break result, maybe_delta = delta_rx.recv() => { if let Some(delta) = maybe_delta { yield Ok::(llm_sse_json_event_or_error("delta", delta)); } } } }; while let Some(delta) = delta_rx.recv().await { yield Ok::(llm_sse_json_event_or_error("delta", delta)); } match llm_result { Ok(response) => { yield Ok::(llm_sse_json_event_or_error( "complete", json!(LlmChatCompletionResponse { id: response.response_id, model: response.model, content: response.content, finish_reason: response.finish_reason, }), )); } Err(error) => { let app_error = map_llm_error(error); yield Ok::(llm_sse_json_event_or_error( "error", json!({ "code": app_error.code(), "message": app_error.message(), }), )); } } yield Ok::(Event::default().data("[DONE]")); }; Sse::new(stream) } fn llm_sse_json_event_or_error(event_name: &str, payload: Value) -> Event { match serde_json::to_string(&payload) { Ok(payload_text) => Event::default().event(event_name).data(payload_text), Err(_) => Event::default() .event("error") .data("{\"code\":\"INTERNAL_SERVER_ERROR\",\"message\":\"SSE payload 序列化失败\"}"), } } fn map_chat_message(message: LlmChatMessagePayload) -> LlmMessage { let role = match message.role { LlmChatMessageRole::System => LlmMessageRole::System, LlmChatMessageRole::User => LlmMessageRole::User, LlmChatMessageRole::Assistant => LlmMessageRole::Assistant, }; LlmMessage::new(role, message.content) } fn llm_error_response(request_context: &RequestContext, error: AppError) -> Response { error.into_response_with_context(Some(request_context)) } #[cfg(test)] mod tests { use std::{ io::{Read, Write}, net::TcpListener, thread, time::Duration as StdDuration, }; use axum::{ body::Body, http::{Request, StatusCode}, }; use http_body_util::BodyExt; use platform_auth::{ AccessTokenClaims, AccessTokenClaimsInput, AuthProvider, BindingStatus, sign_access_token, }; use serde_json::{Value, json}; use time::OffsetDateTime; use tower::ServiceExt; use crate::{app::build_router, config::AppConfig, state::AppState}; struct MockResponse { status_line: &'static str, content_type: &'static str, body: String, extra_headers: Vec<(&'static str, &'static str)>, } #[tokio::test] async fn llm_chat_completions_returns_non_stream_text_payload() { let server_url = spawn_mock_server(vec![MockResponse { status_line: "200 OK", content_type: "application/json; charset=utf-8", body: r#"{"id":"resp_api_server_01","model":"ark-router-test","choices":[{"message":{"content":"代理成功"},"finish_reason":"stop"}]}"#.to_string(), extra_headers: Vec::new(), }]); let state = seed_authenticated_state(AppConfig { llm_base_url: server_url, llm_api_key: Some("test-key".to_string()), ..AppConfig::default() }) .await; let token = issue_access_token(&state); let app = build_router(state); let response = app .oneshot( Request::builder() .method("POST") .uri("/api/llm/chat/completions") .header("authorization", format!("Bearer {token}")) .header("content-type", "application/json") .header("x-genarrative-response-envelope", "v1") .body(Body::from( json!({ "messages": [ { "role": "system", "content": "系统" }, { "role": "user", "content": "用户" } ] }) .to_string(), )) .expect("request should build"), ) .await .expect("request should succeed"); assert_eq!(response.status(), StatusCode::OK); let body = response .into_body() .collect() .await .expect("body should collect") .to_bytes(); let payload: Value = serde_json::from_slice(&body).expect("response body should be valid json"); assert_eq!(payload["ok"], Value::Bool(true)); assert_eq!( payload["data"]["id"], Value::String("resp_api_server_01".to_string()) ); assert_eq!( payload["data"]["model"], Value::String("ark-router-test".to_string()) ); assert_eq!( payload["data"]["content"], Value::String("代理成功".to_string()) ); assert_eq!( payload["data"]["finishReason"], Value::String("stop".to_string()) ); } #[tokio::test] async fn llm_chat_completions_streams_sse_payload() { let server_url = spawn_mock_server(vec![MockResponse { status_line: "200 OK", content_type: "text/event-stream; charset=utf-8", body: concat!( "data: {\"choices\":[{\"delta\":{\"content\":\"你\"}}]}\n\n", "data: {\"choices\":[{\"delta\":{\"content\":\"好\"}}]}\n\n", "data: {\"choices\":[{\"finish_reason\":\"stop\"}]}\n\n", "data: [DONE]\n\n" ) .to_string(), extra_headers: vec![("x-request-id", "req_llm_stream_01")], }]); let state = seed_authenticated_state(AppConfig { llm_base_url: server_url, llm_api_key: Some("test-key".to_string()), ..AppConfig::default() }) .await; let token = issue_access_token(&state); let app = build_router(state); let response = app .oneshot( Request::builder() .method("POST") .uri("/api/llm/chat/completions") .header("authorization", format!("Bearer {token}")) .header("content-type", "application/json") .body(Body::from( json!({ "stream": true, "messages": [ { "role": "user", "content": "用户" } ] }) .to_string(), )) .expect("request should build"), ) .await .expect("request should succeed"); assert_eq!(response.status(), StatusCode::OK); assert_eq!( response .headers() .get("content-type") .and_then(|value| value.to_str().ok()), Some("text/event-stream") ); let body = response .into_body() .collect() .await .expect("body should collect") .to_bytes(); let body_text = String::from_utf8(body.to_vec()).expect("body should be utf8"); assert!(body_text.contains("event: delta")); assert!(body_text.contains(r#""delta":"你""#)); assert!(body_text.contains(r#""content":"你好""#)); assert!(body_text.contains("event: complete")); assert!(body_text.contains(r#""id":"req_llm_stream_01""#)); assert!(body_text.contains(r#""finishReason":"stop""#)); assert!(body_text.contains("data: [DONE]")); } async fn seed_authenticated_state(config: AppConfig) -> AppState { let state = AppState::new(config).expect("state should build"); state .seed_test_phone_user_with_password("13800138101", "secret123") .await .id; state } fn issue_access_token(state: &AppState) -> String { let claims = AccessTokenClaims::from_input( AccessTokenClaimsInput { user_id: "user_00000001".to_string(), session_id: "sess_llm_proxy".to_string(), provider: AuthProvider::Password, roles: vec!["user".to_string()], token_version: 2, phone_verified: true, binding_status: BindingStatus::Active, display_name: Some("LLM 代理用户".to_string()), }, state.auth_jwt_config(), OffsetDateTime::now_utc(), ) .expect("claims should build"); sign_access_token(&claims, state.auth_jwt_config()).expect("token should sign") } fn spawn_mock_server(responses: Vec) -> String { let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind"); let address = listener.local_addr().expect("listener should have addr"); thread::spawn(move || { for response in responses { let (mut stream, _) = listener.accept().expect("request should connect"); read_request(&mut stream); write_response(&mut stream, response); } }); format!("http://{address}") } fn read_request(stream: &mut std::net::TcpStream) { stream .set_read_timeout(Some(StdDuration::from_secs(1))) .expect("read timeout should be set"); let mut buffer = Vec::new(); let mut chunk = [0_u8; 1024]; let mut expected_total = None; loop { match stream.read(&mut chunk) { Ok(0) => break, Ok(bytes_read) => { buffer.extend_from_slice(&chunk[..bytes_read]); if expected_total.is_none() && let Some(header_end) = find_header_end(&buffer) { let content_length = read_content_length(&buffer[..header_end]).unwrap_or(0); expected_total = Some(header_end + content_length); } if let Some(total_bytes) = expected_total && buffer.len() >= total_bytes { break; } } Err(error) if error.kind() == std::io::ErrorKind::WouldBlock || error.kind() == std::io::ErrorKind::TimedOut => { break; } Err(error) => panic!("mock server failed to read request: {error}"), } } } fn write_response(stream: &mut std::net::TcpStream, response: MockResponse) { let body = response.body; let mut raw_response = format!( "HTTP/1.1 {}\r\nContent-Type: {}\r\nContent-Length: {}\r\nConnection: close\r\n", response.status_line, response.content_type, body.len() ); for (name, value) in response.extra_headers { raw_response.push_str(format!("{name}: {value}\r\n").as_str()); } raw_response.push_str("\r\n"); raw_response.push_str(body.as_str()); stream .write_all(raw_response.as_bytes()) .expect("mock response should be written"); stream.flush().expect("mock response should flush"); } fn find_header_end(buffer: &[u8]) -> Option { buffer .windows(4) .position(|window| window == b"\r\n\r\n") .map(|index| index + 4) } fn read_content_length(headers: &[u8]) -> Option { let text = String::from_utf8_lossy(headers); text.lines().find_map(|line| { let (name, value) = line.split_once(':')?; if name.eq_ignore_ascii_case("content-length") { return value.trim().parse::().ok(); } None }) } }