use axum::{ Json, extract::{Extension, State}, http::StatusCode, response::Response, }; use platform_llm::{LlmError, LlmMessage, LlmMessageRole, LlmTextRequest}; use serde_json::Value; use shared_contracts::llm::{ LlmChatCompletionRequest, LlmChatCompletionResponse, LlmChatMessagePayload, LlmChatMessageRole, }; use crate::{ api_response::json_success_body, auth::AuthenticatedAccessToken, http_error::AppError, request_context::RequestContext, state::AppState, }; pub async fn proxy_llm_chat_completions( State(state): State, Extension(request_context): Extension, Extension(_authenticated): Extension, Json(payload): Json, ) -> Result, Response> { if payload.stream { return Err(llm_error_response( &request_context, AppError::from_status(StatusCode::NOT_IMPLEMENTED) .with_message("Rust `api-server` 首版暂不支持流式 LLM 代理"), )); } let llm_client = state.llm_client().ok_or_else(|| { llm_error_response( &request_context, AppError::from_status(StatusCode::SERVICE_UNAVAILABLE) .with_message("服务端尚未配置可用的 LLM API Key"), ) })?; let request = LlmTextRequest { model: payload.model, messages: payload .messages .into_iter() .map(map_chat_message) .collect::>(), max_tokens: None, }; let response = llm_client .request_text(request) .await .map_err(|error| llm_error_response(&request_context, map_llm_error(error)))?; Ok(json_success_body( Some(&request_context), LlmChatCompletionResponse { id: response.response_id, model: response.model, content: response.content, finish_reason: response.finish_reason, }, )) } fn map_chat_message(message: LlmChatMessagePayload) -> LlmMessage { let role = match message.role { LlmChatMessageRole::System => LlmMessageRole::System, LlmChatMessageRole::User => LlmMessageRole::User, LlmChatMessageRole::Assistant => LlmMessageRole::Assistant, }; LlmMessage::new(role, message.content) } fn map_llm_error(error: LlmError) -> AppError { match error { LlmError::InvalidRequest(message) => { AppError::from_status(StatusCode::BAD_REQUEST).with_message(message) } LlmError::InvalidConfig(message) => { AppError::from_status(StatusCode::SERVICE_UNAVAILABLE).with_message(message) } LlmError::Upstream { status_code: 429, message, } => AppError::from_status(StatusCode::TOO_MANY_REQUESTS).with_message(message), LlmError::Upstream { message, .. } => { AppError::from_status(StatusCode::BAD_GATEWAY).with_message(message) } LlmError::Timeout { attempts } => AppError::from_status(StatusCode::BAD_GATEWAY) .with_message(format!("LLM 请求超时,累计尝试 {attempts} 次")), LlmError::Connectivity { attempts, message } => { AppError::from_status(StatusCode::BAD_GATEWAY) .with_message(format!("LLM 连接失败,累计尝试 {attempts} 次:{message}")) } LlmError::StreamUnavailable => { AppError::from_status(StatusCode::BAD_GATEWAY).with_message("LLM 流式响应体不可用") } LlmError::EmptyResponse => { AppError::from_status(StatusCode::BAD_GATEWAY).with_message("LLM 返回内容为空") } LlmError::Transport(message) | LlmError::Deserialize(message) => { AppError::from_status(StatusCode::BAD_GATEWAY).with_message(message) } } } fn llm_error_response(request_context: &RequestContext, error: AppError) -> Response { error.into_response_with_context(Some(request_context)) } #[cfg(test)] mod tests { use std::{ io::{Read, Write}, net::TcpListener, thread, time::Duration as StdDuration, }; use axum::{ body::Body, http::{Request, StatusCode}, }; use http_body_util::BodyExt; use platform_auth::{ AccessTokenClaims, AccessTokenClaimsInput, AuthProvider, BindingStatus, sign_access_token, }; use serde_json::{Value, json}; use time::OffsetDateTime; use tower::ServiceExt; use crate::{app::build_router, config::AppConfig, state::AppState}; struct MockResponse { status_line: &'static str, content_type: &'static str, body: String, } #[tokio::test] async fn llm_chat_completions_returns_non_stream_text_payload() { let server_url = spawn_mock_server(vec![MockResponse { status_line: "200 OK", content_type: "application/json; charset=utf-8", body: r#"{"id":"resp_api_server_01","model":"ark-router-test","choices":[{"message":{"content":"代理成功"},"finish_reason":"stop"}]}"#.to_string(), }]); let state = seed_authenticated_state(AppConfig { llm_base_url: server_url, llm_api_key: Some("test-key".to_string()), ..AppConfig::default() }) .await; let token = issue_access_token(&state); let app = build_router(state); let response = app .oneshot( Request::builder() .method("POST") .uri("/api/llm/chat/completions") .header("authorization", format!("Bearer {token}")) .header("content-type", "application/json") .header("x-genarrative-response-envelope", "v1") .body(Body::from( json!({ "messages": [ { "role": "system", "content": "系统" }, { "role": "user", "content": "用户" } ] }) .to_string(), )) .expect("request should build"), ) .await .expect("request should succeed"); assert_eq!(response.status(), StatusCode::OK); let body = response .into_body() .collect() .await .expect("body should collect") .to_bytes(); let payload: Value = serde_json::from_slice(&body).expect("response body should be valid json"); assert_eq!(payload["ok"], Value::Bool(true)); assert_eq!( payload["data"]["id"], Value::String("resp_api_server_01".to_string()) ); assert_eq!( payload["data"]["model"], Value::String("ark-router-test".to_string()) ); assert_eq!( payload["data"]["content"], Value::String("代理成功".to_string()) ); assert_eq!( payload["data"]["finishReason"], Value::String("stop".to_string()) ); } #[tokio::test] async fn llm_chat_completions_rejects_stream_mode() { let state = seed_authenticated_state(AppConfig::default()).await; let token = issue_access_token(&state); let app = build_router(state); let response = app .oneshot( Request::builder() .method("POST") .uri("/api/llm/chat/completions") .header("authorization", format!("Bearer {token}")) .header("content-type", "application/json") .header("x-genarrative-response-envelope", "v1") .body(Body::from( json!({ "stream": true, "messages": [ { "role": "user", "content": "用户" } ] }) .to_string(), )) .expect("request should build"), ) .await .expect("request should succeed"); assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED); let body = response .into_body() .collect() .await .expect("body should collect") .to_bytes(); let payload: Value = serde_json::from_slice(&body).expect("response body should be valid json"); assert_eq!(payload["ok"], Value::Bool(false)); assert_eq!( payload["error"]["code"], Value::String("NOT_IMPLEMENTED".to_string()) ); } async fn seed_authenticated_state(config: AppConfig) -> AppState { let state = AppState::new(config).expect("state should build"); state .password_entry_service() .execute(module_auth::PasswordEntryInput { username: "llm_proxy_user".to_string(), password: "secret123".to_string(), }) .await .expect("seed login should succeed"); state } fn issue_access_token(state: &AppState) -> String { let claims = AccessTokenClaims::from_input( AccessTokenClaimsInput { user_id: "user_00000001".to_string(), session_id: "sess_llm_proxy".to_string(), provider: AuthProvider::Password, roles: vec!["user".to_string()], token_version: 1, phone_verified: true, binding_status: BindingStatus::Active, display_name: Some("LLM 代理用户".to_string()), }, state.auth_jwt_config(), OffsetDateTime::now_utc(), ) .expect("claims should build"); sign_access_token(&claims, state.auth_jwt_config()).expect("token should sign") } fn spawn_mock_server(responses: Vec) -> String { let listener = TcpListener::bind("127.0.0.1:0").expect("listener should bind"); let address = listener.local_addr().expect("listener should have addr"); thread::spawn(move || { for response in responses { let (mut stream, _) = listener.accept().expect("request should connect"); read_request(&mut stream); write_response(&mut stream, response); } }); format!("http://{address}") } fn read_request(stream: &mut std::net::TcpStream) { stream .set_read_timeout(Some(StdDuration::from_secs(1))) .expect("read timeout should be set"); let mut buffer = Vec::new(); let mut chunk = [0_u8; 1024]; let mut expected_total = None; loop { match stream.read(&mut chunk) { Ok(0) => break, Ok(bytes_read) => { buffer.extend_from_slice(&chunk[..bytes_read]); if expected_total.is_none() && let Some(header_end) = find_header_end(&buffer) { let content_length = read_content_length(&buffer[..header_end]).unwrap_or(0); expected_total = Some(header_end + content_length); } if let Some(total_bytes) = expected_total && buffer.len() >= total_bytes { break; } } Err(error) if error.kind() == std::io::ErrorKind::WouldBlock || error.kind() == std::io::ErrorKind::TimedOut => { break; } Err(error) => panic!("mock server failed to read request: {error}"), } } } fn write_response(stream: &mut std::net::TcpStream, response: MockResponse) { let body = response.body; let raw_response = format!( "HTTP/1.1 {}\r\nContent-Type: {}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", response.status_line, response.content_type, body.len(), body ); stream .write_all(raw_response.as_bytes()) .expect("mock response should be written"); stream.flush().expect("mock response should flush"); } fn find_header_end(buffer: &[u8]) -> Option { buffer .windows(4) .position(|window| window == b"\r\n\r\n") .map(|index| index + 4) } fn read_content_length(headers: &[u8]) -> Option { let text = String::from_utf8_lossy(headers); text.lines().find_map(|line| { let (name, value) = line.split_once(':')?; if name.eq_ignore_ascii_case("content-length") { return value.trim().parse::().ok(); } None }) } }